Reformat Python code with psf/black

This commit is contained in:
cclauss 2019-11-20 08:58:47 +01:00
parent f915352138
commit aa13f95ca5
746 changed files with 103596 additions and 76051 deletions

View file

@ -7,28 +7,33 @@ from util.config.provider import get_config_provider
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/")) CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/') STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/') STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/') STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, 'webfonts/') STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, "webfonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/') TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_TESTING = 'TEST' in os.environ IS_TESTING = "TEST" in os.environ
IS_BUILDING = 'BUILDING' in os.environ IS_BUILDING = "BUILDING" in os.environ
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, 'stack/') OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, "stack/")
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py', config_provider = get_config_provider(
testing=IS_TESTING, kubernetes=IS_KUBERNETES) OVERRIDE_CONFIG_DIRECTORY,
"config.yaml",
"config.py",
testing=IS_TESTING,
kubernetes=IS_KUBERNETES,
)
def _get_version_number_changelog(): def _get_version_number_changelog():
try: try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f: with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0) return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError: except IOError:
return '' return ""
def _get_git_sha(): def _get_git_sha():

View file

@ -1,22 +1,30 @@
from enum import Enum, unique from enum import Enum, unique
from data.migrationutil import DefinedDataMigration, MigrationPhase from data.migrationutil import DefinedDataMigration, MigrationPhase
@unique @unique
class ERTMigrationFlags(Enum): class ERTMigrationFlags(Enum):
""" Flags for the encrypted robot token migration. """ """ Flags for the encrypted robot token migration. """
READ_OLD_FIELDS = 'read-old'
WRITE_OLD_FIELDS = 'write-old' READ_OLD_FIELDS = "read-old"
WRITE_OLD_FIELDS = "write-old"
ActiveDataMigration = DefinedDataMigration( ActiveDataMigration = DefinedDataMigration(
'encrypted_robot_tokens', "encrypted_robot_tokens",
'ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE', "ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE",
[ [
MigrationPhase('add-new-fields', 'c13c8052f7a6', [ERTMigrationFlags.READ_OLD_FIELDS, MigrationPhase(
ERTMigrationFlags.WRITE_OLD_FIELDS]), "add-new-fields",
MigrationPhase('backfill-then-read-only-new', "c13c8052f7a6",
'703298a825c2', [ERTMigrationFlags.WRITE_OLD_FIELDS]), [ERTMigrationFlags.READ_OLD_FIELDS, ERTMigrationFlags.WRITE_OLD_FIELDS],
MigrationPhase('stop-writing-both', '703298a825c2', []), ),
MigrationPhase('remove-old-fields', 'c059b952ed76', []), MigrationPhase(
] "backfill-then-read-only-new",
"703298a825c2",
[ERTMigrationFlags.WRITE_OLD_FIELDS],
),
MigrationPhase("stop-writing-both", "703298a825c2", []),
MigrationPhase("remove-old-fields", "c059b952ed76", []),
],
) )

242
app.py
View file

@ -16,8 +16,14 @@ from werkzeug.exceptions import HTTPException
import features import features
from _init import (config_provider, CONF_DIR, IS_KUBERNETES, IS_TESTING, OVERRIDE_CONFIG_DIRECTORY, from _init import (
IS_BUILDING) config_provider,
CONF_DIR,
IS_KUBERNETES,
IS_TESTING,
OVERRIDE_CONFIG_DIRECTORY,
IS_BUILDING,
)
from auth.auth_context import get_authenticated_user from auth.auth_context import get_authenticated_user
from avatars.avatars import Avatar from avatars.avatars import Avatar
@ -35,7 +41,11 @@ from data.userevent import UserEventsBuilderModule
from data.userfiles import Userfiles from data.userfiles import Userfiles
from data.users import UserAuthentication from data.users import UserAuthentication
from data.registry_model import registry_model from data.registry_model import registry_model
from path_converters import RegexConverter, RepositoryPathConverter, APIRepositoryPathConverter from path_converters import (
RegexConverter,
RepositoryPathConverter,
APIRepositoryPathConverter,
)
from oauth.services.github import GithubOAuthService from oauth.services.github import GithubOAuthService
from oauth.services.gitlab import GitLabOAuthService from oauth.services.gitlab import GitLabOAuthService
from oauth.loginmanager import OAuthLoginManager from oauth.loginmanager import OAuthLoginManager
@ -62,13 +72,13 @@ from util.security.instancekeys import InstanceKeys
from util.security.signing import Signer from util.security.signing import Signer
OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, 'stack/config.yaml') OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, "stack/config.yaml")
OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, 'stack/config.py') OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, "stack/config.py")
OVERRIDE_CONFIG_KEY = 'QUAY_OVERRIDE_CONFIG' OVERRIDE_CONFIG_KEY = "QUAY_OVERRIDE_CONFIG"
DOCKER_V2_SIGNINGKEY_FILENAME = 'docker_v2.pem' DOCKER_V2_SIGNINGKEY_FILENAME = "docker_v2.pem"
INIT_SCRIPTS_LOCATION = '/conf/init/' INIT_SCRIPTS_LOCATION = "/conf/init/"
app = Flask(__name__) app = Flask(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,11 +90,13 @@ is_building = IS_BUILDING
if is_testing: if is_testing:
from test.testconfig import TestConfig from test.testconfig import TestConfig
logger.debug('Loading test config.')
logger.debug("Loading test config.")
app.config.from_object(TestConfig()) app.config.from_object(TestConfig())
else: else:
from config import DefaultConfig from config import DefaultConfig
logger.debug('Loading default config.')
logger.debug("Loading default config.")
app.config.from_object(DefaultConfig()) app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter) app.teardown_request(database.close_db_filter)
@ -92,49 +104,60 @@ else:
config_provider.update_app_config(app.config) config_provider.update_app_config(app.config)
# Update any configuration found in the override environment variable. # Update any configuration found in the override environment variable.
environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, '{}')) environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, "{}"))
app.config.update(environ_config) app.config.update(environ_config)
# Fix remote address handling for Flask. # Fix remote address handling for Flask.
if app.config.get('PROXY_COUNT', 1): if app.config.get("PROXY_COUNT", 1):
app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get('PROXY_COUNT', 1)) app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get("PROXY_COUNT", 1))
# Ensure the V3 upgrade key is specified correctly. If not, simply fail. # Ensure the V3 upgrade key is specified correctly. If not, simply fail.
# TODO: Remove for V3.1. # TODO: Remove for V3.1.
if not is_testing and not is_building and app.config.get('SETUP_COMPLETE', False): if not is_testing and not is_building and app.config.get("SETUP_COMPLETE", False):
v3_upgrade_mode = app.config.get('V3_UPGRADE_MODE') v3_upgrade_mode = app.config.get("V3_UPGRADE_MODE")
if v3_upgrade_mode is None: if v3_upgrade_mode is None:
raise Exception('Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs') raise Exception(
"Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs"
)
if (v3_upgrade_mode != 'background' if (
and v3_upgrade_mode != 'complete' v3_upgrade_mode != "background"
and v3_upgrade_mode != 'production-transition' and v3_upgrade_mode != "complete"
and v3_upgrade_mode != 'post-oci-rollout' and v3_upgrade_mode != "production-transition"
and v3_upgrade_mode != 'post-oci-roll-back-compat'): and v3_upgrade_mode != "post-oci-rollout"
raise Exception('Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs') and v3_upgrade_mode != "post-oci-roll-back-compat"
):
raise Exception(
"Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs"
)
# Split the registry model based on config. # Split the registry model based on config.
# TODO: Remove once we are fully on the OCI data model. # TODO: Remove once we are fully on the OCI data model.
registry_model.setup_split(app.config.get('OCI_NAMESPACE_PROPORTION') or 0, registry_model.setup_split(
app.config.get('OCI_NAMESPACE_WHITELIST') or set(), app.config.get("OCI_NAMESPACE_PROPORTION") or 0,
app.config.get('V22_NAMESPACE_WHITELIST') or set(), app.config.get("OCI_NAMESPACE_WHITELIST") or set(),
app.config.get('V3_UPGRADE_MODE')) app.config.get("V22_NAMESPACE_WHITELIST") or set(),
app.config.get("V3_UPGRADE_MODE"),
)
# Allow user to define a custom storage preference for the local instance. # Allow user to define a custom storage preference for the local instance.
_distributed_storage_preference = os.environ.get('QUAY_DISTRIBUTED_STORAGE_PREFERENCE', '').split() _distributed_storage_preference = os.environ.get(
"QUAY_DISTRIBUTED_STORAGE_PREFERENCE", ""
).split()
if _distributed_storage_preference: if _distributed_storage_preference:
app.config['DISTRIBUTED_STORAGE_PREFERENCE'] = _distributed_storage_preference app.config["DISTRIBUTED_STORAGE_PREFERENCE"] = _distributed_storage_preference
# Generate a secret key if none was specified. # Generate a secret key if none was specified.
if app.config['SECRET_KEY'] is None: if app.config["SECRET_KEY"] is None:
logger.debug('Generating in-memory secret key') logger.debug("Generating in-memory secret key")
app.config['SECRET_KEY'] = generate_secret_key() app.config["SECRET_KEY"] = generate_secret_key()
# If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure # If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure
# session cookie. # session cookie.
if (app.config['PREFERRED_URL_SCHEME'] == 'https' and if app.config["PREFERRED_URL_SCHEME"] == "https" and not app.config.get(
not app.config.get('FORCE_NONSECURE_SESSION_COOKIE', False)): "FORCE_NONSECURE_SESSION_COOKIE", False
app.config['SESSION_COOKIE_SECURE'] = True ):
app.config["SESSION_COOKIE_SECURE"] = True
# Load features from config. # Load features from config.
features.import_features(app.config) features.import_features(app.config)
@ -145,7 +168,7 @@ logger.debug("Loaded config", extra={"config": app.config})
class RequestWithId(Request): class RequestWithId(Request):
request_gen = staticmethod(urn_generator(['request'])) request_gen = staticmethod(urn_generator(["request"]))
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(RequestWithId, self).__init__(*args, **kwargs) super(RequestWithId, self).__init__(*args, **kwargs)
@ -154,20 +177,31 @@ class RequestWithId(Request):
@app.before_request @app.before_request
def _request_start(): def _request_start():
if os.getenv('PYDEV_DEBUG', None): if os.getenv("PYDEV_DEBUG", None):
import pydevd import pydevd
host, port = os.getenv('PYDEV_DEBUG').split(':')
pydevd.settrace(host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False)
logger.debug('Starting request: %s (%s)', request.request_id, request.path, host, port = os.getenv("PYDEV_DEBUG").split(":")
extra={"request_id": request.request_id}) pydevd.settrace(
host,
port=int(port),
stdoutToServer=True,
stderrToServer=True,
suspend=False,
)
logger.debug(
"Starting request: %s (%s)",
request.request_id,
request.path,
extra={"request_id": request.request_id},
)
DEFAULT_FILTER = lambda x: '[FILTERED]' DEFAULT_FILTER = lambda x: "[FILTERED]"
FILTERED_VALUES = [ FILTERED_VALUES = [
{'key': ['password'], 'fn': DEFAULT_FILTER}, {"key": ["password"], "fn": DEFAULT_FILTER},
{'key': ['user', 'password'], 'fn': DEFAULT_FILTER}, {"key": ["user", "password"], "fn": DEFAULT_FILTER},
{'key': ['blob'], 'fn': lambda x: x[0:8]} {"key": ["blob"], "fn": lambda x: x[0:8]},
] ]
@ -181,7 +215,7 @@ def _request_end(resp):
values = request.values.to_dict() values = request.values.to_dict()
if jsonbody and not isinstance(jsonbody, dict): if jsonbody and not isinstance(jsonbody, dict):
jsonbody = {'_parsererror': jsonbody} jsonbody = {"_parsererror": jsonbody}
if isinstance(values, dict): if isinstance(values, dict):
filter_logs(values, FILTERED_VALUES) filter_logs(values, FILTERED_VALUES)
@ -201,23 +235,24 @@ def _request_end(resp):
if request.user_agent is not None: if request.user_agent is not None:
extra["user-agent"] = request.user_agent.string extra["user-agent"] = request.user_agent.string
logger.debug('Ending request: %s (%s)', request.request_id, request.path, extra=extra) logger.debug(
"Ending request: %s (%s)", request.request_id, request.path, extra=extra
)
return resp return resp
root_logger = logging.getLogger() root_logger = logging.getLogger()
app.request_class = RequestWithId app.request_class = RequestWithId
# Register custom converters. # Register custom converters.
app.url_map.converters['regex'] = RegexConverter app.url_map.converters["regex"] = RegexConverter
app.url_map.converters['repopath'] = RepositoryPathConverter app.url_map.converters["repopath"] = RepositoryPathConverter
app.url_map.converters['apirepopath'] = APIRepositoryPathConverter app.url_map.converters["apirepopath"] = APIRepositoryPathConverter
Principal(app, use_sessions=False) Principal(app, use_sessions=False)
tf = app.config['DB_TRANSACTION_FACTORY'] tf = app.config["DB_TRANSACTION_FACTORY"]
model_cache = get_model_cache(app.config) model_cache = get_model_cache(app.config)
avatar = Avatar(app) avatar = Avatar(app)
@ -225,10 +260,14 @@ login_manager = LoginManager(app)
mail = Mail(app) mail = Mail(app)
prometheus = PrometheusPlugin(app) prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus) metric_queue = MetricQueue(prometheus)
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue) chunk_cleanup_queue = WorkQueue(
app.config["CHUNK_CLEANUP_QUEUE_NAME"], tf, metric_queue=metric_queue
)
instance_keys = InstanceKeys(app) instance_keys = InstanceKeys(app)
ip_resolver = IPResolver(app) ip_resolver = IPResolver(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver) storage = Storage(
app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver
)
userfiles = Userfiles(app, storage) userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage) log_archive = LogArchive(app, storage)
analytics = Analytics(app) analytics = Analytics(app)
@ -246,42 +285,82 @@ build_canceller = BuildCanceller(app)
start_cloudwatch_sender(metric_queue, app) start_cloudwatch_sender(metric_queue, app)
github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG') github_trigger = GithubOAuthService(app.config, "GITHUB_TRIGGER_CONFIG")
gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG') gitlab_trigger = GitLabOAuthService(app.config, "GITLAB_TRIGGER_CONFIG")
oauth_login = OAuthLoginManager(app.config) oauth_login = OAuthLoginManager(app.config)
oauth_apps = [github_trigger, gitlab_trigger] oauth_apps = [github_trigger, gitlab_trigger]
image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf, image_replication_queue = WorkQueue(
has_namespace=False, metric_queue=metric_queue) app.config["REPLICATION_QUEUE_NAME"],
dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf, tf,
has_namespace=False,
metric_queue=metric_queue,
)
dockerfile_build_queue = WorkQueue(
app.config["DOCKERFILE_BUILD_QUEUE_NAME"],
tf,
metric_queue=metric_queue, metric_queue=metric_queue,
reporter=BuildMetricQueueReporter(metric_queue), reporter=BuildMetricQueueReporter(metric_queue),
has_namespace=True)
notification_queue = WorkQueue(app.config['NOTIFICATION_QUEUE_NAME'], tf, has_namespace=True,
metric_queue=metric_queue)
secscan_notification_queue = WorkQueue(app.config['SECSCAN_NOTIFICATION_QUEUE_NAME'], tf,
has_namespace=False,
metric_queue=metric_queue)
export_action_logs_queue = WorkQueue(app.config['EXPORT_ACTION_LOGS_QUEUE_NAME'], tf,
has_namespace=True, has_namespace=True,
metric_queue=metric_queue) )
notification_queue = WorkQueue(
app.config["NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
secscan_notification_queue = WorkQueue(
app.config["SECSCAN_NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
export_action_logs_queue = WorkQueue(
app.config["EXPORT_ACTION_LOGS_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
# Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied # Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied
# when a namespace is marked for deletion. # when a namespace is marked for deletion.
namespace_gc_queue = WorkQueue(app.config['NAMESPACE_GC_QUEUE_NAME'], tf, has_namespace=False, namespace_gc_queue = WorkQueue(
metric_queue=metric_queue) app.config["NAMESPACE_GC_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
all_queues = [image_replication_queue, dockerfile_build_queue, notification_queue, all_queues = [
secscan_notification_queue, chunk_cleanup_queue, namespace_gc_queue] image_replication_queue,
dockerfile_build_queue,
notification_queue,
secscan_notification_queue,
chunk_cleanup_queue,
namespace_gc_queue,
]
url_scheme_and_hostname = URLSchemeAndHostname(app.config['PREFERRED_URL_SCHEME'], app.config['SERVER_HOSTNAME']) url_scheme_and_hostname = URLSchemeAndHostname(
secscan_api = SecurityScannerAPI(app.config, storage, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'], app.config["PREFERRED_URL_SCHEME"], app.config["SERVER_HOSTNAME"]
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname), )
instance_keys=instance_keys) secscan_api = SecurityScannerAPI(
app.config,
storage,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
uri_creator=get_blob_download_uri_getter(
app.test_request_context("/"), url_scheme_and_hostname
),
instance_keys=instance_keys,
)
repo_mirror_api = RepoMirrorAPI(app.config, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'], repo_mirror_api = RepoMirrorAPI(
instance_keys=instance_keys) app.config,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
instance_keys=instance_keys,
)
tuf_metadata_api = TUFMetadataAPI(app, app.config) tuf_metadata_api = TUFMetadataAPI(app, app.config)
@ -293,8 +372,12 @@ else:
docker_v2_signing_key = RSAKey(key=RSA.generate(2048)) docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
# Configure the database. # Configure the database.
if app.config.get('DATABASE_SECRET_KEY') is None and app.config.get('SETUP_COMPLETE', False): if app.config.get("DATABASE_SECRET_KEY") is None and app.config.get(
raise Exception('Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?') "SETUP_COMPLETE", False
):
raise Exception(
"Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?"
)
database.configure(app.config) database.configure(app.config)
@ -306,9 +389,10 @@ model.config.register_repo_cleanup_callback(tuf_metadata_api.delete_metadata)
@login_manager.user_loader @login_manager.user_loader
def load_user(user_uuid): def load_user(user_uuid):
logger.debug('User loader loading deferred user with uuid: %s', user_uuid) logger.debug("User loader loading deferred user with uuid: %s", user_uuid)
return LoginWrappedDBUser(user_uuid) return LoginWrappedDBUser(user_uuid)
logs_model.configure(app.config) logs_model.configure(app.config)
get_app_url = partial(get_app_url, app.config) get_app_url = partial(get_app_url, app.config)

View file

@ -1,5 +1,6 @@
# NOTE: Must be before we import or call anything that may be synchronous. # NOTE: Must be before we import or call anything that may be synchronous.
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
import os import os
@ -17,6 +18,6 @@ import registry
import secscan import secscan
if __name__ == '__main__': if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False) logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0') application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -1,19 +1,23 @@
from flask import _request_ctx_stack from flask import _request_ctx_stack
def get_authenticated_context(): def get_authenticated_context():
""" Returns the auth context for the current request context, if any. """ """ Returns the auth context for the current request context, if any. """
return getattr(_request_ctx_stack.top, 'authenticated_context', None) return getattr(_request_ctx_stack.top, "authenticated_context", None)
def get_authenticated_user(): def get_authenticated_user():
""" Returns the authenticated user, if any, or None if none. """ """ Returns the authenticated user, if any, or None if none. """
context = get_authenticated_context() context = get_authenticated_context()
return context.authed_user if context else None return context.authed_user if context else None
def get_validated_oauth_token(): def get_validated_oauth_token():
""" Returns the authenticated and validated OAuth access token, if any, or None if none. """ """ Returns the authenticated and validated OAuth access token, if any, or None if none. """
context = get_authenticated_context() context = get_authenticated_context()
return context.authed_oauth_token if context else None return context.authed_oauth_token if context else None
def set_authenticated_context(auth_context): def set_authenticated_context(auth_context):
""" Sets the auth context for the current request context to that given. """ """ Sets the auth context for the current request context to that given. """
ctx = _request_ctx_stack.top ctx = _request_ctx_stack.top

View file

@ -16,6 +16,7 @@ from auth.scopes import scopes_from_scope_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class AuthContext(object): class AuthContext(object):
""" """
@ -105,8 +106,16 @@ class ValidatedAuthContext(AuthContext):
""" ValidatedAuthContext represents the loaded, authenticated and validated auth information """ ValidatedAuthContext represents the loaded, authenticated and validated auth information
for the current request context. for the current request context.
""" """
def __init__(self, user=None, token=None, oauthtoken=None, robot=None, appspecifictoken=None,
signed_data=None): def __init__(
self,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
):
# Note: These field names *MUST* match the string values of the kinds defined in # Note: These field names *MUST* match the string values of the kinds defined in
# ContextEntityKind. # ContextEntityKind.
self.user = user self.user = user
@ -138,7 +147,9 @@ class ValidatedAuthContext(AuthContext):
""" """
authed_user = self._authed_user() authed_user = self._authed_user()
if authed_user is not None and not authed_user.enabled: if authed_user is not None and not authed_user.enabled:
logger.warning('Attempt to reference a disabled user/robot: %s', authed_user.username) logger.warning(
"Attempt to reference a disabled user/robot: %s", authed_user.username
)
return None return None
return authed_user return authed_user
@ -155,7 +166,7 @@ class ValidatedAuthContext(AuthContext):
return self.appspecifictoken.user return self.appspecifictoken.user
if self.signed_data: if self.signed_data:
return model.user.get_user(self.signed_data['user_context']) return model.user.get_user(self.signed_data["user_context"])
return self.user if self.user else self.robot return self.user if self.user else self.robot
@ -174,17 +185,19 @@ class ValidatedAuthContext(AuthContext):
""" Returns the identity for the auth context. """ """ Returns the identity for the auth context. """
if self.oauthtoken: if self.oauthtoken:
scope_set = scopes_from_scope_string(self.oauthtoken.scope) scope_set = scopes_from_scope_string(self.oauthtoken.scope)
return QuayDeferredPermissionUser.for_user(self.oauthtoken.authorized_user, scope_set) return QuayDeferredPermissionUser.for_user(
self.oauthtoken.authorized_user, scope_set
)
if self.authed_user: if self.authed_user:
return QuayDeferredPermissionUser.for_user(self.authed_user) return QuayDeferredPermissionUser.for_user(self.authed_user)
if self.token: if self.token:
return Identity(self.token.get_code(), 'token') return Identity(self.token.get_code(), "token")
if self.signed_data: if self.signed_data:
identity = Identity(None, 'signed_grant') identity = Identity(None, "signed_grant")
identity.provides.update(self.signed_data['grants']) identity.provides.update(self.signed_data["grants"])
return identity return identity
return None return None
@ -226,64 +239,61 @@ class ValidatedAuthContext(AuthContext):
@property @property
def unique_key(self): def unique_key(self):
signed_dict = self.to_signed_dict() signed_dict = self.to_signed_dict()
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)')) return "%s-%s" % (
signed_dict["entity_kind"],
signed_dict.get("entity_reference", "(anon)"),
)
def to_signed_dict(self): def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization. form of signed serialization.
""" """
dict_data = { dict_data = {"version": 2, "entity_kind": self.entity_kind.value}
'version': 2,
'entity_kind': self.entity_kind.value,
}
if self.entity_kind != ContextEntityKind.anonymous: if self.entity_kind != ContextEntityKind.anonymous:
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]() handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
dict_data.update({ dict_data.update(
'entity_reference': handler.get_serialized_entity_reference(self.entity_reference), {
}) "entity_reference": handler.get_serialized_entity_reference(
self.entity_reference
)
}
)
# Add legacy information. # Add legacy information.
# TODO: Remove this all once the new code is fully deployed. # TODO: Remove this all once the new code is fully deployed.
if self.token: if self.token:
dict_data.update({ dict_data.update({"kind": "token", "token": self.token.code})
'kind': 'token',
'token': self.token.code,
})
if self.oauthtoken: if self.oauthtoken:
dict_data.update({ dict_data.update(
'kind': 'oauth', {
'oauth': self.oauthtoken.uuid, "kind": "oauth",
'user': self.authed_user.username, "oauth": self.oauthtoken.uuid,
}) "user": self.authed_user.username,
}
)
if self.user or self.robot: if self.user or self.robot:
dict_data.update({ dict_data.update({"kind": "user", "user": self.authed_user.username})
'kind': 'user',
'user': self.authed_user.username,
})
if self.appspecifictoken: if self.appspecifictoken:
dict_data.update({ dict_data.update({"kind": "user", "user": self.authed_user.username})
'kind': 'user',
'user': self.authed_user.username,
})
if self.is_anonymous: if self.is_anonymous:
dict_data.update({ dict_data.update({"kind": "anonymous"})
'kind': 'anonymous',
})
# End of legacy information. # End of legacy information.
return dict_data return dict_data
class SignedAuthContext(AuthContext): class SignedAuthContext(AuthContext):
""" SignedAuthContext represents an auth context loaded from a signed token of some kind, """ SignedAuthContext represents an auth context loaded from a signed token of some kind,
such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading
the actual {user, robot, token, etc} when requested. This allows registry operations that the actual {user, robot, token, etc} when requested. This allows registry operations that
only need to check if *some* entity is present to do so, without hitting the database. only need to check if *some* entity is present to do so, without hitting the database.
""" """
def __init__(self, kind, signed_data, v1_dict_format): def __init__(self, kind, signed_data, v1_dict_format):
self.kind = kind self.kind = kind
self.signed_data = signed_data self.signed_data = signed_data
@ -296,19 +306,22 @@ class SignedAuthContext(AuthContext):
return self._get_validated().unique_key return self._get_validated().unique_key
signed_dict = self.signed_data signed_dict = self.signed_data
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)')) return "%s-%s" % (
signed_dict["entity_kind"],
signed_dict.get("entity_reference", "(anon)"),
)
@classmethod @classmethod
def build_from_signed_dict(cls, dict_data, v1_dict_format=False): def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
if not v1_dict_format: if not v1_dict_format:
entity_kind = ContextEntityKind(dict_data.get('entity_kind', 'anonymous')) entity_kind = ContextEntityKind(dict_data.get("entity_kind", "anonymous"))
return SignedAuthContext(entity_kind, dict_data, v1_dict_format) return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
# Legacy handling. # Legacy handling.
# TODO: Remove this all once the new code is fully deployed. # TODO: Remove this all once the new code is fully deployed.
kind_string = dict_data.get('kind', 'anonymous') kind_string = dict_data.get("kind", "anonymous")
if kind_string == 'oauth': if kind_string == "oauth":
kind_string = 'oauthtoken' kind_string = "oauthtoken"
kind = ContextEntityKind(kind_string) kind = ContextEntityKind(kind_string)
return SignedAuthContext(kind, dict_data, v1_dict_format) return SignedAuthContext(kind, dict_data, v1_dict_format)
@ -322,54 +335,65 @@ class SignedAuthContext(AuthContext):
if self.kind == ContextEntityKind.anonymous: if self.kind == ContextEntityKind.anonymous:
return ValidatedAuthContext() return ValidatedAuthContext()
serialized_entity_reference = self.signed_data['entity_reference'] serialized_entity_reference = self.signed_data["entity_reference"]
handler = CONTEXT_ENTITY_HANDLERS[self.kind]() handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
entity_reference = handler.deserialize_entity_reference(serialized_entity_reference) entity_reference = handler.deserialize_entity_reference(
serialized_entity_reference
)
if entity_reference is None: if entity_reference is None:
logger.debug('Could not deserialize entity reference `%s` under kind `%s`', logger.debug(
serialized_entity_reference, self.kind) "Could not deserialize entity reference `%s` under kind `%s`",
serialized_entity_reference,
self.kind,
)
return ValidatedAuthContext() return ValidatedAuthContext()
return ValidatedAuthContext(**{self.kind.value: entity_reference}) return ValidatedAuthContext(**{self.kind.value: entity_reference})
# Legacy handling. # Legacy handling.
# TODO: Remove this all once the new code is fully deployed. # TODO: Remove this all once the new code is fully deployed.
kind_string = self.signed_data.get('kind', 'anonymous') kind_string = self.signed_data.get("kind", "anonymous")
if kind_string == 'oauth': if kind_string == "oauth":
kind_string = 'oauthtoken' kind_string = "oauthtoken"
kind = ContextEntityKind(kind_string) kind = ContextEntityKind(kind_string)
if kind == ContextEntityKind.anonymous: if kind == ContextEntityKind.anonymous:
return ValidatedAuthContext() return ValidatedAuthContext()
if kind == ContextEntityKind.user or kind == ContextEntityKind.robot: if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
user = model.user.get_user(self.signed_data.get('user', '')) user = model.user.get_user(self.signed_data.get("user", ""))
if not user: if not user:
return None return None
return ValidatedAuthContext(robot=user) if user.robot else ValidatedAuthContext(user=user) return (
ValidatedAuthContext(robot=user)
if user.robot
else ValidatedAuthContext(user=user)
)
if kind == ContextEntityKind.token: if kind == ContextEntityKind.token:
token = model.token.load_token_data(self.signed_data.get('token')) token = model.token.load_token_data(self.signed_data.get("token"))
if not token: if not token:
return None return None
return ValidatedAuthContext(token=token) return ValidatedAuthContext(token=token)
if kind == ContextEntityKind.oauthtoken: if kind == ContextEntityKind.oauthtoken:
user = model.user.get_user(self.signed_data.get('user', '')) user = model.user.get_user(self.signed_data.get("user", ""))
if not user: if not user:
return None return None
token_uuid = self.signed_data.get('oauth', '') token_uuid = self.signed_data.get("oauth", "")
oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid) oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
if not oauthtoken: if not oauthtoken:
return None return None
return ValidatedAuthContext(oauthtoken=oauthtoken) return ValidatedAuthContext(oauthtoken=oauthtoken)
raise Exception('Unknown auth context kind `%s` when deserializing %s' % (kind, raise Exception(
self.signed_data)) "Unknown auth context kind `%s` when deserializing %s"
% (kind, self.signed_data)
)
# End of legacy handling. # End of legacy handling.
@property @property

View file

@ -8,12 +8,13 @@ from auth.validateresult import ValidateResult, AuthKind
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def has_basic_auth(username): def has_basic_auth(username):
""" Returns true if a basic auth header exists with a username and password pair that validates """ Returns true if a basic auth header exists with a username and password pair that validates
against the internal authentication system. Returns True on full success and False on any against the internal authentication system. Returns True on full success and False on any
failure (missing header, invalid header, invalid credentials, etc). failure (missing header, invalid header, invalid credentials, etc).
""" """
auth_header = request.headers.get('authorization', '') auth_header = request.headers.get("authorization", "")
result = validate_basic_auth(auth_header) result = validate_basic_auth(auth_header)
return result.has_nonrobot_user and result.context.user.username == username return result.has_nonrobot_user and result.context.user.username == username
@ -25,13 +26,13 @@ def validate_basic_auth(auth_header):
if not auth_header: if not auth_header:
return ValidateResult(AuthKind.basic, missing=True) return ValidateResult(AuthKind.basic, missing=True)
logger.debug('Attempt to process basic auth header') logger.debug("Attempt to process basic auth header")
# Parse the basic auth header. # Parse the basic auth header.
assert isinstance(auth_header, basestring) assert isinstance(auth_header, basestring)
credentials, err = _parse_basic_auth_header(auth_header) credentials, err = _parse_basic_auth_header(auth_header)
if err is not None: if err is not None:
logger.debug('Got invalid basic auth header: %s', auth_header) logger.debug("Got invalid basic auth header: %s", auth_header)
return ValidateResult(AuthKind.basic, missing=True) return ValidateResult(AuthKind.basic, missing=True)
auth_username, auth_password_or_token = credentials auth_username, auth_password_or_token = credentials
@ -42,17 +43,19 @@ def validate_basic_auth(auth_header):
def _parse_basic_auth_header(auth): def _parse_basic_auth_header(auth):
""" Parses the given basic auth header, returning the credentials found inside. """ Parses the given basic auth header, returning the credentials found inside.
""" """
normalized = [part.strip() for part in auth.split(' ') if part] normalized = [part.strip() for part in auth.split(" ") if part]
if normalized[0].lower() != 'basic' or len(normalized) != 2: if normalized[0].lower() != "basic" or len(normalized) != 2:
return None, 'Invalid basic auth header' return None, "Invalid basic auth header"
try: try:
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)] credentials = [
part.decode("utf-8") for part in b64decode(normalized[1]).split(":", 1)
]
except (TypeError, UnicodeDecodeError, ValueError): except (TypeError, UnicodeDecodeError, ValueError):
logger.exception('Exception when parsing basic auth header: %s', auth) logger.exception("Exception when parsing basic auth header: %s", auth)
return None, 'Could not parse basic auth header' return None, "Could not parse basic auth header"
if len(credentials) != 2: if len(credentials) != 2:
return None, 'Unexpected number of credentials found in basic auth header' return None, "Unexpected number of credentials found in basic auth header"
return credentials, None return credentials, None

View file

@ -4,21 +4,26 @@ from enum import Enum
from data import model from data import model
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
class ContextEntityKind(Enum): class ContextEntityKind(Enum):
""" Defines the various kinds of entities in an auth context. Note that the string values of """ Defines the various kinds of entities in an auth context. Note that the string values of
these fields *must* match the names of the fields in the ValidatedAuthContext class, as these fields *must* match the names of the fields in the ValidatedAuthContext class, as
we fill them in directly based on the string names here. we fill them in directly based on the string names here.
""" """
anonymous = 'anonymous'
user = 'user' anonymous = "anonymous"
robot = 'robot' user = "user"
token = 'token' robot = "robot"
oauthtoken = 'oauthtoken' token = "token"
appspecifictoken = 'appspecifictoken' oauthtoken = "oauthtoken"
signed_data = 'signed_data' appspecifictoken = "appspecifictoken"
signed_data = "signed_data"
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
@ -87,9 +92,7 @@ class UserEntityHandler(ContextEntityHandler):
return "user %s" % entity_reference.username return "user %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, { return entity_reference.username, {"username": entity_reference.username}
'username': entity_reference.username,
}
class RobotEntityHandler(ContextEntityHandler): class RobotEntityHandler(ContextEntityHandler):
@ -106,10 +109,10 @@ class RobotEntityHandler(ContextEntityHandler):
return "robot %s" % entity_reference.username return "robot %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, { return (
'username': entity_reference.username, entity_reference.username,
'is_robot': True, {"username": entity_reference.username, "is_robot": True},
} )
class TokenEntityHandler(ContextEntityHandler): class TokenEntityHandler(ContextEntityHandler):
@ -126,9 +129,10 @@ class TokenEntityHandler(ContextEntityHandler):
return "token %s" % entity_reference.friendly_name return "token %s" % entity_reference.friendly_name
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'token:%s' % entity_reference.id, { return (
'token': entity_reference.friendly_name, "token:%s" % entity_reference.id,
} {"token": entity_reference.friendly_name},
)
class OAuthTokenEntityHandler(ContextEntityHandler): class OAuthTokenEntityHandler(ContextEntityHandler):
@ -145,12 +149,15 @@ class OAuthTokenEntityHandler(ContextEntityHandler):
return "oauthtoken for user %s" % entity_reference.authorized_user.username return "oauthtoken for user %s" % entity_reference.authorized_user.username
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'oauthtoken:%s' % entity_reference.id, { return (
'oauth_token_id': entity_reference.id, "oauthtoken:%s" % entity_reference.id,
'oauth_token_application_id': entity_reference.application.client_id, {
'oauth_token_application': entity_reference.application.name, "oauth_token_id": entity_reference.id,
'username': entity_reference.authorized_user.username, "oauth_token_application_id": entity_reference.application.client_id,
} "oauth_token_application": entity_reference.application.name,
"username": entity_reference.authorized_user.username,
},
)
class AppSpecificTokenEntityHandler(ContextEntityHandler): class AppSpecificTokenEntityHandler(ContextEntityHandler):
@ -168,11 +175,14 @@ class AppSpecificTokenEntityHandler(ContextEntityHandler):
return "app specific token %s for user %s" % tpl return "app specific token %s for user %s" % tpl
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'appspecifictoken:%s' % entity_reference.id, { return (
'app_specific_token': entity_reference.uuid, "appspecifictoken:%s" % entity_reference.id,
'app_specific_token_title': entity_reference.title, {
'username': entity_reference.user.username, "app_specific_token": entity_reference.uuid,
} "app_specific_token_title": entity_reference.title,
"username": entity_reference.user.username,
},
)
class SignedDataEntityHandler(ContextEntityHandler): class SignedDataEntityHandler(ContextEntityHandler):
@ -189,7 +199,7 @@ class SignedDataEntityHandler(ContextEntityHandler):
return "signed" return "signed"
def analytics_id_and_public_metadata(self, entity_reference): def analytics_id_and_public_metadata(self, entity_reference):
return 'signed', {'signed': entity_reference} return "signed", {"signed": entity_reference}
CONTEXT_ENTITY_HANDLERS = { CONTEXT_ENTITY_HANDLERS = {

View file

@ -7,6 +7,7 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_session_cookie(auth_header_unusued=None): def validate_session_cookie(auth_header_unusued=None):
""" Attempts to load a user from a session cookie. """ """ Attempts to load a user from a session cookie. """
if current_user.is_anonymous: if current_user.is_anonymous:
@ -16,22 +17,30 @@ def validate_session_cookie(auth_header_unusued=None):
# Attempt to parse the user uuid to make sure the cookie has the right value type # Attempt to parse the user uuid to make sure the cookie has the right value type
UUID(current_user.get_id()) UUID(current_user.get_id())
except ValueError: except ValueError:
logger.debug('Got non-UUID for session cookie user: %s', current_user.get_id()) logger.debug("Got non-UUID for session cookie user: %s", current_user.get_id())
return ValidateResult(AuthKind.cookie, error_message='Invalid session cookie format') return ValidateResult(
AuthKind.cookie, error_message="Invalid session cookie format"
)
logger.debug('Loading user from cookie: %s', current_user.get_id()) logger.debug("Loading user from cookie: %s", current_user.get_id())
db_user = current_user.db_user() db_user = current_user.db_user()
if db_user is None: if db_user is None:
return ValidateResult(AuthKind.cookie, error_message='Could not find matching user') return ValidateResult(
AuthKind.cookie, error_message="Could not find matching user"
)
# Don't allow disabled users to login. # Don't allow disabled users to login.
if not db_user.enabled: if not db_user.enabled:
logger.debug('User %s in session cookie is disabled', db_user.username) logger.debug("User %s in session cookie is disabled", db_user.username)
return ValidateResult(AuthKind.cookie, error_message='User account is disabled') return ValidateResult(AuthKind.cookie, error_message="User account is disabled")
# Don't allow organizations to "login". # Don't allow organizations to "login".
if db_user.organization: if db_user.organization:
logger.debug('User %s in session cookie is in-fact organization', db_user.username) logger.debug(
return ValidateResult(AuthKind.cookie, error_message='Cannot login to organization') "User %s in session cookie is in-fact organization", db_user.username
)
return ValidateResult(
AuthKind.cookie, error_message="Cannot login to organization"
)
return ValidateResult(AuthKind.cookie, user=db_user) return ValidateResult(AuthKind.cookie, user=db_user)

View file

@ -1,3 +1,3 @@
ACCESS_TOKEN_USERNAME = '$token' ACCESS_TOKEN_USERNAME = "$token"
OAUTH_TOKEN_USERNAME = '$oauthtoken' OAUTH_TOKEN_USERNAME = "$oauthtoken"
APP_SPECIFIC_TOKEN_USERNAME = '$app' APP_SPECIFIC_TOKEN_USERNAME = "$app"

View file

@ -7,8 +7,11 @@ import features
from app import authentication from app import authentication
from auth.oauth import validate_oauth_token from auth.oauth import validate_oauth_token
from auth.validateresult import ValidateResult, AuthKind from auth.validateresult import ValidateResult, AuthKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from data import model from data import model
from util.names import parse_robot_username from util.names import parse_robot_username
@ -16,8 +19,8 @@ logger = logging.getLogger(__name__)
class CredentialKind(Enum): class CredentialKind(Enum):
user = 'user' user = "user"
robot = 'robot' robot = "robot"
token = ACCESS_TOKEN_USERNAME token = ACCESS_TOKEN_USERNAME
oauth_token = OAUTH_TOKEN_USERNAME oauth_token = OAUTH_TOKEN_USERNAME
app_specific_token = APP_SPECIFIC_TOKEN_USERNAME app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
@ -27,36 +30,61 @@ def validate_credentials(auth_username, auth_password_or_token):
""" Validates a pair of auth username and password/token credentials. """ """ Validates a pair of auth username and password/token credentials. """
# Check for access tokens. # Check for access tokens.
if auth_username == ACCESS_TOKEN_USERNAME: if auth_username == ACCESS_TOKEN_USERNAME:
logger.debug('Found credentials for access token') logger.debug("Found credentials for access token")
try: try:
token = model.token.load_token_data(auth_password_or_token) token = model.token.load_token_data(auth_password_or_token)
logger.debug('Successfully validated credentials for access token %s', token.id) logger.debug(
return ValidateResult(AuthKind.credentials, token=token), CredentialKind.token "Successfully validated credentials for access token %s", token.id
)
return (
ValidateResult(AuthKind.credentials, token=token),
CredentialKind.token,
)
except model.DataModelException: except model.DataModelException:
logger.warning('Failed to validate credentials for access token %s', auth_password_or_token) logger.warning(
return (ValidateResult(AuthKind.credentials, error_message='Invalid access token'), "Failed to validate credentials for access token %s",
CredentialKind.token) auth_password_or_token,
)
return (
ValidateResult(
AuthKind.credentials, error_message="Invalid access token"
),
CredentialKind.token,
)
# Check for App Specific tokens. # Check for App Specific tokens.
if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME: if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
logger.debug('Found credentials for app specific auth token') logger.debug("Found credentials for app specific auth token")
token = model.appspecifictoken.access_valid_token(auth_password_or_token) token = model.appspecifictoken.access_valid_token(auth_password_or_token)
if token is None: if token is None:
logger.debug('Failed to validate credentials for app specific token: %s', logger.debug(
auth_password_or_token) "Failed to validate credentials for app specific token: %s",
return (ValidateResult(AuthKind.credentials, error_message='Invalid token'), auth_password_or_token,
CredentialKind.app_specific_token) )
return (
ValidateResult(AuthKind.credentials, error_message="Invalid token"),
CredentialKind.app_specific_token,
)
if not token.user.enabled: if not token.user.enabled:
logger.debug('Tried to use an app specific token for a disabled user: %s', logger.debug(
token.uuid) "Tried to use an app specific token for a disabled user: %s", token.uuid
return (ValidateResult(AuthKind.credentials, )
error_message='This user has been disabled. Please contact your administrator.'), return (
CredentialKind.app_specific_token) ValidateResult(
AuthKind.credentials,
error_message="This user has been disabled. Please contact your administrator.",
),
CredentialKind.app_specific_token,
)
logger.debug('Successfully validated credentials for app specific token %s', token.id) logger.debug(
return (ValidateResult(AuthKind.credentials, appspecifictoken=token), "Successfully validated credentials for app specific token %s", token.id
CredentialKind.app_specific_token) )
return (
ValidateResult(AuthKind.credentials, appspecifictoken=token),
CredentialKind.app_specific_token,
)
# Check for OAuth tokens. # Check for OAuth tokens.
if auth_username == OAUTH_TOKEN_USERNAME: if auth_username == OAUTH_TOKEN_USERNAME:
@ -65,21 +93,42 @@ def validate_credentials(auth_username, auth_password_or_token):
# Check for robots and users. # Check for robots and users.
is_robot = parse_robot_username(auth_username) is_robot = parse_robot_username(auth_username)
if is_robot: if is_robot:
logger.debug('Found credentials header for robot %s', auth_username) logger.debug("Found credentials header for robot %s", auth_username)
try: try:
robot = model.user.verify_robot(auth_username, auth_password_or_token) robot = model.user.verify_robot(auth_username, auth_password_or_token)
logger.debug('Successfully validated credentials for robot %s', auth_username) logger.debug(
return ValidateResult(AuthKind.credentials, robot=robot), CredentialKind.robot "Successfully validated credentials for robot %s", auth_username
)
return (
ValidateResult(AuthKind.credentials, robot=robot),
CredentialKind.robot,
)
except model.InvalidRobotException as ire: except model.InvalidRobotException as ire:
logger.warning('Failed to validate credentials for robot %s: %s', auth_username, ire) logger.warning(
return ValidateResult(AuthKind.credentials, error_message=str(ire)), CredentialKind.robot "Failed to validate credentials for robot %s: %s", auth_username, ire
)
return (
ValidateResult(AuthKind.credentials, error_message=str(ire)),
CredentialKind.robot,
)
# Otherwise, treat as a standard user. # Otherwise, treat as a standard user.
(authenticated, err) = authentication.verify_and_link_user(auth_username, auth_password_or_token, (authenticated, err) = authentication.verify_and_link_user(
basic_auth=True) auth_username, auth_password_or_token, basic_auth=True
)
if authenticated: if authenticated:
logger.debug('Successfully validated credentials for user %s', authenticated.username) logger.debug(
return ValidateResult(AuthKind.credentials, user=authenticated), CredentialKind.user "Successfully validated credentials for user %s", authenticated.username
)
return (
ValidateResult(AuthKind.credentials, user=authenticated),
CredentialKind.user,
)
else: else:
logger.warning('Failed to validate credentials for user %s: %s', auth_username, err) logger.warning(
return ValidateResult(AuthKind.credentials, error_message=err), CredentialKind.user "Failed to validate credentials for user %s: %s", auth_username, err
)
return (
ValidateResult(AuthKind.credentials, error_message=err),
CredentialKind.user,
)

View file

@ -14,15 +14,17 @@ from util.http import abort
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _auth_decorator(pass_result=False, handlers=None): def _auth_decorator(pass_result=False, handlers=None):
""" Builds an auth decorator that runs the given handlers and, if any return successfully, """ Builds an auth decorator that runs the given handlers and, if any return successfully,
sets up the auth context. The wrapped function will be invoked *regardless of success or sets up the auth context. The wrapped function will be invoked *regardless of success or
failure of the auth handler(s)* failure of the auth handler(s)*
""" """
def processor(func): def processor(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
auth_header = request.headers.get('authorization', '') auth_header = request.headers.get("authorization", "")
result = None result = None
for handler in handlers: for handler in handlers:
@ -33,32 +35,42 @@ def _auth_decorator(pass_result=False, handlers=None):
# Check for a valid result. # Check for a valid result.
if result.auth_valid: if result.auth_valid:
logger.debug('Found valid auth result: %s', result.tuple()) logger.debug("Found valid auth result: %s", result.tuple())
# Set the various pieces of the auth context. # Set the various pieces of the auth context.
result.apply_to_context() result.apply_to_context()
# Log the metric. # Log the metric.
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True]) metric_queue.authentication_count.Inc(
labelvalues=[result.kind, True]
)
break break
# Otherwise, report the error. # Otherwise, report the error.
if result.error_message is not None: if result.error_message is not None:
# Log the failure. # Log the failure.
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False]) metric_queue.authentication_count.Inc(
labelvalues=[result.kind, False]
)
break break
if pass_result: if pass_result:
kwargs['auth_result'] = result kwargs["auth_result"] = result
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return processor return processor
process_oauth = _auth_decorator(handlers=[validate_bearer_auth, validate_session_cookie]) process_oauth = _auth_decorator(
handlers=[validate_bearer_auth, validate_session_cookie]
)
process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth]) process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth])
process_auth_or_cookie = _auth_decorator(handlers=[validate_basic_auth, validate_session_cookie]) process_auth_or_cookie = _auth_decorator(
handlers=[validate_basic_auth, validate_session_cookie]
)
process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True) process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True)
process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth]) process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth])
@ -67,6 +79,7 @@ def require_session_login(func):
""" Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If """ Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
a valid session cookie does exist, the authenticated user and identity are also set. a valid session cookie does exist, the authenticated user and identity are also set.
""" """
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
result = validate_session_cookie() result = validate_session_cookie()
@ -77,7 +90,8 @@ def require_session_login(func):
elif not result.missing: elif not result.missing:
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False]) metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
abort(401, message='Method requires login and no valid login could be loaded.') abort(401, message="Method requires login and no valid login could be loaded.")
return wrapper return wrapper
@ -86,11 +100,15 @@ def extract_namespace_repo_from_session(func):
and passes them into the decorated function as the first and second arguments. If the and passes them into the decorated function as the first and second arguments. If the
session doesn't exist or does not contain these arugments, a 400 error is raised. session doesn't exist or does not contain these arugments, a 400 error is raised.
""" """
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if 'namespace' not in session or 'repository' not in session: if "namespace" not in session or "repository" not in session:
logger.error('Unable to load namespace or repository from session: %s', session) logger.error(
abort(400, message='Missing namespace in request') "Unable to load namespace or repository from session: %s", session
)
abort(400, message="Missing namespace in request")
return func(session["namespace"], session["repository"], *args, **kwargs)
return func(session['namespace'], session['repository'], *args, **kwargs)
return wrapper return wrapper

View file

@ -8,6 +8,7 @@ from data import model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_bearer_auth(auth_header): def validate_bearer_auth(auth_header):
""" Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it """ Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
points to a valid OAuth token. points to a valid OAuth token.
@ -15,9 +16,9 @@ def validate_bearer_auth(auth_header):
if not auth_header: if not auth_header:
return ValidateResult(AuthKind.oauth, missing=True) return ValidateResult(AuthKind.oauth, missing=True)
normalized = [part.strip() for part in auth_header.split(' ') if part] normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != 'bearer' or len(normalized) != 2: if normalized[0].lower() != "bearer" or len(normalized) != 2:
logger.debug('Got invalid bearer token format: %s', auth_header) logger.debug("Got invalid bearer token format: %s", auth_header)
return ValidateResult(AuthKind.oauth, missing=True) return ValidateResult(AuthKind.oauth, missing=True)
(_, oauth_token) = normalized (_, oauth_token) = normalized
@ -29,20 +30,25 @@ def validate_oauth_token(token):
""" """
validated = model.oauth.validate_access_token(token) validated = model.oauth.validate_access_token(token)
if not validated: if not validated:
logger.warning('OAuth access token could not be validated: %s', token) logger.warning("OAuth access token could not be validated: %s", token)
return ValidateResult(AuthKind.oauth, return ValidateResult(
error_message='OAuth access token could not be validated') AuthKind.oauth, error_message="OAuth access token could not be validated"
)
if validated.expires_at <= datetime.utcnow(): if validated.expires_at <= datetime.utcnow():
logger.warning('OAuth access with an expired token: %s', token) logger.warning("OAuth access with an expired token: %s", token)
return ValidateResult(AuthKind.oauth, error_message='OAuth access token has expired') return ValidateResult(
AuthKind.oauth, error_message="OAuth access token has expired"
)
# Don't allow disabled users to login. # Don't allow disabled users to login.
if not validated.authorized_user.enabled: if not validated.authorized_user.enabled:
return ValidateResult(AuthKind.oauth, return ValidateResult(
error_message='Granter of the oauth access token is disabled') AuthKind.oauth,
error_message="Granter of the oauth access token is disabled",
)
# We have a valid token # We have a valid token
scope_set = scopes_from_scope_string(validated.scope) scope_set = scopes_from_scope_string(validated.scope)
logger.debug('Successfully validated oauth access token with scope: %s', scope_set) logger.debug("Successfully validated oauth access token with scope: %s", scope_set)
return ValidateResult(AuthKind.oauth, oauthtoken=validated) return ValidateResult(AuthKind.oauth, oauthtoken=validated)

View file

@ -14,60 +14,59 @@ from data import model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ResourceNeed = namedtuple('resource', ['type', 'namespace', 'name', 'role']) _ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"])
_RepositoryNeed = partial(_ResourceNeed, 'repository') _RepositoryNeed = partial(_ResourceNeed, "repository")
_NamespaceWideNeed = namedtuple('namespacewide', ['type', 'namespace', 'role']) _NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"])
_OrganizationNeed = partial(_NamespaceWideNeed, 'organization') _OrganizationNeed = partial(_NamespaceWideNeed, "organization")
_OrganizationRepoNeed = partial(_NamespaceWideNeed, 'organizationrepo') _OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo")
_TeamTypeNeed = namedtuple('teamwideneed', ['type', 'orgname', 'teamname', 'role']) _TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"])
_TeamNeed = partial(_TeamTypeNeed, 'orgteam') _TeamNeed = partial(_TeamTypeNeed, "orgteam")
_UserTypeNeed = namedtuple('userspecificneed', ['type', 'username', 'role']) _UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"])
_UserNeed = partial(_UserTypeNeed, 'user') _UserNeed = partial(_UserTypeNeed, "user")
_SuperUserNeed = partial(namedtuple('superuserneed', ['type']), 'superuser') _SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser")
REPO_ROLES = [None, 'read', 'write', 'admin'] REPO_ROLES = [None, "read", "write", "admin"]
TEAM_ROLES = [None, 'member', 'creator', 'admin'] TEAM_ROLES = [None, "member", "creator", "admin"]
USER_ROLES = [None, 'read', 'admin'] USER_ROLES = [None, "read", "admin"]
TEAM_ORGWIDE_REPO_ROLES = { TEAM_ORGWIDE_REPO_ROLES = {"admin": "admin", "creator": None, "member": None}
'admin': 'admin',
'creator': None,
'member': None,
}
SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None) SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None)
SCOPE_MAX_REPO_ROLES.update({ SCOPE_MAX_REPO_ROLES.update(
scopes.READ_REPO: 'read', {
scopes.WRITE_REPO: 'write', scopes.READ_REPO: "read",
scopes.ADMIN_REPO: 'admin', scopes.WRITE_REPO: "write",
scopes.DIRECT_LOGIN: 'admin', scopes.ADMIN_REPO: "admin",
}) scopes.DIRECT_LOGIN: "admin",
}
)
SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None) SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None)
SCOPE_MAX_TEAM_ROLES.update({ SCOPE_MAX_TEAM_ROLES.update(
scopes.CREATE_REPO: 'creator', {
scopes.DIRECT_LOGIN: 'admin', scopes.CREATE_REPO: "creator",
scopes.ORG_ADMIN: 'admin', scopes.DIRECT_LOGIN: "admin",
}) scopes.ORG_ADMIN: "admin",
}
)
SCOPE_MAX_USER_ROLES = defaultdict(lambda: None) SCOPE_MAX_USER_ROLES = defaultdict(lambda: None)
SCOPE_MAX_USER_ROLES.update({ SCOPE_MAX_USER_ROLES.update(
scopes.READ_USER: 'read', {scopes.READ_USER: "read", scopes.DIRECT_LOGIN: "admin", scopes.ADMIN_USER: "admin"}
scopes.DIRECT_LOGIN: 'admin', )
scopes.ADMIN_USER: 'admin',
})
def repository_read_grant(namespace, repository): def repository_read_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'read') return _RepositoryNeed(namespace, repository, "read")
def repository_write_grant(namespace, repository): def repository_write_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'write') return _RepositoryNeed(namespace, repository, "write")
def repository_admin_grant(namespace, repository): def repository_admin_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'admin') return _RepositoryNeed(namespace, repository, "admin")
class QuayDeferredPermissionUser(Identity): class QuayDeferredPermissionUser(Identity):
@ -84,21 +83,27 @@ class QuayDeferredPermissionUser(Identity):
@staticmethod @staticmethod
def for_id(uuid, auth_scopes=None): def for_id(uuid, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN} auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(uuid, 'user_uuid', auth_scopes) return QuayDeferredPermissionUser(uuid, "user_uuid", auth_scopes)
@staticmethod @staticmethod
def for_user(user, auth_scopes=None): def for_user(user, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN} auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(user.uuid, 'user_uuid', auth_scopes, user=user) return QuayDeferredPermissionUser(
user.uuid, "user_uuid", auth_scopes, user=user
)
def _translate_role_for_scopes(self, cardinality, max_roles, role): def _translate_role_for_scopes(self, cardinality, max_roles, role):
if self._scope_set is None: if self._scope_set is None:
return role return role
max_for_scopes = max({cardinality.index(max_roles[scope]) for scope in self._scope_set}) max_for_scopes = max(
{cardinality.index(max_roles[scope]) for scope in self._scope_set}
)
if max_for_scopes < cardinality.index(role): if max_for_scopes < cardinality.index(role):
logger.debug('Translated permission %s -> %s', role, cardinality[max_for_scopes]) logger.debug(
"Translated permission %s -> %s", role, cardinality[max_for_scopes]
)
return cardinality[max_for_scopes] return cardinality[max_for_scopes]
else: else:
return role return role
@ -118,23 +123,31 @@ class QuayDeferredPermissionUser(Identity):
""" """
# Add the user specific permissions, only for non-oauth permission # Add the user specific permissions, only for non-oauth permission
user_grant = _UserNeed(user_object.username, self._user_role_for_scopes('admin')) user_grant = _UserNeed(
logger.debug('User permission: {0}'.format(user_grant)) user_object.username, self._user_role_for_scopes("admin")
)
logger.debug("User permission: {0}".format(user_grant))
self.provides.add(user_grant) self.provides.add(user_grant)
# Every user is the admin of their own 'org' # Every user is the admin of their own 'org'
user_namespace = _OrganizationNeed(user_object.username, self._team_role_for_scopes('admin')) user_namespace = _OrganizationNeed(
logger.debug('User namespace permission: {0}'.format(user_namespace)) user_object.username, self._team_role_for_scopes("admin")
)
logger.debug("User namespace permission: {0}".format(user_namespace))
self.provides.add(user_namespace) self.provides.add(user_namespace)
# Org repo roles can differ for scopes # Org repo roles can differ for scopes
user_repos = _OrganizationRepoNeed(user_object.username, self._repo_role_for_scopes('admin')) user_repos = _OrganizationRepoNeed(
logger.debug('User namespace repo permission: {0}'.format(user_repos)) user_object.username, self._repo_role_for_scopes("admin")
)
logger.debug("User namespace repo permission: {0}".format(user_repos))
self.provides.add(user_repos) self.provides.add(user_repos)
if ((scopes.SUPERUSER in self._scope_set or scopes.DIRECT_LOGIN in self._scope_set) and if (
superusers.is_superuser(user_object.username)): scopes.SUPERUSER in self._scope_set
logger.debug('Adding superuser to user: %s', user_object.username) or scopes.DIRECT_LOGIN in self._scope_set
) and superusers.is_superuser(user_object.username):
logger.debug("Adding superuser to user: %s", user_object.username)
self.provides.add(_SuperUserNeed()) self.provides.add(_SuperUserNeed())
def _populate_namespace_wide_provides(self, user_object, namespace_filter): def _populate_namespace_wide_provides(self, user_object, namespace_filter):
@ -142,40 +155,59 @@ class QuayDeferredPermissionUser(Identity):
This method does *not* add any provides for specific repositories. This method does *not* add any provides for specific repositories.
""" """
for team in model.permission.get_org_wide_permissions(user_object, org_filter=namespace_filter): for team in model.permission.get_org_wide_permissions(
team_org_grant = _OrganizationNeed(team.organization.username, user_object, org_filter=namespace_filter
self._team_role_for_scopes(team.role.name)) ):
logger.debug('Organization team added permission: {0}'.format(team_org_grant)) team_org_grant = _OrganizationNeed(
team.organization.username, self._team_role_for_scopes(team.role.name)
)
logger.debug(
"Organization team added permission: {0}".format(team_org_grant)
)
self.provides.add(team_org_grant) self.provides.add(team_org_grant)
team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name] team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
org_repo_grant = _OrganizationRepoNeed(team.organization.username, org_repo_grant = _OrganizationRepoNeed(
self._repo_role_for_scopes(team_repo_role)) team.organization.username, self._repo_role_for_scopes(team_repo_role)
logger.debug('Organization team added repo permission: {0}'.format(org_repo_grant)) )
logger.debug(
"Organization team added repo permission: {0}".format(org_repo_grant)
)
self.provides.add(org_repo_grant) self.provides.add(org_repo_grant)
team_grant = _TeamNeed(team.organization.username, team.name, team_grant = _TeamNeed(
self._team_role_for_scopes(team.role.name)) team.organization.username,
logger.debug('Team added permission: {0}'.format(team_grant)) team.name,
self._team_role_for_scopes(team.role.name),
)
logger.debug("Team added permission: {0}".format(team_grant))
self.provides.add(team_grant) self.provides.add(team_grant)
def _populate_repository_provides(self, user_object, namespace_filter, repository_name): def _populate_repository_provides(
self, user_object, namespace_filter, repository_name
):
""" Populates the repository-specific provides for a particular user and repository. """ """ Populates the repository-specific provides for a particular user and repository. """
if namespace_filter and repository_name: if namespace_filter and repository_name:
permissions = model.permission.get_user_repository_permissions(user_object, namespace_filter, permissions = model.permission.get_user_repository_permissions(
repository_name) user_object, namespace_filter, repository_name
)
else: else:
permissions = model.permission.get_all_user_repository_permissions(user_object) permissions = model.permission.get_all_user_repository_permissions(
user_object
)
for perm in permissions: for perm in permissions:
repo_grant = _RepositoryNeed(perm.repository.namespace_user.username, perm.repository.name, repo_grant = _RepositoryNeed(
self._repo_role_for_scopes(perm.role.name)) perm.repository.namespace_user.username,
logger.debug('User added permission: {0}'.format(repo_grant)) perm.repository.name,
self._repo_role_for_scopes(perm.role.name),
)
logger.debug("User added permission: {0}".format(repo_grant))
self.provides.add(repo_grant) self.provides.add(repo_grant)
def can(self, permission): def can(self, permission):
logger.debug('Loading user permissions after deferring for: %s', self.id) logger.debug("Loading user permissions after deferring for: %s", self.id)
user_object = self._user_object or model.user.get_user_by_uuid(self.id) user_object = self._user_object or model.user.get_user_by_uuid(self.id)
if user_object is None: if user_object is None:
return super(QuayDeferredPermissionUser, self).can(permission) return super(QuayDeferredPermissionUser, self).can(permission)
@ -195,7 +227,7 @@ class QuayDeferredPermissionUser(Identity):
perm_repository = None perm_repository = None
if perm_namespace and perm_repo_name: if perm_namespace and perm_repo_name:
perm_repository = '%s/%s' % (perm_namespace, perm_repo_name) perm_repository = "%s/%s" % (perm_namespace, perm_repo_name)
if not perm_namespace and not perm_repo_name: if not perm_namespace and not perm_repo_name:
# Nothing more to load, so just check directly. # Nothing more to load, so just check directly.
@ -203,7 +235,9 @@ class QuayDeferredPermissionUser(Identity):
# Lazy-load the repository-specific permissions. # Lazy-load the repository-specific permissions.
if perm_repository and perm_repository not in self._repositories_loaded: if perm_repository and perm_repository not in self._repositories_loaded:
self._populate_repository_provides(user_object, perm_namespace, perm_repo_name) self._populate_repository_provides(
user_object, perm_namespace, perm_repo_name
)
self._repositories_loaded.add(perm_repository) self._repositories_loaded.add(perm_repository)
# If we now have permission, no need to load any more permissions. # If we now have permission, no need to load any more permissions.
@ -220,61 +254,68 @@ class QuayDeferredPermissionUser(Identity):
class QuayPermission(Permission): class QuayPermission(Permission):
""" Base for all permissions in Quay. """ """ Base for all permissions in Quay. """
namespace = None namespace = None
repo_name = None repo_name = None
class ModifyRepositoryPermission(QuayPermission): class ModifyRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, 'write') write_need = _RepositoryNeed(namespace, name, "write")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, 'write') org_write_need = _OrganizationRepoNeed(namespace, "write")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(ModifyRepositoryPermission, self).__init__(admin_need, write_need, org_admin_need, super(ModifyRepositoryPermission, self).__init__(
org_write_need) admin_need, write_need, org_admin_need, org_write_need
)
class ReadRepositoryPermission(QuayPermission): class ReadRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, 'write') write_need = _RepositoryNeed(namespace, name, "write")
read_need = _RepositoryNeed(namespace, name, 'read') read_need = _RepositoryNeed(namespace, name, "read")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, 'write') org_write_need = _OrganizationRepoNeed(namespace, "write")
org_read_need = _OrganizationRepoNeed(namespace, 'read') org_read_need = _OrganizationRepoNeed(namespace, "read")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(ReadRepositoryPermission, self).__init__(admin_need, write_need, read_need, super(ReadRepositoryPermission, self).__init__(
org_admin_need, org_read_need, org_write_need) admin_need,
write_need,
read_need,
org_admin_need,
org_read_need,
org_write_need,
)
class AdministerRepositoryPermission(QuayPermission): class AdministerRepositoryPermission(QuayPermission):
def __init__(self, namespace, name): def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin') admin_need = _RepositoryNeed(namespace, name, "admin")
org_admin_need = _OrganizationRepoNeed(namespace, 'admin') org_admin_need = _OrganizationRepoNeed(namespace, "admin")
self.namespace = namespace self.namespace = namespace
self.repo_name = name self.repo_name = name
super(AdministerRepositoryPermission, self).__init__(admin_need, super(AdministerRepositoryPermission, self).__init__(admin_need, org_admin_need)
org_admin_need)
class CreateRepositoryPermission(QuayPermission): class CreateRepositoryPermission(QuayPermission):
def __init__(self, namespace): def __init__(self, namespace):
admin_org = _OrganizationNeed(namespace, 'admin') admin_org = _OrganizationNeed(namespace, "admin")
create_repo_org = _OrganizationNeed(namespace, 'creator') create_repo_org = _OrganizationNeed(namespace, "creator")
self.namespace = namespace self.namespace = namespace
super(CreateRepositoryPermission, self).__init__(admin_org, super(CreateRepositoryPermission, self).__init__(admin_org, create_repo_org)
create_repo_org)
class SuperUserPermission(QuayPermission): class SuperUserPermission(QuayPermission):
def __init__(self): def __init__(self):
@ -284,20 +325,20 @@ class SuperUserPermission(QuayPermission):
class UserAdminPermission(QuayPermission): class UserAdminPermission(QuayPermission):
def __init__(self, username): def __init__(self, username):
user_admin = _UserNeed(username, 'admin') user_admin = _UserNeed(username, "admin")
super(UserAdminPermission, self).__init__(user_admin) super(UserAdminPermission, self).__init__(user_admin)
class UserReadPermission(QuayPermission): class UserReadPermission(QuayPermission):
def __init__(self, username): def __init__(self, username):
user_admin = _UserNeed(username, 'admin') user_admin = _UserNeed(username, "admin")
user_read = _UserNeed(username, 'read') user_read = _UserNeed(username, "read")
super(UserReadPermission, self).__init__(user_read, user_admin) super(UserReadPermission, self).__init__(user_read, user_admin)
class AdministerOrganizationPermission(QuayPermission): class AdministerOrganizationPermission(QuayPermission):
def __init__(self, org_name): def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name self.namespace = org_name
@ -306,27 +347,29 @@ class AdministerOrganizationPermission(QuayPermission):
class OrganizationMemberPermission(QuayPermission): class OrganizationMemberPermission(QuayPermission):
def __init__(self, org_name): def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
repo_creator_org = _OrganizationNeed(org_name, 'creator') repo_creator_org = _OrganizationNeed(org_name, "creator")
org_member = _OrganizationNeed(org_name, 'member') org_member = _OrganizationNeed(org_name, "member")
self.namespace = org_name self.namespace = org_name
super(OrganizationMemberPermission, self).__init__(admin_org, org_member, super(OrganizationMemberPermission, self).__init__(
repo_creator_org) admin_org, org_member, repo_creator_org
)
class ViewTeamPermission(QuayPermission): class ViewTeamPermission(QuayPermission):
def __init__(self, org_name, team_name): def __init__(self, org_name, team_name):
team_admin = _TeamNeed(org_name, team_name, 'admin') team_admin = _TeamNeed(org_name, team_name, "admin")
team_creator = _TeamNeed(org_name, team_name, 'creator') team_creator = _TeamNeed(org_name, team_name, "creator")
team_member = _TeamNeed(org_name, team_name, 'member') team_member = _TeamNeed(org_name, team_name, "member")
admin_org = _OrganizationNeed(org_name, 'admin') admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name self.namespace = org_name
super(ViewTeamPermission, self).__init__(team_admin, team_creator, super(ViewTeamPermission, self).__init__(
team_member, admin_org) team_admin, team_creator, team_member, admin_org
)
class AlwaysFailPermission(QuayPermission): class AlwaysFailPermission(QuayPermission):
@ -336,29 +379,34 @@ class AlwaysFailPermission(QuayPermission):
@identity_loaded.connect_via(app) @identity_loaded.connect_via(app)
def on_identity_loaded(sender, identity): def on_identity_loaded(sender, identity):
logger.debug('Identity loaded: %s' % identity) logger.debug("Identity loaded: %s" % identity)
# We have verified an identity, load in all of the permissions # We have verified an identity, load in all of the permissions
if isinstance(identity, QuayDeferredPermissionUser): if isinstance(identity, QuayDeferredPermissionUser):
logger.debug('Deferring permissions for user with uuid: %s', identity.id) logger.debug("Deferring permissions for user with uuid: %s", identity.id)
elif identity.auth_type == 'user_uuid': elif identity.auth_type == "user_uuid":
logger.debug('Switching username permission to deferred object with uuid: %s', identity.id) logger.debug(
"Switching username permission to deferred object with uuid: %s",
identity.id,
)
switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id) switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
identity_changed.send(app, identity=switch_to_deferred) identity_changed.send(app, identity=switch_to_deferred)
elif identity.auth_type == 'token': elif identity.auth_type == "token":
logger.debug('Loading permissions for token: %s', identity.id) logger.debug("Loading permissions for token: %s", identity.id)
token_data = model.token.load_token_data(identity.id) token_data = model.token.load_token_data(identity.id)
repo_grant = _RepositoryNeed(token_data.repository.namespace_user.username, repo_grant = _RepositoryNeed(
token_data.repository.namespace_user.username,
token_data.repository.name, token_data.repository.name,
token_data.role.name) token_data.role.name,
logger.debug('Delegate token added permission: %s', repo_grant) )
logger.debug("Delegate token added permission: %s", repo_grant)
identity.provides.add(repo_grant) identity.provides.add(repo_grant)
elif identity.auth_type == 'signed_grant' or identity.auth_type == 'signed_jwt': elif identity.auth_type == "signed_grant" or identity.auth_type == "signed_jwt":
logger.debug('Loaded %s identity for: %s', identity.auth_type, identity.id) logger.debug("Loaded %s identity for: %s", identity.auth_type, identity.id)
else: else:
logger.error('Unknown identity auth type: %s', identity.auth_type) logger.error("Unknown identity auth type: %s", identity.auth_type)

View file

@ -9,49 +9,43 @@ from flask_principal import identity_changed, Identity
from app import app, get_app_url, instance_keys, metric_queue from app import app, get_app_url, instance_keys, metric_queue
from auth.auth_context import set_authenticated_context from auth.auth_context import set_authenticated_context
from auth.auth_context_type import SignedAuthContext from auth.auth_context_type import SignedAuthContext
from auth.permissions import repository_read_grant, repository_write_grant, repository_admin_grant from auth.permissions import (
repository_read_grant,
repository_write_grant,
repository_admin_grant,
)
from util.http import abort from util.http import abort
from util.names import parse_namespace_repository from util.names import parse_namespace_repository
from util.security.registry_jwt import (ANONYMOUS_SUB, decode_bearer_header, from util.security.registry_jwt import (
InvalidBearerTokenException) ANONYMOUS_SUB,
decode_bearer_header,
InvalidBearerTokenException,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ACCESS_SCHEMA = { ACCESS_SCHEMA = {
'type': 'array', "type": "array",
'description': 'List of access granted to the subject', "description": "List of access granted to the subject",
'items': { "items": {
'type': 'object', "type": "object",
'required': [ "required": ["type", "name", "actions"],
'type', "properties": {
'name', "type": {
'actions', "type": "string",
], "description": "We only allow repository permissions",
'properties': { "enum": ["repository"],
'type': {
'type': 'string',
'description': 'We only allow repository permissions',
'enum': [
'repository',
],
}, },
'name': { "name": {
'type': 'string', "type": "string",
'description': 'The name of the repository for which we are receiving access' "description": "The name of the repository for which we are receiving access",
},
'actions': {
'type': 'array',
'description': 'List of specific verbs which can be performed against repository',
'items': {
'type': 'string',
'enum': [
'push',
'pull',
'*',
],
}, },
"actions": {
"type": "array",
"description": "List of specific verbs which can be performed against repository",
"items": {"type": "string", "enum": ["push", "pull", "*"]},
}, },
}, },
}, },
@ -65,19 +59,19 @@ class InvalidJWTException(Exception):
def get_auth_headers(repository=None, scopes=None): def get_auth_headers(repository=None, scopes=None):
""" Returns a dictionary of headers for auth responses. """ """ Returns a dictionary of headers for auth responses. """
headers = {} headers = {}
realm_auth_path = url_for('v2.generate_registry_jwt') realm_auth_path = url_for("v2.generate_registry_jwt")
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(), authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(
realm_auth_path, get_app_url(), realm_auth_path, app.config["SERVER_HOSTNAME"]
app.config['SERVER_HOSTNAME']) )
if repository: if repository:
scopes_string = "repository:{0}".format(repository) scopes_string = "repository:{0}".format(repository)
if scopes: if scopes:
scopes_string += ':' + ','.join(scopes) scopes_string += ":" + ",".join(scopes)
authenticate += ',scope="{0}"'.format(scopes_string) authenticate += ',scope="{0}"'.format(scopes_string)
headers['WWW-Authenticate'] = authenticate headers["WWW-Authenticate"] = authenticate
headers['Docker-Distribution-API-Version'] = 'registry/2.0' headers["Docker-Distribution-API-Version"] = "registry/2.0"
return headers return headers
@ -86,79 +80,95 @@ def identity_from_bearer_token(bearer_header):
identity could not be loaded. Expects tokens and grants in the format of the Docker registry identity could not be loaded. Expects tokens and grants in the format of the Docker registry
v2 auth spec: https://docs.docker.com/registry/spec/auth/token/ v2 auth spec: https://docs.docker.com/registry/spec/auth/token/
""" """
logger.debug('Validating auth header: %s', bearer_header) logger.debug("Validating auth header: %s", bearer_header)
try: try:
payload = decode_bearer_header(bearer_header, instance_keys, app.config, payload = decode_bearer_header(
metric_queue=metric_queue) bearer_header, instance_keys, app.config, metric_queue=metric_queue
)
except InvalidBearerTokenException as bte: except InvalidBearerTokenException as bte:
logger.exception('Invalid bearer token: %s', bte) logger.exception("Invalid bearer token: %s", bte)
raise InvalidJWTException(bte) raise InvalidJWTException(bte)
loaded_identity = Identity(payload['sub'], 'signed_jwt') loaded_identity = Identity(payload["sub"], "signed_jwt")
# Process the grants from the payload # Process the grants from the payload
if 'access' in payload: if "access" in payload:
try: try:
validate(payload['access'], ACCESS_SCHEMA) validate(payload["access"], ACCESS_SCHEMA)
except ValidationError: except ValidationError:
logger.exception('We should not be minting invalid credentials') logger.exception("We should not be minting invalid credentials")
raise InvalidJWTException('Token contained invalid or malformed access grants') raise InvalidJWTException(
"Token contained invalid or malformed access grants"
)
lib_namespace = app.config['LIBRARY_NAMESPACE'] lib_namespace = app.config["LIBRARY_NAMESPACE"]
for grant in payload['access']: for grant in payload["access"]:
namespace, repo_name = parse_namespace_repository(grant['name'], lib_namespace) namespace, repo_name = parse_namespace_repository(
grant["name"], lib_namespace
)
if '*' in grant['actions']: if "*" in grant["actions"]:
loaded_identity.provides.add(repository_admin_grant(namespace, repo_name)) loaded_identity.provides.add(
elif 'push' in grant['actions']: repository_admin_grant(namespace, repo_name)
loaded_identity.provides.add(repository_write_grant(namespace, repo_name)) )
elif 'pull' in grant['actions']: elif "push" in grant["actions"]:
loaded_identity.provides.add(repository_read_grant(namespace, repo_name)) loaded_identity.provides.add(
repository_write_grant(namespace, repo_name)
)
elif "pull" in grant["actions"]:
loaded_identity.provides.add(
repository_read_grant(namespace, repo_name)
)
default_context = { default_context = {"kind": "anonymous"}
'kind': 'anonymous'
}
if payload['sub'] != ANONYMOUS_SUB: if payload["sub"] != ANONYMOUS_SUB:
default_context = { default_context = {"kind": "user", "user": payload["sub"]}
'kind': 'user',
'user': payload['sub'],
}
return loaded_identity, payload.get('context', default_context) return loaded_identity, payload.get("context", default_context)
def process_registry_jwt_auth(scopes=None): def process_registry_jwt_auth(scopes=None):
""" Processes the registry JWT auth token found in the authorization header. If none found, """ Processes the registry JWT auth token found in the authorization header. If none found,
no error is returned. If an invalid token is found, raises a 401. no error is returned. If an invalid token is found, raises a 401.
""" """
def inner(func): def inner(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
logger.debug('Called with params: %s, %s', args, kwargs) logger.debug("Called with params: %s, %s", args, kwargs)
auth = request.headers.get('authorization', '').strip() auth = request.headers.get("authorization", "").strip()
if auth: if auth:
try: try:
extracted_identity, context_dict = identity_from_bearer_token(auth) extracted_identity, context_dict = identity_from_bearer_token(auth)
identity_changed.send(app, identity=extracted_identity) identity_changed.send(app, identity=extracted_identity)
logger.debug('Identity changed to %s', extracted_identity.id) logger.debug("Identity changed to %s", extracted_identity.id)
auth_context = SignedAuthContext.build_from_signed_dict(context_dict) auth_context = SignedAuthContext.build_from_signed_dict(
context_dict
)
if auth_context is not None: if auth_context is not None:
logger.debug('Auth context set to %s', auth_context.signed_data) logger.debug("Auth context set to %s", auth_context.signed_data)
set_authenticated_context(auth_context) set_authenticated_context(auth_context)
except InvalidJWTException as ije: except InvalidJWTException as ije:
repository = None repository = None
if 'namespace_name' in kwargs and 'repo_name' in kwargs: if "namespace_name" in kwargs and "repo_name" in kwargs:
repository = kwargs['namespace_name'] + '/' + kwargs['repo_name'] repository = (
kwargs["namespace_name"] + "/" + kwargs["repo_name"]
)
abort(401, message=ije.message, headers=get_auth_headers(repository=repository, abort(
scopes=scopes)) 401,
message=ije.message,
headers=get_auth_headers(repository=repository, scopes=scopes),
)
else: else:
logger.debug('No auth header.') logger.debug("No auth header.")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return inner return inner

View file

@ -2,85 +2,132 @@ from collections import namedtuple
import features import features
import re import re
Scope = namedtuple('scope', ['scope', 'icon', 'dangerous', 'title', 'description']) Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"])
READ_REPO = Scope(scope='repo:read', READ_REPO = Scope(
icon='fa-hdd-o', scope="repo:read",
icon="fa-hdd-o",
dangerous=False, dangerous=False,
title='View all visible repositories', title="View all visible repositories",
description=('This application will be able to view and pull all repositories ' description=(
'visible to the granting user or robot account')) "This application will be able to view and pull all repositories "
"visible to the granting user or robot account"
),
)
WRITE_REPO = Scope(scope='repo:write', WRITE_REPO = Scope(
icon='fa-hdd-o', scope="repo:write",
icon="fa-hdd-o",
dangerous=False, dangerous=False,
title='Read/Write to any accessible repositories', title="Read/Write to any accessible repositories",
description=('This application will be able to view, push and pull to all ' description=(
'repositories to which the granting user or robot account has ' "This application will be able to view, push and pull to all "
'write access')) "repositories to which the granting user or robot account has "
"write access"
),
)
ADMIN_REPO = Scope(scope='repo:admin', ADMIN_REPO = Scope(
icon='fa-hdd-o', scope="repo:admin",
icon="fa-hdd-o",
dangerous=False, dangerous=False,
title='Administer Repositories', title="Administer Repositories",
description=('This application will have administrator access to all ' description=(
'repositories to which the granting user or robot account has ' "This application will have administrator access to all "
'access')) "repositories to which the granting user or robot account has "
"access"
),
)
CREATE_REPO = Scope(scope='repo:create', CREATE_REPO = Scope(
icon='fa-plus', scope="repo:create",
icon="fa-plus",
dangerous=False, dangerous=False,
title='Create Repositories', title="Create Repositories",
description=('This application will be able to create repositories in to any ' description=(
'namespaces that the granting user or robot account is allowed ' "This application will be able to create repositories in to any "
'to create repositories')) "namespaces that the granting user or robot account is allowed "
"to create repositories"
),
)
READ_USER = Scope(scope= 'user:read', READ_USER = Scope(
icon='fa-user', scope="user:read",
icon="fa-user",
dangerous=False, dangerous=False,
title='Read User Information', title="Read User Information",
description=('This application will be able to read user information such as ' description=(
'username and email address.')) "This application will be able to read user information such as "
"username and email address."
),
)
ADMIN_USER = Scope(scope= 'user:admin', ADMIN_USER = Scope(
icon='fa-gear', scope="user:admin",
icon="fa-gear",
dangerous=True, dangerous=True,
title='Administer User', title="Administer User",
description=('This application will be able to administer your account ' description=(
'including creating robots and granting them permissions ' "This application will be able to administer your account "
'to your repositories. You should have absolute trust in the ' "including creating robots and granting them permissions "
'requesting application before granting this permission.')) "to your repositories. You should have absolute trust in the "
"requesting application before granting this permission."
),
)
ORG_ADMIN = Scope(scope='org:admin', ORG_ADMIN = Scope(
icon='fa-gear', scope="org:admin",
icon="fa-gear",
dangerous=True, dangerous=True,
title='Administer Organization', title="Administer Organization",
description=('This application will be able to administer your organizations ' description=(
'including creating robots, creating teams, adjusting team ' "This application will be able to administer your organizations "
'membership, and changing billing settings. You should have ' "including creating robots, creating teams, adjusting team "
'absolute trust in the requesting application before granting this ' "membership, and changing billing settings. You should have "
'permission.')) "absolute trust in the requesting application before granting this "
"permission."
),
)
DIRECT_LOGIN = Scope(scope='direct_user_login', DIRECT_LOGIN = Scope(
icon='fa-exclamation-triangle', scope="direct_user_login",
icon="fa-exclamation-triangle",
dangerous=True, dangerous=True,
title='Full Access', title="Full Access",
description=('This scope should not be available to OAuth applications. ' description=(
'Never approve a request for this scope!')) "This scope should not be available to OAuth applications. "
"Never approve a request for this scope!"
),
)
SUPERUSER = Scope(scope='super:user', SUPERUSER = Scope(
icon='fa-street-view', scope="super:user",
icon="fa-street-view",
dangerous=True, dangerous=True,
title='Super User Access', title="Super User Access",
description=('This application will be able to administer your installation ' description=(
'including managing users, managing organizations and other ' "This application will be able to administer your installation "
'features found in the superuser panel. You should have ' "including managing users, managing organizations and other "
'absolute trust in the requesting application before granting this ' "features found in the superuser panel. You should have "
'permission.')) "absolute trust in the requesting application before granting this "
"permission."
),
)
ALL_SCOPES = {scope.scope: scope for scope in (READ_REPO, WRITE_REPO, ADMIN_REPO, CREATE_REPO, ALL_SCOPES = {
READ_USER, ORG_ADMIN, SUPERUSER, ADMIN_USER)} scope.scope: scope
for scope in (
READ_REPO,
WRITE_REPO,
ADMIN_REPO,
CREATE_REPO,
READ_USER,
ORG_ADMIN,
SUPERUSER,
ADMIN_USER,
)
}
IMPLIED_SCOPES = { IMPLIED_SCOPES = {
ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO}, ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
@ -97,19 +144,19 @@ IMPLIED_SCOPES = {
def app_scopes(app_config): def app_scopes(app_config):
scopes_from_config = dict(ALL_SCOPES) scopes_from_config = dict(ALL_SCOPES)
if not app_config.get('FEATURE_SUPER_USERS', False): if not app_config.get("FEATURE_SUPER_USERS", False):
del scopes_from_config[SUPERUSER.scope] del scopes_from_config[SUPERUSER.scope]
return scopes_from_config return scopes_from_config
def scopes_from_scope_string(scopes): def scopes_from_scope_string(scopes):
if not scopes: if not scopes:
scopes = '' scopes = ""
# Note: The scopes string should be space seperated according to the spec: # Note: The scopes string should be space seperated according to the spec:
# https://tools.ietf.org/html/rfc6749#section-3.3 # https://tools.ietf.org/html/rfc6749#section-3.3
# However, we also support commas for backwards compatibility with existing callers to our code. # However, we also support commas for backwards compatibility with existing callers to our code.
scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(' |,', scopes)} scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(" |,", scopes)}
return scope_set if not None in scope_set else set() return scope_set if not None in scope_set else set()
@ -135,12 +182,14 @@ def get_scope_information(scopes_string):
scopes = scopes_from_scope_string(scopes_string) scopes = scopes_from_scope_string(scopes_string)
scope_info = [] scope_info = []
for scope in scopes: for scope in scopes:
scope_info.append({ scope_info.append(
'title': scope.title, {
'scope': scope.scope, "title": scope.title,
'description': scope.description, "scope": scope.scope,
'icon': scope.icon, "description": scope.description,
'dangerous': scope.dangerous, "icon": scope.icon,
}) "dangerous": scope.dangerous,
}
)
return scope_info return scope_info

View file

@ -8,18 +8,16 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The prefix for all signatures of signed granted. # The prefix for all signatures of signed granted.
SIGNATURE_PREFIX = 'sigv2=' SIGNATURE_PREFIX = "sigv2="
def generate_signed_token(grants, user_context): def generate_signed_token(grants, user_context):
""" Generates a signed session token with the given grants and user context. """ """ Generates a signed session token with the given grants and user context. """
ser = SecureCookieSessionInterface().get_signing_serializer(app) ser = SecureCookieSessionInterface().get_signing_serializer(app)
data_to_sign = { data_to_sign = {"grants": grants, "user_context": user_context}
'grants': grants,
'user_context': user_context,
}
encrypted = ser.dumps(data_to_sign) encrypted = ser.dumps(data_to_sign)
return '{0}{1}'.format(SIGNATURE_PREFIX, encrypted) return "{0}{1}".format(SIGNATURE_PREFIX, encrypted)
def validate_signed_grant(auth_header): def validate_signed_grant(auth_header):
@ -30,14 +28,14 @@ def validate_signed_grant(auth_header):
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Try to parse the token from the header. # Try to parse the token from the header.
normalized = [part.strip() for part in auth_header.split(' ') if part] normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != 'token' or len(normalized) != 2: if normalized[0].lower() != "token" or len(normalized) != 2:
logger.debug('Not a token: %s', auth_header) logger.debug("Not a token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Check that it starts with the expected prefix. # Check that it starts with the expected prefix.
if not normalized[1].startswith(SIGNATURE_PREFIX): if not normalized[1].startswith(SIGNATURE_PREFIX):
logger.debug('Not a signed grant token: %s', auth_header) logger.debug("Not a signed grant token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True) return ValidateResult(AuthKind.signed_grant, missing=True)
# Decrypt the grant. # Decrypt the grant.
@ -45,11 +43,14 @@ def validate_signed_grant(auth_header):
ser = SecureCookieSessionInterface().get_signing_serializer(app) ser = SecureCookieSessionInterface().get_signing_serializer(app)
try: try:
token_data = ser.loads(encrypted, max_age=app.config['SIGNED_GRANT_EXPIRATION_SEC']) token_data = ser.loads(
encrypted, max_age=app.config["SIGNED_GRANT_EXPIRATION_SEC"]
)
except BadSignature: except BadSignature:
logger.warning('Signed grant could not be validated: %s', encrypted) logger.warning("Signed grant could not be validated: %s", encrypted)
return ValidateResult(AuthKind.signed_grant, return ValidateResult(
error_message='Signed grant could not be validated') AuthKind.signed_grant, error_message="Signed grant could not be validated"
)
logger.debug('Successfully validated signed grant with data: %s', token_data) logger.debug("Successfully validated signed grant with data: %s", token_data)
return ValidateResult(AuthKind.signed_grant, signed_data=token_data) return ValidateResult(AuthKind.signed_grant, signed_data=token_data)

View file

@ -1,27 +1,37 @@
import pytest import pytest
from auth.auth_context_type import SignedAuthContext, ValidatedAuthContext, ContextEntityKind from auth.auth_context_type import (
SignedAuthContext,
ValidatedAuthContext,
ContextEntityKind,
)
from data import model, database from data import model, database
from test.fixtures import * from test.fixtures import *
def get_oauth_token(_): def get_oauth_token(_):
return database.OAuthAccessToken.get() return database.OAuthAccessToken.get()
@pytest.mark.parametrize('kind, entity_reference, loader', [ @pytest.mark.parametrize(
"kind, entity_reference, loader",
[
(ContextEntityKind.anonymous, None, None), (ContextEntityKind.anonymous, None, None),
(ContextEntityKind.appspecifictoken, '%s%s' % ('a' * 60, 'b' * 60), (
model.appspecifictoken.access_valid_token), ContextEntityKind.appspecifictoken,
"%s%s" % ("a" * 60, "b" * 60),
model.appspecifictoken.access_valid_token,
),
(ContextEntityKind.oauthtoken, None, get_oauth_token), (ContextEntityKind.oauthtoken, None, get_oauth_token),
(ContextEntityKind.robot, 'devtable+dtrobot', model.user.lookup_robot), (ContextEntityKind.robot, "devtable+dtrobot", model.user.lookup_robot),
(ContextEntityKind.user, 'devtable', model.user.get_user), (ContextEntityKind.user, "devtable", model.user.get_user),
]) ],
@pytest.mark.parametrize('v1_dict_format', [ )
(True), @pytest.mark.parametrize("v1_dict_format", [(True), (False)])
(False), def test_signed_auth_context(
]) kind, entity_reference, loader, v1_dict_format, initialized_db
def test_signed_auth_context(kind, entity_reference, loader, v1_dict_format, initialized_db): ):
if kind == ContextEntityKind.anonymous: if kind == ContextEntityKind.anonymous:
validated = ValidatedAuthContext() validated = ValidatedAuthContext()
assert validated.is_anonymous assert validated.is_anonymous
@ -33,15 +43,19 @@ def test_signed_auth_context(kind, entity_reference, loader, v1_dict_format, ini
assert validated.entity_kind == kind assert validated.entity_kind == kind
assert validated.unique_key assert validated.unique_key
signed = SignedAuthContext.build_from_signed_dict(validated.to_signed_dict(), signed = SignedAuthContext.build_from_signed_dict(
v1_dict_format=v1_dict_format) validated.to_signed_dict(), v1_dict_format=v1_dict_format
)
if not v1_dict_format: if not v1_dict_format:
# Under legacy V1 format, we don't track the app specific token, merely its associated user. # Under legacy V1 format, we don't track the app specific token, merely its associated user.
assert signed.entity_kind == kind assert signed.entity_kind == kind
assert signed.description == validated.description assert signed.description == validated.description
assert signed.credential_username == validated.credential_username assert signed.credential_username == validated.credential_username
assert signed.analytics_id_and_public_metadata() == validated.analytics_id_and_public_metadata() assert (
signed.analytics_id_and_public_metadata()
== validated.analytics_id_and_public_metadata()
)
assert signed.unique_key == validated.unique_key assert signed.unique_key == validated.unique_key
assert signed.is_anonymous == validated.is_anonymous assert signed.is_anonymous == validated.is_anonymous

View file

@ -5,8 +5,11 @@ import pytest
from base64 import b64encode from base64 import b64encode
from auth.basic import validate_basic_auth from auth.basic import validate_basic_auth
from auth.credentials import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credentials import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
from data import model from data import model
@ -16,66 +19,100 @@ from test.fixtures import *
def _token(username, password): def _token(username, password):
assert isinstance(username, basestring) assert isinstance(username, basestring)
assert isinstance(password, basestring) assert isinstance(password, basestring)
return 'basic ' + b64encode('%s:%s' % (username, password)) return "basic " + b64encode("%s:%s" % (username, password))
@pytest.mark.parametrize('token, expected_result', [ @pytest.mark.parametrize(
('', ValidateResult(AuthKind.basic, missing=True)), "token, expected_result",
('someinvalidtoken', ValidateResult(AuthKind.basic, missing=True)), [
('somefoobartoken', ValidateResult(AuthKind.basic, missing=True)), ("", ValidateResult(AuthKind.basic, missing=True)),
('basic ', ValidateResult(AuthKind.basic, missing=True)), ("someinvalidtoken", ValidateResult(AuthKind.basic, missing=True)),
('basic some token', ValidateResult(AuthKind.basic, missing=True)), ("somefoobartoken", ValidateResult(AuthKind.basic, missing=True)),
('basic sometoken', ValidateResult(AuthKind.basic, missing=True)), ("basic ", ValidateResult(AuthKind.basic, missing=True)),
(_token(APP_SPECIFIC_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, ("basic some token", ValidateResult(AuthKind.basic, missing=True)),
error_message='Invalid token')), ("basic sometoken", ValidateResult(AuthKind.basic, missing=True)),
(_token(ACCESS_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, (
error_message='Invalid access token')), _token(APP_SPECIFIC_TOKEN_USERNAME, "invalid"),
(_token(OAUTH_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic, error_message="Invalid token"),
ValidateResult(AuthKind.basic, error_message='OAuth access token could not be validated')), ),
(_token('devtable', 'invalid'), ValidateResult(AuthKind.basic, (
error_message='Invalid Username or Password')), _token(ACCESS_TOKEN_USERNAME, "invalid"),
(_token('devtable+somebot', 'invalid'), ValidateResult( ValidateResult(AuthKind.basic, error_message="Invalid access token"),
AuthKind.basic, error_message='Could not find robot with username: devtable+somebot')), ),
(_token('disabled', 'password'), ValidateResult( (
_token(OAUTH_TOKEN_USERNAME, "invalid"),
ValidateResult(
AuthKind.basic, AuthKind.basic,
error_message='This user has been disabled. Please contact your administrator.')),]) error_message="OAuth access token could not be validated",
),
),
(
_token("devtable", "invalid"),
ValidateResult(
AuthKind.basic, error_message="Invalid Username or Password"
),
),
(
_token("devtable+somebot", "invalid"),
ValidateResult(
AuthKind.basic,
error_message="Could not find robot with username: devtable+somebot",
),
),
(
_token("disabled", "password"),
ValidateResult(
AuthKind.basic,
error_message="This user has been disabled. Please contact your administrator.",
),
),
],
)
def test_validate_basic_auth_token(token, expected_result, app): def test_validate_basic_auth_token(token, expected_result, app):
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == expected_result assert result == expected_result
def test_valid_user(app): def test_valid_user(app):
token = _token('devtable', 'password') token = _token("devtable", "password")
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, user=model.user.get_user('devtable')) assert result == ValidateResult(
AuthKind.basic, user=model.user.get_user("devtable")
)
def test_valid_robot(app): def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, password = model.user.create_robot(
"somerobot", model.user.get_user("devtable")
)
token = _token(robot.username, password) token = _token(robot.username, password)
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, robot=robot) assert result == ValidateResult(AuthKind.basic, robot=robot)
def test_valid_token(app): def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken') access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code()) token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, token=access_token) assert result == ValidateResult(AuthKind.basic, token=access_token)
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read') model.user.get_user_or_org("buynlarge")
)[0]
oauth_token, code = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read"
)
token = _token(OAUTH_TOKEN_USERNAME, code) token = _token(OAUTH_TOKEN_USERNAME, code)
result = validate_basic_auth(token) result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token) assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
def test_valid_app_specific_token(app): def test_valid_app_specific_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token) token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
result = validate_basic_auth(token) result = validate_basic_auth(token)
@ -83,16 +120,17 @@ def test_valid_app_specific_token(app):
def test_invalid_unicode(app): def test_invalid_unicode(app):
token = '\xebOH' token = "\xebOH"
header = 'basic ' + b64encode(token) header = "basic " + b64encode(token)
result = validate_basic_auth(header) result = validate_basic_auth(header)
assert result == ValidateResult(AuthKind.basic, missing=True) assert result == ValidateResult(AuthKind.basic, missing=True)
def test_invalid_unicode_2(app): def test_invalid_unicode_2(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
header = 'basic ' + b64encode('devtable+somerobot:%s' % token) header = "basic " + b64encode("devtable+somerobot:%s" % token)
result = validate_basic_auth(header) result = validate_basic_auth(header)
assert result == ValidateResult( assert result == ValidateResult(
AuthKind.basic, AuthKind.basic,
error_message='Could not find robot with username: devtable+somerobot and supplied password.') error_message="Could not find robot with username: devtable+somerobot and supplied password.",
)

View file

@ -14,20 +14,20 @@ def test_anonymous_cookie(app):
def test_invalidformatted_cookie(app): def test_invalidformatted_cookie(app):
# "Login" with a non-UUID reference. # "Login" with a non-UUID reference.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser('somenonuuid', someuser)) login_user(LoginWrappedDBUser("somenonuuid", someuser))
# Ensure we get an invalid session cookie format error. # Ensure we get an invalid session cookie format error.
result = validate_session_cookie() result = validate_session_cookie()
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'Invalid session cookie format' assert result.error_message == "Invalid session cookie format"
def test_disabled_user(app): def test_disabled_user(app):
# "Login" with a disabled user. # "Login" with a disabled user.
someuser = model.user.get_user('disabled') someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Ensure we get an invalid session cookie format error. # Ensure we get an invalid session cookie format error.
@ -35,12 +35,12 @@ def test_disabled_user(app):
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'User account is disabled' assert result.error_message == "User account is disabled"
def test_valid_user(app): def test_valid_user(app):
# Login with a valid user. # Login with a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
result = validate_session_cookie() result = validate_session_cookie()
@ -52,7 +52,7 @@ def test_valid_user(app):
def test_valid_organization(app): def test_valid_organization(app):
# "Login" with a valid organization. # "Login" with a valid organization.
someorg = model.user.get_namespace_user('buynlarge') someorg = model.user.get_namespace_user("buynlarge")
someorg.uuid = str(uuid.uuid4()) someorg.uuid = str(uuid.uuid4())
someorg.verified = True someorg.verified = True
someorg.save() someorg.save()
@ -63,4 +63,4 @@ def test_valid_organization(app):
assert result.authed_user is None assert result.authed_user is None
assert result.context.identity is None assert result.context.identity is None
assert not result.has_nonrobot_user assert not result.has_nonrobot_user
assert result.error_message == 'Cannot login to organization' assert result.error_message == "Cannot login to organization"

View file

@ -1,147 +1,184 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from auth.credentials import validate_credentials, CredentialKind from auth.credentials import validate_credentials, CredentialKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME, from auth.credential_consts import (
APP_SPECIFIC_TOKEN_USERNAME) ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
from data import model from data import model
from test.fixtures import * from test.fixtures import *
def test_valid_user(app): def test_valid_user(app):
result, kind = validate_credentials('devtable', 'password') result, kind = validate_credentials("devtable", "password")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials, user=model.user.get_user('devtable')) assert result == ValidateResult(
AuthKind.credentials, user=model.user.get_user("devtable")
)
def test_valid_robot(app): def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, password = model.user.create_robot(
"somerobot", model.user.get_user("devtable")
)
result, kind = validate_credentials(robot.username, password) result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert result == ValidateResult(AuthKind.credentials, robot=robot) assert result == ValidateResult(AuthKind.credentials, robot=robot)
def test_valid_robot_for_disabled_user(app): def test_valid_robot_for_disabled_user(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
user.enabled = False user.enabled = False
user.save() user.save()
robot, password = model.user.create_robot('somerobot', user) robot, password = model.user.create_robot("somerobot", user)
result, kind = validate_credentials(robot.username, password) result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
err = 'This user has been disabled. Please contact your administrator.' err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err) assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_valid_token(app): def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken') access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code()) result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
assert kind == CredentialKind.token assert kind == CredentialKind.token
assert result == ValidateResult(AuthKind.credentials, token=access_token) assert result == ValidateResult(AuthKind.credentials, token=access_token)
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read') model.user.get_user_or_org("buynlarge")
)[0]
oauth_token, code = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read"
)
result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code) result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
assert kind == CredentialKind.oauth_token assert kind == CredentialKind.oauth_token
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token) assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
def test_invalid_user(app): def test_invalid_user(app):
result, kind = validate_credentials('devtable', 'somepassword') result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials, assert result == ValidateResult(
error_message='Invalid Username or Password') AuthKind.credentials, error_message="Invalid Username or Password"
)
def test_valid_app_specific_token(app): def test_valid_app_specific_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, appspecifictoken=app_specific_token) assert result == ValidateResult(
AuthKind.credentials, appspecifictoken=app_specific_token
)
def test_valid_app_specific_token_for_disabled_user(app): def test_valid_app_specific_token_for_disabled_user(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
user.enabled = False user.enabled = False
user.save() user.save()
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token) full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
err = 'This user has been disabled. Please contact your administrator.' err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err) assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_invalid_app_specific_token(app): def test_invalid_app_specific_token(app):
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, 'somecode') result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, "somecode")
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token') assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_invalid_app_specific_token_code(app): def test_invalid_app_specific_token_code(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, 'some token') app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = app_specific_token.token_name + 'something' full_token = app_specific_token.token_name + "something"
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token) result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token') assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_unicode(app): def test_unicode(app):
result, kind = validate_credentials('someusername', 'some₪code') result, kind = validate_credentials("someusername", "some₪code")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.auth_valid assert not result.auth_valid
assert result == ValidateResult(AuthKind.credentials, assert result == ValidateResult(
error_message='Invalid Username or Password') AuthKind.credentials, error_message="Invalid Username or Password"
)
def test_unicode_robot(app): def test_unicode_robot(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, 'some₪code') result, kind = validate_credentials(robot.username, "some₪code")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot and supplied password.' msg = (
"Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg) assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_user(app): def test_invalid_user(app):
result, kind = validate_credentials('someinvaliduser', 'password') result, kind = validate_credentials("someinvaliduser", "password")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_user_password(app): def test_invalid_user_password(app):
result, kind = validate_credentials('devtable', 'somepassword') result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_robot(app): def test_invalid_robot(app):
result, kind = validate_credentials('devtable+doesnotexist', 'password') result, kind = validate_credentials("devtable+doesnotexist", "password")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_robot_token(app): def test_invalid_robot_token(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable')) robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, 'invalidpassword') result, kind = validate_credentials(robot.username, "invalidpassword")
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.authed_user assert not result.authed_user
assert not result.auth_valid assert not result.auth_valid
def test_invalid_unicode_robot(app): def test_invalid_unicode_robot(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials('devtable+somerobot', token) result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot' msg = "Could not find robot with username: devtable+somerobot"
assert result == ValidateResult(AuthKind.credentials, error_message=msg) assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_unicode_robot_2(app): def test_invalid_unicode_robot_2(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
robot, password = model.user.create_robot('somerobot', user) robot, password = model.user.create_robot("somerobot", user)
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”' token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials('devtable+somerobot', token) result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot assert kind == CredentialKind.robot
assert not result.auth_valid assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot and supplied password.' msg = (
"Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg) assert result == ValidateResult(AuthKind.credentials, error_message=msg)

View file

@ -7,7 +7,10 @@ from werkzeug.exceptions import HTTPException
from app import LoginWrappedDBUser from app import LoginWrappedDBUser
from auth.auth_context import get_authenticated_user from auth.auth_context import get_authenticated_user
from auth.decorators import ( from auth.decorators import (
extract_namespace_repo_from_session, require_session_login, process_auth_or_cookie) extract_namespace_repo_from_session,
require_session_login,
process_auth_or_cookie,
)
from data import model from data import model
from test.fixtures import * from test.fixtures import *
@ -30,14 +33,14 @@ def test_extract_namespace_repo_from_session_present(app):
# Add the namespace and repository to the session. # Add the namespace and repository to the session.
session.clear() session.clear()
session['namespace'] = 'foo' session["namespace"] = "foo"
session['repository'] = 'bar' session["repository"] = "bar"
# Call the decorated method. # Call the decorated method.
extract_namespace_repo_from_session(somefunc)() extract_namespace_repo_from_session(somefunc)()
assert encountered[0] == 'foo' assert encountered[0] == "foo"
assert encountered[1] == 'bar' assert encountered[1] == "bar"
def test_require_session_login_missing(app): def test_require_session_login_missing(app):
@ -53,7 +56,7 @@ def test_require_session_login_valid_user(app):
pass pass
# Login as a valid user. # Login as a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.
@ -68,7 +71,7 @@ def test_require_session_login_invalid_user(app):
pass pass
# "Login" as a disabled user. # "Login" as a disabled user.
someuser = model.user.get_user('disabled') someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.
@ -95,7 +98,7 @@ def test_process_auth_or_cookie_valid_user(app):
pass pass
# Login as a valid user. # Login as a valid user.
someuser = model.user.get_user('devtable') someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser)) login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function. # Call the function.

View file

@ -6,50 +6,63 @@ from data import model
from test.fixtures import * from test.fixtures import *
@pytest.mark.parametrize('header, expected_result', [ @pytest.mark.parametrize(
('', ValidateResult(AuthKind.oauth, missing=True)), "header, expected_result",
('somerandomtoken', ValidateResult(AuthKind.oauth, missing=True)), [
('bearer some random token', ValidateResult(AuthKind.oauth, missing=True)), ("", ValidateResult(AuthKind.oauth, missing=True)),
('bearer invalidtoken', ("somerandomtoken", ValidateResult(AuthKind.oauth, missing=True)),
ValidateResult(AuthKind.oauth, error_message='OAuth access token could not be validated')),]) ("bearer some random token", ValidateResult(AuthKind.oauth, missing=True)),
(
"bearer invalidtoken",
ValidateResult(
AuthKind.oauth,
error_message="OAuth access token could not be validated",
),
),
],
)
def test_bearer(header, expected_result, app): def test_bearer(header, expected_result, app):
assert validate_bearer_auth(header) == expected_result assert validate_bearer_auth(header) == expected_result
def test_valid_oauth(app): def test_valid_oauth(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0] app = model.oauth.list_applications_for_org(
token_string = '%s%s' % ('a' * 20, 'b' * 20) model.user.get_user_or_org("buynlarge")
oauth_token, _ = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read', )[0]
access_token=token_string) token_string = "%s%s" % ("a" * 20, "b" * 20)
result = validate_bearer_auth('bearer ' + token_string) oauth_token, _ = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read", access_token=token_string
)
result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken == oauth_token assert result.context.oauthtoken == oauth_token
assert result.authed_user == user assert result.authed_user == user
assert result.auth_valid assert result.auth_valid
def test_disabled_user_oauth(app): def test_disabled_user_oauth(app):
user = model.user.get_user('disabled') user = model.user.get_user("disabled")
token_string = '%s%s' % ('a' * 20, 'b' * 20) token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin', oauth_token, _ = model.oauth.create_access_token_for_testing(
access_token=token_string) user, "deadbeef", "repo:admin", access_token=token_string
)
result = validate_bearer_auth('bearer ' + token_string) result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None assert result.context.oauthtoken is None
assert result.authed_user is None assert result.authed_user is None
assert not result.auth_valid assert not result.auth_valid
assert result.error_message == 'Granter of the oauth access token is disabled' assert result.error_message == "Granter of the oauth access token is disabled"
def test_expired_token(app): def test_expired_token(app):
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
token_string = '%s%s' % ('a' * 20, 'b' * 20) token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin', oauth_token, _ = model.oauth.create_access_token_for_testing(
access_token=token_string, user, "deadbeef", "repo:admin", access_token=token_string, expires_in=-1000
expires_in=-1000) )
result = validate_bearer_auth('bearer ' + token_string) result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None assert result.context.oauthtoken is None
assert result.authed_user is None assert result.authed_user is None
assert not result.auth_valid assert not result.auth_valid
assert result.error_message == 'OAuth access token has expired' assert result.error_message == "OAuth access token has expired"

View file

@ -6,8 +6,9 @@ from data import model
from test.fixtures import * from test.fixtures import *
SUPER_USERNAME = 'devtable' SUPER_USERNAME = "devtable"
UNSUPER_USERNAME = 'freshuser' UNSUPER_USERNAME = "freshuser"
@pytest.fixture() @pytest.fixture()
def superuser(initialized_db): def superuser(initialized_db):

View file

@ -14,24 +14,20 @@ from initdb import setup_database_for_testing, finished_database_for_testing
from util.morecollections import AttrDict from util.morecollections import AttrDict
from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject
TEST_AUDIENCE = app.config['SERVER_HOSTNAME'] TEST_AUDIENCE = app.config["SERVER_HOSTNAME"]
TEST_USER = AttrDict({'username': 'joeuser', 'uuid': 'foobar', 'enabled': True}) TEST_USER = AttrDict({"username": "joeuser", "uuid": "foobar", "enabled": True})
MAX_SIGNED_S = 3660 MAX_SIGNED_S = 3660
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
ANONYMOUS_SUB = '(anonymous)' ANONYMOUS_SUB = "(anonymous)"
SERVICE_NAME = 'quay' SERVICE_NAME = "quay"
# This import has to come below any references to "app". # This import has to come below any references to "app".
from test.fixtures import * from test.fixtures import *
def _access(typ='repository', name='somens/somerepo', actions=None): def _access(typ="repository", name="somens/somerepo", actions=None):
actions = [] if actions is None else actions actions = [] if actions is None else actions
return [{ return [{"type": typ, "name": name, "actions": actions}]
'type': typ,
'name': name,
'actions': actions,
}]
def _delete_field(token_data, field_name): def _delete_field(token_data, field_name):
@ -39,19 +35,28 @@ def _delete_field(token_data, field_name):
return token_data return token_data
def _token_data(access=[], context=None, audience=TEST_AUDIENCE, user=TEST_USER, iat=None, def _token_data(
exp=None, nbf=None, iss=None, subject=None): access=[],
context=None,
audience=TEST_AUDIENCE,
user=TEST_USER,
iat=None,
exp=None,
nbf=None,
iss=None,
subject=None,
):
if subject is None: if subject is None:
_, subject = build_context_and_subject(ValidatedAuthContext(user=user)) _, subject = build_context_and_subject(ValidatedAuthContext(user=user))
return { return {
'iss': iss or instance_keys.service_name, "iss": iss or instance_keys.service_name,
'aud': audience, "aud": audience,
'nbf': nbf if nbf is not None else int(time.time()), "nbf": nbf if nbf is not None else int(time.time()),
'iat': iat if iat is not None else int(time.time()), "iat": iat if iat is not None else int(time.time()),
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S), "exp": exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': subject, "sub": subject,
'access': access, "access": access,
'context': context, "context": context,
} }
@ -62,13 +67,15 @@ def _token(token_data, key_id=None, private_key=None, skip_header=False, alg=Non
if alg == "none": if alg == "none":
private_key = None private_key = None
token_headers = {'kid': key_id} token_headers = {"kid": key_id}
if skip_header: if skip_header:
token_headers = {} token_headers = {}
token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers) token_data = jwt.encode(
return 'Bearer {0}'.format(token_data) token_data, private_key, alg or "RS256", headers=token_headers
)
return "Bearer {0}".format(token_data)
def _parse_token(token): def _parse_token(token):
@ -78,65 +85,92 @@ def _parse_token(token):
def test_accepted_token(initialized_db): def test_accepted_token(initialized_db):
token = _token(_token_data()) token = _token(_token_data())
identity = _parse_token(token) identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username, assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
identity.id) TEST_USER.username,
identity.id,
)
assert len(identity.provides) == 0 assert len(identity.provides) == 0
anon_token = _token(_token_data(user=None)) anon_token = _token(_token_data(user=None))
anon_identity = _parse_token(anon_token) anon_identity = _parse_token(anon_token)
assert anon_identity.id == ANONYMOUS_SUB, 'should be %s, but was %s' % (ANONYMOUS_SUB, assert anon_identity.id == ANONYMOUS_SUB, "should be %s, but was %s" % (
anon_identity.id) ANONYMOUS_SUB,
anon_identity.id,
)
assert len(identity.provides) == 0 assert len(identity.provides) == 0
@pytest.mark.parametrize('access', [ @pytest.mark.parametrize(
(_access(actions=['pull', 'push'])), "access",
(_access(actions=['pull', '*'])), [
(_access(actions=['*', 'push'])), (_access(actions=["pull", "push"])),
(_access(actions=['*'])), (_access(actions=["pull", "*"])),
(_access(actions=['pull', '*', 'push'])),]) (_access(actions=["*", "push"])),
(_access(actions=["*"])),
(_access(actions=["pull", "*", "push"])),
],
)
def test_token_with_access(access, initialized_db): def test_token_with_access(access, initialized_db):
token = _token(_token_data(access=access)) token = _token(_token_data(access=access))
identity = _parse_token(token) identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username, assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
identity.id) TEST_USER.username,
identity.id,
)
assert len(identity.provides) == 1 assert len(identity.provides) == 1
role = list(identity.provides)[0][3] role = list(identity.provides)[0][3]
if "*" in access[0]['actions']: if "*" in access[0]["actions"]:
assert role == 'admin' assert role == "admin"
elif "push" in access[0]['actions']: elif "push" in access[0]["actions"]:
assert role == 'write' assert role == "write"
elif "pull" in access[0]['actions']: elif "pull" in access[0]["actions"]:
assert role == 'read' assert role == "read"
@pytest.mark.parametrize('token', [ @pytest.mark.parametrize(
pytest.param(_token( "token",
_token_data(access=[{ [
'toipe': 'repository', pytest.param(
'namesies': 'somens/somerepo', _token(
'akshuns': ['pull', 'push', '*']}])), id='bad access'), _token_data(
pytest.param(_token(_token_data(audience='someotherapp')), id='bad aud'), access=[
pytest.param(_token(_delete_field(_token_data(), 'aud')), id='no aud'), {
pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id='future nbf'), "toipe": "repository",
pytest.param(_token(_delete_field(_token_data(), 'nbf')), id='no nbf'), "namesies": "somens/somerepo",
pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id='future iat'), "akshuns": ["pull", "push", "*"],
pytest.param(_token(_delete_field(_token_data(), 'iat')), id='no iat'), }
pytest.param(_token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)), id='exp too long'), ]
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id='expired'), )
pytest.param(_token(_delete_field(_token_data(), 'exp')), id='no exp'), ),
pytest.param(_token(_delete_field(_token_data(), 'sub')), id='no sub'), id="bad access",
pytest.param(_token(_token_data(iss='badissuer')), id='bad iss'), ),
pytest.param(_token(_delete_field(_token_data(), 'iss')), id='no iss'), pytest.param(_token(_token_data(audience="someotherapp")), id="bad aud"),
pytest.param(_token(_token_data(), skip_header=True), id='no header'), pytest.param(_token(_delete_field(_token_data(), "aud")), id="no aud"),
pytest.param(_token(_token_data(), key_id='someunknownkey'), id='bad key'), pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id="future nbf"),
pytest.param(_token(_token_data(), key_id='kid7'), id='bad key :: kid7'), pytest.param(_token(_delete_field(_token_data(), "nbf")), id="no nbf"),
pytest.param(_token(_token_data(), alg='none', private_key=None), id='none alg'), pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id="future iat"),
pytest.param('some random token', id='random token'), pytest.param(_token(_delete_field(_token_data(), "iat")), id="no iat"),
pytest.param('Bearer: sometokenhere', id='extra bearer'), pytest.param(
pytest.param('\nBearer: dGVzdA', id='leading newline'), _token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)),
]) id="exp too long",
),
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id="expired"),
pytest.param(_token(_delete_field(_token_data(), "exp")), id="no exp"),
pytest.param(_token(_delete_field(_token_data(), "sub")), id="no sub"),
pytest.param(_token(_token_data(iss="badissuer")), id="bad iss"),
pytest.param(_token(_delete_field(_token_data(), "iss")), id="no iss"),
pytest.param(_token(_token_data(), skip_header=True), id="no header"),
pytest.param(_token(_token_data(), key_id="someunknownkey"), id="bad key"),
pytest.param(_token(_token_data(), key_id="kid7"), id="bad key :: kid7"),
pytest.param(
_token(_token_data(), alg="none", private_key=None), id="none alg"
),
pytest.param("some random token", id="random token"),
pytest.param("Bearer: sometokenhere", id="extra bearer"),
pytest.param("\nBearer: dGVzdA", id="leading newline"),
],
)
def test_invalid_jwt(token, initialized_db): def test_invalid_jwt(token, initialized_db):
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(token) _parse_token(token)
@ -146,34 +180,39 @@ def test_mixing_keys_e2e(initialized_db):
token_data = _token_data() token_data = _token_data()
# Create a new key for testing. # Create a new key for testing.
p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey', p, key = model.service_keys.generate_service_key(
name='newkey', metadata={}) instance_keys.service_name, None, kid="newkey", name="newkey", metadata={}
private_key = p.exportKey('PEM') )
private_key = p.exportKey("PEM")
# Test first with the new valid, but unapproved key. # Test first with the new valid, but unapproved key.
unapproved_key_token = _token(token_data, key_id='newkey', private_key=private_key) unapproved_key_token = _token(token_data, key_id="newkey", private_key=private_key)
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(unapproved_key_token) _parse_token(unapproved_key_token)
# Approve the key and try again. # Approve the key and try again.
admin_user = model.user.get_user('devtable') admin_user = model.user.get_user("devtable")
model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user) model.service_keys.approve_service_key(
key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user
)
valid_token = _token(token_data, key_id='newkey', private_key=private_key) valid_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(valid_token) identity = _parse_token(valid_token)
assert identity.id == TEST_USER.username assert identity.id == TEST_USER.username
assert len(identity.provides) == 0 assert len(identity.provides) == 0
# Try using a different private key with the existing key ID. # Try using a different private key with the existing key ID.
bad_private_token = _token(token_data, key_id='newkey', bad_private_token = _token(
private_key=instance_keys.local_private_key) token_data, key_id="newkey", private_key=instance_keys.local_private_key
)
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(bad_private_token) _parse_token(bad_private_token)
# Try using a different key ID with the existing private key. # Try using a different key ID with the existing private key.
kid_mismatch_token = _token(token_data, key_id=instance_keys.local_key_id, kid_mismatch_token = _token(
private_key=private_key) token_data, key_id=instance_keys.local_key_id, private_key=private_key
)
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(kid_mismatch_token) _parse_token(kid_mismatch_token)
@ -181,7 +220,7 @@ def test_mixing_keys_e2e(initialized_db):
key.delete_instance(recursive=True) key.delete_instance(recursive=True)
# Ensure it still works (via the cache.) # Ensure it still works (via the cache.)
deleted_key_token = _token(token_data, key_id='newkey', private_key=private_key) deleted_key_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(deleted_key_token) identity = _parse_token(deleted_key_token)
assert identity.id == TEST_USER.username assert identity.id == TEST_USER.username
assert len(identity.provides) == 0 assert len(identity.provides) == 0
@ -194,10 +233,7 @@ def test_mixing_keys_e2e(initialized_db):
_parse_token(deleted_key_token) _parse_token(deleted_key_token)
@pytest.mark.parametrize('token', [ @pytest.mark.parametrize("token", [u"someunicodetoken✡", u"\xc9\xad\xbd"])
u'someunicodetoken✡',
u'\xc9\xad\xbd',
])
def test_unicode_token(token): def test_unicode_token(token):
with pytest.raises(InvalidJWTException): with pytest.raises(InvalidJWTException):
_parse_token(token) _parse_token(token)

View file

@ -1,34 +1,35 @@
import pytest import pytest
from auth.scopes import ( from auth.scopes import (
scopes_from_scope_string, validate_scope_string, ALL_SCOPES, is_subset_string) scopes_from_scope_string,
validate_scope_string,
ALL_SCOPES,
is_subset_string,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'scopes_string, expected', "scopes_string, expected",
[ [
# Valid single scopes. # Valid single scopes.
('repo:read', ['repo:read']), ("repo:read", ["repo:read"]),
('repo:admin', ['repo:admin']), ("repo:admin", ["repo:admin"]),
# Invalid scopes. # Invalid scopes.
('not:valid', []), ("not:valid", []),
('repo:admins', []), ("repo:admins", []),
# Valid scope strings. # Valid scope strings.
('repo:read repo:admin', ['repo:read', 'repo:admin']), ("repo:read repo:admin", ["repo:read", "repo:admin"]),
('repo:read,repo:admin', ['repo:read', 'repo:admin']), ("repo:read,repo:admin", ["repo:read", "repo:admin"]),
('repo:read,repo:admin repo:write', ['repo:read', 'repo:admin', 'repo:write']), ("repo:read,repo:admin repo:write", ["repo:read", "repo:admin", "repo:write"]),
# Partially invalid scopes. # Partially invalid scopes.
('repo:read,not:valid', []), ("repo:read,not:valid", []),
('repo:read repo:admins', []), ("repo:read repo:admins", []),
# Invalid scope strings. # Invalid scope strings.
('repo:read|repo:admin', []), ("repo:read|repo:admin", []),
# Mixture of delimiters. # Mixture of delimiters.
('repo:read, repo:admin', []),]) ("repo:read, repo:admin", []),
],
)
def test_parsing(scopes_string, expected): def test_parsing(scopes_string, expected):
expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected} expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
parsed_scope_set = scopes_from_scope_string(scopes_string) parsed_scope_set = scopes_from_scope_string(scopes_string)
@ -36,15 +37,19 @@ def test_parsing(scopes_string, expected):
assert validate_scope_string(scopes_string) == bool(expected) assert validate_scope_string(scopes_string) == bool(expected)
@pytest.mark.parametrize('superset, subset, result', [ @pytest.mark.parametrize(
('repo:read', 'repo:read', True), "superset, subset, result",
('repo:read repo:admin', 'repo:read', True), [
('repo:read,repo:admin', 'repo:read', True), ("repo:read", "repo:read", True),
('repo:read,repo:admin', 'repo:admin', True), ("repo:read repo:admin", "repo:read", True),
('repo:read,repo:admin', 'repo:admin repo:read', True), ("repo:read,repo:admin", "repo:read", True),
('', 'repo:read', False), ("repo:read,repo:admin", "repo:admin", True),
('unknown:tag', 'repo:read', False), ("repo:read,repo:admin", "repo:admin repo:read", True),
('repo:read unknown:tag', 'repo:read', False), ("", "repo:read", False),
('repo:read,unknown:tag', 'repo:read', False),]) ("unknown:tag", "repo:read", False),
("repo:read unknown:tag", "repo:read", False),
("repo:read,unknown:tag", "repo:read", False),
],
)
def test_subset_string(superset, subset, result): def test_subset_string(superset, subset, result):
assert is_subset_string(superset, subset) == result assert is_subset_string(superset, subset) == result

View file

@ -1,32 +1,47 @@
import pytest import pytest
from auth.signedgrant import validate_signed_grant, generate_signed_token, SIGNATURE_PREFIX from auth.signedgrant import (
validate_signed_grant,
generate_signed_token,
SIGNATURE_PREFIX,
)
from auth.validateresult import AuthKind, ValidateResult from auth.validateresult import AuthKind, ValidateResult
@pytest.mark.parametrize('header, expected_result', [ @pytest.mark.parametrize(
pytest.param('', ValidateResult(AuthKind.signed_grant, missing=True), id='Missing'), "header, expected_result",
pytest.param('somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True), [
id='Invalid header'), pytest.param(
pytest.param('token somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True), "", ValidateResult(AuthKind.signed_grant, missing=True), id="Missing"
id='Random Token'), ),
pytest.param('token ' + SIGNATURE_PREFIX + 'foo', pytest.param(
ValidateResult(AuthKind.signed_grant, "somerandomtoken",
error_message='Signed grant could not be validated'), ValidateResult(AuthKind.signed_grant, missing=True),
id='Invalid token'), id="Invalid header",
]) ),
pytest.param(
"token somerandomtoken",
ValidateResult(AuthKind.signed_grant, missing=True),
id="Random Token",
),
pytest.param(
"token " + SIGNATURE_PREFIX + "foo",
ValidateResult(
AuthKind.signed_grant,
error_message="Signed grant could not be validated",
),
id="Invalid token",
),
],
)
def test_token(header, expected_result): def test_token(header, expected_result):
assert validate_signed_grant(header) == expected_result assert validate_signed_grant(header) == expected_result
def test_valid_grant(): def test_valid_grant():
header = 'token ' + generate_signed_token({'a': 'b'}, {'c': 'd'}) header = "token " + generate_signed_token({"a": "b"}, {"c": "d"})
expected = ValidateResult(AuthKind.signed_grant, signed_data={ expected = ValidateResult(
'grants': { AuthKind.signed_grant,
'a': 'b', signed_data={"grants": {"a": "b"}, "user_context": {"c": "d"}},
}, )
'user_context': {
'c': 'd'
},
})
assert validate_signed_grant(header) == expected assert validate_signed_grant(header) == expected

View file

@ -6,34 +6,44 @@ from data import model
from data.database import AppSpecificAuthToken from data.database import AppSpecificAuthToken
from test.fixtures import * from test.fixtures import *
def get_user(): def get_user():
return model.user.get_user('devtable') return model.user.get_user("devtable")
def get_app_specific_token(): def get_app_specific_token():
return AppSpecificAuthToken.get() return AppSpecificAuthToken.get()
def get_robot(): def get_robot():
robot, _ = model.user.create_robot('somebot', get_user()) robot, _ = model.user.create_robot("somebot", get_user())
return robot return robot
def get_token(): def get_token():
return model.token.create_delegate_token('devtable', 'simple', 'sometoken') return model.token.create_delegate_token("devtable", "simple", "sometoken")
def get_oauthtoken(): def get_oauthtoken():
user = model.user.get_user('devtable') user = model.user.get_user("devtable")
return list(model.oauth.list_access_tokens_for_user(user))[0] return list(model.oauth.list_access_tokens_for_user(user))[0]
def get_signeddata():
return {'grants': {'a': 'b'}, 'user_context': {'c': 'd'}}
@pytest.mark.parametrize('get_entity,entity_kind', [ def get_signeddata():
(get_user, 'user'), return {"grants": {"a": "b"}, "user_context": {"c": "d"}}
(get_robot, 'robot'),
(get_token, 'token'),
(get_oauthtoken, 'oauthtoken'), @pytest.mark.parametrize(
(get_signeddata, 'signed_data'), "get_entity,entity_kind",
(get_app_specific_token, 'appspecifictoken'), [
]) (get_user, "user"),
(get_robot, "robot"),
(get_token, "token"),
(get_oauthtoken, "oauthtoken"),
(get_signeddata, "signed_data"),
(get_app_specific_token, "appspecifictoken"),
],
)
def test_apply_context(get_entity, entity_kind, app): def test_apply_context(get_entity, entity_kind, app):
assert get_authenticated_context() is None assert get_authenticated_context() is None
@ -44,17 +54,17 @@ def test_apply_context(get_entity, entity_kind, app):
result = ValidateResult(AuthKind.basic, **args) result = ValidateResult(AuthKind.basic, **args)
result.apply_to_context() result.apply_to_context()
expected_user = entity if entity_kind == 'user' or entity_kind == 'robot' else None expected_user = entity if entity_kind == "user" or entity_kind == "robot" else None
if entity_kind == 'oauthtoken': if entity_kind == "oauthtoken":
expected_user = entity.authorized_user expected_user = entity.authorized_user
if entity_kind == 'appspecifictoken': if entity_kind == "appspecifictoken":
expected_user = entity.user expected_user = entity.user
expected_token = entity if entity_kind == 'token' else None expected_token = entity if entity_kind == "token" else None
expected_oauth = entity if entity_kind == 'oauthtoken' else None expected_oauth = entity if entity_kind == "oauthtoken" else None
expected_appspecifictoken = entity if entity_kind == 'appspecifictoken' else None expected_appspecifictoken = entity if entity_kind == "appspecifictoken" else None
expected_grant = entity if entity_kind == 'signed_data' else None expected_grant = entity if entity_kind == "signed_data" else None
assert get_authenticated_context().authed_user == expected_user assert get_authenticated_context().authed_user == expected_user
assert get_authenticated_context().token == expected_token assert get_authenticated_context().token == expected_token

View file

@ -3,22 +3,39 @@ from auth.auth_context_type import ValidatedAuthContext, ContextEntityKind
class AuthKind(Enum): class AuthKind(Enum):
cookie = 'cookie' cookie = "cookie"
basic = 'basic' basic = "basic"
oauth = 'oauth' oauth = "oauth"
signed_grant = 'signed_grant' signed_grant = "signed_grant"
credentials = 'credentials' credentials = "credentials"
class ValidateResult(object): class ValidateResult(object):
""" A result of validating auth in one form or another. """ """ A result of validating auth in one form or another. """
def __init__(self, kind, missing=False, user=None, token=None, oauthtoken=None,
robot=None, appspecifictoken=None, signed_data=None, error_message=None): def __init__(
self,
kind,
missing=False,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
error_message=None,
):
self.kind = kind self.kind = kind
self.missing = missing self.missing = missing
self.error_message = error_message self.error_message = error_message
self.context = ValidatedAuthContext(user=user, token=token, oauthtoken=oauthtoken, robot=robot, self.context = ValidatedAuthContext(
appspecifictoken=appspecifictoken, signed_data=signed_data) user=user,
token=token,
oauthtoken=oauthtoken,
robot=robot,
appspecifictoken=appspecifictoken,
signed_data=signed_data,
)
def tuple(self): def tuple(self):
return (self.kind, self.missing, self.error_message, self.context.tuple()) return (self.kind, self.missing, self.error_message, self.context.tuple())
@ -32,13 +49,18 @@ class ValidateResult(object):
def with_kind(self, kind): def with_kind(self, kind):
""" Returns a copy of this result, but with the kind replaced. """ """ Returns a copy of this result, but with the kind replaced. """
result = ValidateResult(kind, missing=self.missing, error_message=self.error_message) result = ValidateResult(
kind, missing=self.missing, error_message=self.error_message
)
result.context = self.context result.context = self.context
return result return result
def __repr__(self): def __repr__(self):
return 'ValidateResult: %s (missing: %s, error: %s)' % (self.kind, self.missing, return "ValidateResult: %s (missing: %s, error: %s)" % (
self.error_message) self.kind,
self.missing,
self.error_message,
)
@property @property
def authed_user(self): def authed_user(self):

View file

@ -6,14 +6,18 @@ from requests.exceptions import RequestException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Avatar(object): class Avatar(object):
def __init__(self, app=None): def __init__(self, app=None):
self.app = app self.app = app
self.state = self._init_app(app) self.state = self._init_app(app)
def _init_app(self, app): def _init_app(self, app):
return AVATAR_CLASSES[app.config.get('AVATAR_KIND', 'Gravatar')]( return AVATAR_CLASSES[app.config.get("AVATAR_KIND", "Gravatar")](
app.config['PREFERRED_URL_SCHEME'], app.config['AVATAR_COLORS'], app.config['HTTPCLIENT']) app.config["PREFERRED_URL_SCHEME"],
app.config["AVATAR_COLORS"],
app.config["HTTPCLIENT"],
)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.state, name, None) return getattr(self.state, name, None)
@ -21,17 +25,18 @@ class Avatar(object):
class BaseAvatar(object): class BaseAvatar(object):
""" Base class for all avatar implementations. """ """ Base class for all avatar implementations. """
def __init__(self, preferred_url_scheme, colors, http_client): def __init__(self, preferred_url_scheme, colors, http_client):
self.preferred_url_scheme = preferred_url_scheme self.preferred_url_scheme = preferred_url_scheme
self.colors = colors self.colors = colors
self.http_client = http_client self.http_client = http_client
def get_mail_html(self, name, email_or_id, size=16, kind='user'): def get_mail_html(self, name, email_or_id, size=16, kind="user"):
""" Returns the full HTML and CSS for viewing the avatar of the given name and email address, """ Returns the full HTML and CSS for viewing the avatar of the given name and email address,
with an optional size. with an optional size.
""" """
data = self.get_data(name, email_or_id, kind) data = self.get_data(name, email_or_id, kind)
url = self._get_url(data['hash'], size) if kind != 'team' else None url = self._get_url(data["hash"], size) if kind != "team" else None
font_size = size - 6 font_size = size - 6
if url is not None: if url is not None:
@ -41,12 +46,21 @@ class BaseAvatar(object):
response = self.http_client.get(url, timeout=5) response = self.http_client.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:
return """<img src="%s" width="%s" height="%s" alt="%s" return """<img src="%s" width="%s" height="%s" alt="%s"
style="vertical-align: middle;">""" % (url, size, size, kind) style="vertical-align: middle;">""" % (
url,
size,
size,
kind,
)
except RequestException: except RequestException:
logger.exception('Could not retrieve avatar for user %s', name) logger.exception("Could not retrieve avatar for user %s", name)
radius = '50%' if kind == 'team' else '0%' radius = "50%" if kind == "team" else "0%"
letter = '&Omega;' if kind == 'team' and data['name'] == 'owners' else data['name'].upper()[0] letter = (
"&Omega;"
if kind == "team" and data["name"] == "owners"
else data["name"].upper()[0]
)
return """ return """
<span style="width: %spx; height: %spx; background-color: %s; font-size: %spx; <span style="width: %spx; height: %spx; background-color: %s; font-size: %spx;
@ -54,21 +68,31 @@ class BaseAvatar(object):
vertical-align: middle; text-align: center; color: white; border-radius: %s"> vertical-align: middle; text-align: center; color: white; border-radius: %s">
%s %s
</span> </span>
""" % (size, size, data['color'], font_size, size, radius, letter) """ % (
size,
size,
data["color"],
font_size,
size,
radius,
letter,
)
def get_data_for_user(self, user): def get_data_for_user(self, user):
return self.get_data(user.username, user.email, 'robot' if user.robot else 'user') return self.get_data(
user.username, user.email, "robot" if user.robot else "user"
)
def get_data_for_team(self, team): def get_data_for_team(self, team):
return self.get_data(team.name, team.name, 'team') return self.get_data(team.name, team.name, "team")
def get_data_for_org(self, org): def get_data_for_org(self, org):
return self.get_data(org.username, org.email, 'org') return self.get_data(org.username, org.email, "org")
def get_data_for_external_user(self, external_user): def get_data_for_external_user(self, external_user):
return self.get_data(external_user.username, external_user.email, 'user') return self.get_data(external_user.username, external_user.email, "user")
def get_data(self, name, email_or_id, kind='user'): def get_data(self, name, email_or_id, kind="user"):
""" Computes and returns the full data block for the avatar: """ Computes and returns the full data block for the avatar:
{ {
'name': name, 'name': name,
@ -87,12 +111,7 @@ class BaseAvatar(object):
byte_data = hash_value[0:byte_count] byte_data = hash_value[0:byte_count]
hash_color = colors[int(byte_data, 16) % len(colors)] hash_color = colors[int(byte_data, 16) % len(colors)]
return { return {"name": name, "hash": hash_value, "color": hash_color, "kind": kind}
'name': name,
'hash': hash_value,
'color': hash_color,
'kind': kind
}
def _get_url(self, hash_value, size): def _get_url(self, hash_value, size):
""" Returns the URL for displaying the overlay avatar. """ """ Returns the URL for displaying the overlay avatar. """
@ -101,15 +120,19 @@ class BaseAvatar(object):
class GravatarAvatar(BaseAvatar): class GravatarAvatar(BaseAvatar):
""" Avatar system that uses gravatar for generating avatars. """ """ Avatar system that uses gravatar for generating avatars. """
def _get_url(self, hash_value, size=16): def _get_url(self, hash_value, size=16):
return '%s://www.gravatar.com/avatar/%s?d=404&size=%s' % (self.preferred_url_scheme, return "%s://www.gravatar.com/avatar/%s?d=404&size=%s" % (
hash_value, size) self.preferred_url_scheme,
hash_value,
size,
)
class LocalAvatar(BaseAvatar): class LocalAvatar(BaseAvatar):
""" Avatar system that uses the local system for generating avatars. """ """ Avatar system that uses the local system for generating avatars. """
pass pass
AVATAR_CLASSES = {
'gravatar': GravatarAvatar, AVATAR_CLASSES = {"gravatar": GravatarAvatar, "local": LocalAvatar}
'local': LocalAvatar
}

69
boot.py
View file

@ -24,44 +24,46 @@ logger = logging.getLogger(__name__)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_audience(): def get_audience():
audience = app.config.get('JWTPROXY_AUDIENCE') audience = app.config.get("JWTPROXY_AUDIENCE")
if audience: if audience:
return audience return audience
scheme = app.config.get('PREFERRED_URL_SCHEME') scheme = app.config.get("PREFERRED_URL_SCHEME")
hostname = app.config.get('SERVER_HOSTNAME') hostname = app.config.get("SERVER_HOSTNAME")
# hostname includes port, use that # hostname includes port, use that
if ':' in hostname: if ":" in hostname:
return urlunparse((scheme, hostname, '', '', '', '')) return urlunparse((scheme, hostname, "", "", "", ""))
# no port, guess based on scheme # no port, guess based on scheme
if scheme == 'https': if scheme == "https":
port = '443' port = "443"
else: else:
port = '80' port = "80"
return urlunparse((scheme, hostname + ':' + port, '', '', '', '')) return urlunparse((scheme, hostname + ":" + port, "", "", "", ""))
def _verify_service_key(): def _verify_service_key():
try: try:
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION']) as f: with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"]) as f:
quay_key_id = f.read() quay_key_id = f.read()
try: try:
get_service_key(quay_key_id, approved_only=False) get_service_key(quay_key_id, approved_only=False)
assert os.path.exists(app.config['INSTANCE_SERVICE_KEY_LOCATION']) assert os.path.exists(app.config["INSTANCE_SERVICE_KEY_LOCATION"])
return quay_key_id return quay_key_id
except ServiceKeyDoesNotExist: except ServiceKeyDoesNotExist:
logger.exception('Could not find non-expired existing service key %s; creating a new one', logger.exception(
quay_key_id) "Could not find non-expired existing service key %s; creating a new one",
quay_key_id,
)
return None return None
# Found a valid service key, so exiting. # Found a valid service key, so exiting.
except IOError: except IOError:
logger.exception('Could not load existing service key; creating a new one') logger.exception("Could not load existing service key; creating a new one")
return None return None
@ -69,38 +71,43 @@ def setup_jwt_proxy():
""" """
Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration. Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration.
""" """
if os.path.exists(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml')): if os.path.exists(os.path.join(CONF_DIR, "jwtproxy_conf.yaml")):
# Proxy is already setup. Make sure the service key is still valid. # Proxy is already setup. Make sure the service key is still valid.
quay_key_id = _verify_service_key() quay_key_id = _verify_service_key()
if quay_key_id is not None: if quay_key_id is not None:
return return
# Ensure we have an existing key if in read-only mode. # Ensure we have an existing key if in read-only mode.
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly': if app.config.get("REGISTRY_STATE", "normal") == "readonly":
quay_key_id = _verify_service_key() quay_key_id = _verify_service_key()
if quay_key_id is None: if quay_key_id is None:
raise Exception('No valid service key found for read-only registry.') raise Exception("No valid service key found for read-only registry.")
else: else:
# Generate the key for this Quay instance to use. # Generate the key for this Quay instance to use.
minutes_until_expiration = app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120) minutes_until_expiration = app.config.get(
"INSTANCE_SERVICE_KEY_EXPIRATION", 120
)
expiration = datetime.now() + timedelta(minutes=minutes_until_expiration) expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
quay_key, quay_key_id = generate_key(app.config['INSTANCE_SERVICE_KEY_SERVICE'], quay_key, quay_key_id = generate_key(
get_audience(), expiration_date=expiration) app.config["INSTANCE_SERVICE_KEY_SERVICE"],
get_audience(),
expiration_date=expiration,
)
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'], mode='w') as f: with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"], mode="w") as f:
f.truncate(0) f.truncate(0)
f.write(quay_key_id) f.write(quay_key_id)
with open(app.config['INSTANCE_SERVICE_KEY_LOCATION'], mode='w') as f: with open(app.config["INSTANCE_SERVICE_KEY_LOCATION"], mode="w") as f:
f.truncate(0) f.truncate(0)
f.write(quay_key.exportKey()) f.write(quay_key.exportKey())
# Generate the JWT proxy configuration. # Generate the JWT proxy configuration.
audience = get_audience() audience = get_audience()
registry = audience + '/keys' registry = audience + "/keys"
security_issuer = app.config.get('SECURITY_SCANNER_ISSUER_NAME', 'security_scanner') security_issuer = app.config.get("SECURITY_SCANNER_ISSUER_NAME", "security_scanner")
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml.jnj')) as f: with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml.jnj")) as f:
template = Template(f.read()) template = Template(f.read())
rendered = template.render( rendered = template.render(
conf_dir=CONF_DIR, conf_dir=CONF_DIR,
@ -108,16 +115,18 @@ def setup_jwt_proxy():
registry=registry, registry=registry,
key_id=quay_key_id, key_id=quay_key_id,
security_issuer=security_issuer, security_issuer=security_issuer,
service_key_location=app.config['INSTANCE_SERVICE_KEY_LOCATION'], service_key_location=app.config["INSTANCE_SERVICE_KEY_LOCATION"],
) )
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml'), 'w') as f: with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml"), "w") as f:
f.write(rendered) f.write(rendered)
def main(): def main():
if not app.config.get('SETUP_COMPLETE', False): if not app.config.get("SETUP_COMPLETE", False):
raise Exception('Your configuration bundle is either not mounted or setup has not been completed') raise Exception(
"Your configuration bundle is either not mounted or setup has not been completed"
)
sync_database_with_config(app.config) sync_database_with_config(app.config)
setup_jwt_proxy() setup_jwt_proxy()
@ -127,5 +136,5 @@ def main():
set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD) set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -16,6 +16,7 @@ class AsyncWrapper(object):
""" Wrapper class which will transform a syncronous library to one that can be used with """ Wrapper class which will transform a syncronous library to one that can be used with
trollius coroutines. trollius coroutines.
""" """
def __init__(self, delegate, loop=None, executor=None): def __init__(self, delegate, loop=None, executor=None):
self._loop = loop if loop is not None else get_event_loop() self._loop = loop if loop is not None else get_event_loop()
self._delegate = delegate self._delegate = delegate

View file

@ -18,80 +18,104 @@ from raven.conf import setup_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BUILD_MANAGERS = { BUILD_MANAGERS = {"enterprise": EnterpriseManager, "ephemeral": EphemeralBuilderManager}
'enterprise': EnterpriseManager,
'ephemeral': EphemeralBuilderManager,
}
EXTERNALLY_MANAGED = 'external' EXTERNALLY_MANAGED = "external"
DEFAULT_WEBSOCKET_PORT = 8787 DEFAULT_WEBSOCKET_PORT = 8787
DEFAULT_CONTROLLER_PORT = 8686 DEFAULT_CONTROLLER_PORT = 8686
LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s" LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
def run_build_manager(): def run_build_manager():
if not features.BUILD_SUPPORT: if not features.BUILD_SUPPORT:
logger.debug('Building is disabled. Please enable the feature flag') logger.debug("Building is disabled. Please enable the feature flag")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly': if app.config.get("REGISTRY_STATE", "normal") == "readonly":
logger.debug('Building is disabled while in read-only mode.') logger.debug("Building is disabled while in read-only mode.")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
build_manager_config = app.config.get('BUILD_MANAGER') build_manager_config = app.config.get("BUILD_MANAGER")
if build_manager_config is None: if build_manager_config is None:
return return
# If the build system is externally managed, then we just sleep this process. # If the build system is externally managed, then we just sleep this process.
if build_manager_config[0] == EXTERNALLY_MANAGED: if build_manager_config[0] == EXTERNALLY_MANAGED:
logger.debug('Builds are externally managed.') logger.debug("Builds are externally managed.")
while True: while True:
time.sleep(1000) time.sleep(1000)
return return
logger.debug('Asking to start build manager with lifecycle "%s"', build_manager_config[0]) logger.debug(
'Asking to start build manager with lifecycle "%s"', build_manager_config[0]
)
manager_klass = BUILD_MANAGERS.get(build_manager_config[0]) manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
if manager_klass is None: if manager_klass is None:
return return
manager_hostname = os.environ.get('BUILDMAN_HOSTNAME', manager_hostname = os.environ.get(
app.config.get('BUILDMAN_HOSTNAME', "BUILDMAN_HOSTNAME",
app.config['SERVER_HOSTNAME'])) app.config.get("BUILDMAN_HOSTNAME", app.config["SERVER_HOSTNAME"]),
websocket_port = int(os.environ.get('BUILDMAN_WEBSOCKET_PORT', )
app.config.get('BUILDMAN_WEBSOCKET_PORT', websocket_port = int(
DEFAULT_WEBSOCKET_PORT))) os.environ.get(
controller_port = int(os.environ.get('BUILDMAN_CONTROLLER_PORT', "BUILDMAN_WEBSOCKET_PORT",
app.config.get('BUILDMAN_CONTROLLER_PORT', app.config.get("BUILDMAN_WEBSOCKET_PORT", DEFAULT_WEBSOCKET_PORT),
DEFAULT_CONTROLLER_PORT))) )
)
controller_port = int(
os.environ.get(
"BUILDMAN_CONTROLLER_PORT",
app.config.get("BUILDMAN_CONTROLLER_PORT", DEFAULT_CONTROLLER_PORT),
)
)
logger.debug('Will pass buildman hostname %s to builders for websocket connection', logger.debug(
manager_hostname) "Will pass buildman hostname %s to builders for websocket connection",
manager_hostname,
)
logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0]) logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
ssl_context = None ssl_context = None
if os.environ.get('SSL_CONFIG'): if os.environ.get("SSL_CONFIG"):
logger.debug('Loading SSL cert and key') logger.debug("Loading SSL cert and key")
ssl_context = SSLContext() ssl_context = SSLContext()
ssl_context.load_cert_chain(os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.cert'), ssl_context.load_cert_chain(
os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.key')) os.path.join(os.environ.get("SSL_CONFIG"), "ssl.cert"),
os.path.join(os.environ.get("SSL_CONFIG"), "ssl.key"),
)
server = BuilderServer(app.config['SERVER_HOSTNAME'], dockerfile_build_queue, build_logs, server = BuilderServer(
user_files, manager_klass, build_manager_config[1], manager_hostname) app.config["SERVER_HOSTNAME"],
server.run('0.0.0.0', websocket_port, controller_port, ssl=ssl_context) dockerfile_build_queue,
build_logs,
user_files,
manager_klass,
build_manager_config[1],
manager_hostname,
)
server.run("0.0.0.0", websocket_port, controller_port, ssl=ssl_context)
if __name__ == '__main__':
if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False) logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
logging.getLogger('peewee').setLevel(logging.WARN) logging.getLogger("peewee").setLevel(logging.WARN)
logging.getLogger('boto').setLevel(logging.WARN) logging.getLogger("boto").setLevel(logging.WARN)
if app.config.get('EXCEPTION_LOG_TYPE', 'FakeSentry') == 'Sentry': if app.config.get("EXCEPTION_LOG_TYPE", "FakeSentry") == "Sentry":
buildman_name = '%s:buildman' % socket.gethostname() buildman_name = "%s:buildman" % socket.gethostname()
setup_logging(SentryHandler(app.config.get('SENTRY_DSN', ''), name=buildman_name, setup_logging(
level=logging.ERROR)) SentryHandler(
app.config.get("SENTRY_DSN", ""),
name=buildman_name,
level=logging.ERROR,
)
)
run_build_manager() run_build_manager()

View file

@ -1,7 +1,9 @@
from autobahn.asyncio.wamp import ApplicationSession from autobahn.asyncio.wamp import ApplicationSession
class BaseComponent(ApplicationSession): class BaseComponent(ApplicationSession):
""" Base class for all registered component sessions in the server. """ """ Base class for all registered component sessions in the server. """
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
ApplicationSession.__init__(self, config) ApplicationSession.__init__(self, config)
self.server = None self.server = None

View file

@ -27,23 +27,27 @@ BUILD_HEARTBEAT_DELAY = datetime.timedelta(seconds=30)
HEARTBEAT_TIMEOUT = 10 HEARTBEAT_TIMEOUT = 10
INITIAL_TIMEOUT = 25 INITIAL_TIMEOUT = 25
SUPPORTED_WORKER_VERSIONS = ['0.3'] SUPPORTED_WORKER_VERSIONS = ["0.3"]
# Label which marks a manifest with its source build ID. # Label which marks a manifest with its source build ID.
INTERNAL_LABEL_BUILD_UUID = 'quay.build.uuid' INTERNAL_LABEL_BUILD_UUID = "quay.build.uuid"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ComponentStatus(object): class ComponentStatus(object):
""" ComponentStatus represents the possible states of a component. """ """ ComponentStatus represents the possible states of a component. """
JOINING = 'joining'
WAITING = 'waiting' JOINING = "joining"
RUNNING = 'running' WAITING = "waiting"
BUILDING = 'building' RUNNING = "running"
TIMED_OUT = 'timeout' BUILDING = "building"
TIMED_OUT = "timeout"
class BuildComponent(BaseComponent): class BuildComponent(BaseComponent):
""" An application session component which conducts one (or more) builds. """ """ An application session component which conducts one (or more) builds. """
def __init__(self, config, realm=None, token=None, **kwargs): def __init__(self, config, realm=None, token=None, **kwargs):
self.expected_token = token self.expected_token = token
self.builder_realm = realm self.builder_realm = realm
@ -61,34 +65,55 @@ class BuildComponent(BaseComponent):
BaseComponent.__init__(self, config, **kwargs) BaseComponent.__init__(self, config, **kwargs)
def kind(self): def kind(self):
return 'builder' return "builder"
def onConnect(self): def onConnect(self):
self.join(self.builder_realm) self.join(self.builder_realm)
@trollius.coroutine @trollius.coroutine
def onJoin(self, details): def onJoin(self, details):
logger.debug('Registering methods and listeners for component %s', self.builder_realm) logger.debug(
yield From(self.register(self._on_ready, u'io.quay.buildworker.ready')) "Registering methods and listeners for component %s", self.builder_realm
yield From(self.register(self._determine_cache_tag, u'io.quay.buildworker.determinecachetag')) )
yield From(self.register(self._ping, u'io.quay.buildworker.ping')) yield From(self.register(self._on_ready, u"io.quay.buildworker.ready"))
yield From(self.register(self._on_log_message, u'io.quay.builder.logmessagesynchronously')) yield From(
self.register(
self._determine_cache_tag, u"io.quay.buildworker.determinecachetag"
)
)
yield From(self.register(self._ping, u"io.quay.buildworker.ping"))
yield From(
self.register(
self._on_log_message, u"io.quay.builder.logmessagesynchronously"
)
)
yield From(self.subscribe(self._on_heartbeat, u'io.quay.builder.heartbeat')) yield From(self.subscribe(self._on_heartbeat, u"io.quay.builder.heartbeat"))
yield From(self._set_status(ComponentStatus.WAITING)) yield From(self._set_status(ComponentStatus.WAITING))
@trollius.coroutine @trollius.coroutine
def start_build(self, build_job): def start_build(self, build_job):
""" Starts a build. """ """ Starts a build. """
if self._component_status not in (ComponentStatus.WAITING, ComponentStatus.RUNNING): if self._component_status not in (
logger.debug('Could not start build for component %s (build %s, worker version: %s): %s', ComponentStatus.WAITING,
self.builder_realm, build_job.repo_build.uuid, self._worker_version, ComponentStatus.RUNNING,
self._component_status) ):
logger.debug(
"Could not start build for component %s (build %s, worker version: %s): %s",
self.builder_realm,
build_job.repo_build.uuid,
self._worker_version,
self._component_status,
)
raise Return() raise Return()
logger.debug('Starting build for component %s (build %s, worker version: %s)', logger.debug(
self.builder_realm, build_job.repo_build.uuid, self._worker_version) "Starting build for component %s (build %s, worker version: %s)",
self.builder_realm,
build_job.repo_build.uuid,
self._worker_version,
)
self._current_job = build_job self._current_job = build_job
self._build_status = StatusHandler(self.build_logs, build_job.repo_build.uuid) self._build_status = StatusHandler(self.build_logs, build_job.repo_build.uuid)
@ -97,25 +122,31 @@ class BuildComponent(BaseComponent):
yield From(self._set_status(ComponentStatus.BUILDING)) yield From(self._set_status(ComponentStatus.BUILDING))
# Send the notification that the build has started. # Send the notification that the build has started.
build_job.send_notification('build_start') build_job.send_notification("build_start")
# Parse the build configuration. # Parse the build configuration.
try: try:
build_config = build_job.build_config build_config = build_job.build_config
except BuildJobLoadException as irbe: except BuildJobLoadException as irbe:
yield From(self._build_failure('Could not load build job information', irbe)) yield From(
self._build_failure("Could not load build job information", irbe)
)
raise Return() raise Return()
base_image_information = {} base_image_information = {}
# Add the pull robot information, if any. # Add the pull robot information, if any.
if build_job.pull_credentials: if build_job.pull_credentials:
base_image_information['username'] = build_job.pull_credentials.get('username', '') base_image_information["username"] = build_job.pull_credentials.get(
base_image_information['password'] = build_job.pull_credentials.get('password', '') "username", ""
)
base_image_information["password"] = build_job.pull_credentials.get(
"password", ""
)
# Retrieve the repository's fully qualified name. # Retrieve the repository's fully qualified name.
repo = build_job.repo_build.repository repo = build_job.repo_build.repository
repository_name = repo.namespace_user.username + '/' + repo.name repository_name = repo.namespace_user.username + "/" + repo.name
# Parse the build queue item into build arguments. # Parse the build queue item into build arguments.
# build_package: URL to the build package to download and untar/unzip. # build_package: URL to the build package to download and untar/unzip.
@ -131,15 +162,15 @@ class BuildComponent(BaseComponent):
# password: The password for pulling the base image (if any). # password: The password for pulling the base image (if any).
context, dockerfile_path = self.extract_dockerfile_args(build_config) context, dockerfile_path = self.extract_dockerfile_args(build_config)
build_arguments = { build_arguments = {
'build_package': build_job.get_build_package_url(self.user_files), "build_package": build_job.get_build_package_url(self.user_files),
'context': context, "context": context,
'dockerfile_path': dockerfile_path, "dockerfile_path": dockerfile_path,
'repository': repository_name, "repository": repository_name,
'registry': self.registry_hostname, "registry": self.registry_hostname,
'pull_token': build_job.repo_build.access_token.get_code(), "pull_token": build_job.repo_build.access_token.get_code(),
'push_token': build_job.repo_build.access_token.get_code(), "push_token": build_job.repo_build.access_token.get_code(),
'tag_names': build_config.get('docker_tags', ['latest']), "tag_names": build_config.get("docker_tags", ["latest"]),
'base_image': base_image_information, "base_image": base_image_information,
} }
# If the trigger has a private key, it's using git, thus we should add # If the trigger has a private key, it's using git, thus we should add
@ -150,39 +181,52 @@ class BuildComponent(BaseComponent):
# TODO(remove-unenc): Remove legacy field. # TODO(remove-unenc): Remove legacy field.
private_key = None private_key = None
if build_job.repo_build.trigger is not None and \ if (
build_job.repo_build.trigger.secure_private_key is not None: build_job.repo_build.trigger is not None
and build_job.repo_build.trigger.secure_private_key is not None
):
private_key = build_job.repo_build.trigger.secure_private_key.decrypt() private_key = build_job.repo_build.trigger.secure_private_key.decrypt()
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS) and \ if (
private_key is None and \ ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
build_job.repo_build.trigger is not None: and private_key is None
and build_job.repo_build.trigger is not None
):
private_key = build_job.repo_build.trigger.private_key private_key = build_job.repo_build.trigger.private_key
if private_key is not None: if private_key is not None:
build_arguments['git'] = { build_arguments["git"] = {
'url': build_config['trigger_metadata'].get('git_url', ''), "url": build_config["trigger_metadata"].get("git_url", ""),
'sha': BuildComponent._commit_sha(build_config), "sha": BuildComponent._commit_sha(build_config),
'private_key': private_key or '', "private_key": private_key or "",
} }
# If the build args have no buildpack, mark it as a failure before sending # If the build args have no buildpack, mark it as a failure before sending
# it to a builder instance. # it to a builder instance.
if not build_arguments['build_package'] and not build_arguments['git']: if not build_arguments["build_package"] and not build_arguments["git"]:
logger.error('%s: insufficient build args: %s', logger.error(
self._current_job.repo_build.uuid, build_arguments) "%s: insufficient build args: %s",
yield From(self._build_failure('Insufficient build arguments. No buildpack available.')) self._current_job.repo_build.uuid,
build_arguments,
)
yield From(
self._build_failure(
"Insufficient build arguments. No buildpack available."
)
)
raise Return() raise Return()
# Invoke the build. # Invoke the build.
logger.debug('Invoking build: %s', self.builder_realm) logger.debug("Invoking build: %s", self.builder_realm)
logger.debug('With Arguments: %s', build_arguments) logger.debug("With Arguments: %s", build_arguments)
def build_complete_callback(result): def build_complete_callback(result):
""" This function is used to execute a coroutine as the callback. """ """ This function is used to execute a coroutine as the callback. """
trollius.ensure_future(self._build_complete(result)) trollius.ensure_future(self._build_complete(result))
self.call("io.quay.builder.build", **build_arguments).add_done_callback(build_complete_callback) self.call("io.quay.builder.build", **build_arguments).add_done_callback(
build_complete_callback
)
# Set the heartbeat for the future. If the builder never receives the build call, # Set the heartbeat for the future. If the builder never receives the build call,
# then this will cause a timeout after 30 seconds. We know the builder has registered # then this will cause a timeout after 30 seconds. We know the builder has registered
@ -191,11 +235,11 @@ class BuildComponent(BaseComponent):
@staticmethod @staticmethod
def extract_dockerfile_args(build_config): def extract_dockerfile_args(build_config):
dockerfile_path = build_config.get('build_subdir', '') dockerfile_path = build_config.get("build_subdir", "")
context = build_config.get('context', '') context = build_config.get("context", "")
if not (dockerfile_path == '' or context == ''): if not (dockerfile_path == "" or context == ""):
# This should not happen and can be removed when we centralize validating build_config # This should not happen and can be removed when we centralize validating build_config
dockerfile_abspath = slash_join('', dockerfile_path) dockerfile_abspath = slash_join("", dockerfile_path)
if ".." in os.path.relpath(dockerfile_abspath, context): if ".." in os.path.relpath(dockerfile_abspath, context):
return os.path.split(dockerfile_path) return os.path.split(dockerfile_path)
dockerfile_path = os.path.relpath(dockerfile_abspath, context) dockerfile_path = os.path.relpath(dockerfile_abspath, context)
@ -205,8 +249,8 @@ class BuildComponent(BaseComponent):
@staticmethod @staticmethod
def _commit_sha(build_config): def _commit_sha(build_config):
""" Determines whether the metadata is using an old schema or not and returns the commit. """ """ Determines whether the metadata is using an old schema or not and returns the commit. """
commit_sha = build_config['trigger_metadata'].get('commit', '') commit_sha = build_config["trigger_metadata"].get("commit", "")
old_commit_sha = build_config['trigger_metadata'].get('commit_sha', '') old_commit_sha = build_config["trigger_metadata"].get("commit_sha", "")
return commit_sha or old_commit_sha return commit_sha or old_commit_sha
@staticmethod @staticmethod
@ -222,8 +266,8 @@ class BuildComponent(BaseComponent):
def _total_completion(statuses, total_images): def _total_completion(statuses, total_images):
""" Returns the current amount completion relative to the total completion of a build. """ """ Returns the current amount completion relative to the total completion of a build. """
percentage_with_sizes = float(len(statuses.values())) / total_images percentage_with_sizes = float(len(statuses.values())) / total_images
sent_bytes = sum([status['current'] for status in statuses.values()]) sent_bytes = sum([status["current"] for status in statuses.values()])
total_bytes = sum([status['total'] for status in statuses.values()]) total_bytes = sum([status["total"] for status in statuses.values()])
return float(sent_bytes) / total_bytes * percentage_with_sizes return float(sent_bytes) / total_bytes * percentage_with_sizes
@staticmethod @staticmethod
@ -233,27 +277,27 @@ class BuildComponent(BaseComponent):
return return
num_images = 0 num_images = 0
status_completion_key = '' status_completion_key = ""
if current_phase == 'pushing': if current_phase == "pushing":
status_completion_key = 'push_completion' status_completion_key = "push_completion"
num_images = status_dict['total_commands'] num_images = status_dict["total_commands"]
elif current_phase == 'pulling': elif current_phase == "pulling":
status_completion_key = 'pull_completion' status_completion_key = "pull_completion"
elif current_phase == 'priming-cache': elif current_phase == "priming-cache":
status_completion_key = 'cache_completion' status_completion_key = "cache_completion"
else: else:
return return
if 'progressDetail' in docker_data and 'id' in docker_data: if "progressDetail" in docker_data and "id" in docker_data:
image_id = docker_data['id'] image_id = docker_data["id"]
detail = docker_data['progressDetail'] detail = docker_data["progressDetail"]
if 'current' in detail and 'total' in detail: if "current" in detail and "total" in detail:
images[image_id] = detail images[image_id] = detail
status_dict[status_completion_key] = \ status_dict[status_completion_key] = BuildComponent._total_completion(
BuildComponent._total_completion(images, max(len(images), num_images)) images, max(len(images), num_images)
)
@trollius.coroutine @trollius.coroutine
def _on_log_message(self, phase, json_data): def _on_log_message(self, phase, json_data):
@ -270,8 +314,8 @@ class BuildComponent(BaseComponent):
pass pass
# Extract the current status message (if any). # Extract the current status message (if any).
fully_unwrapped = '' fully_unwrapped = ""
keys_to_extract = ['error', 'status', 'stream'] keys_to_extract = ["error", "status", "stream"]
for key in keys_to_extract: for key in keys_to_extract:
if key in log_data: if key in log_data:
fully_unwrapped = log_data[key] fully_unwrapped = log_data[key]
@ -279,7 +323,7 @@ class BuildComponent(BaseComponent):
# Determine if this is a step string. # Determine if this is a step string.
current_step = None current_step = None
current_status_string = str(fully_unwrapped.encode('utf-8')) current_status_string = str(fully_unwrapped.encode("utf-8"))
if current_status_string and phase == BUILD_PHASE.BUILDING: if current_status_string and phase == BUILD_PHASE.BUILDING:
current_step = extract_current_step(current_status_string) current_step = extract_current_step(current_status_string)
@ -288,28 +332,42 @@ class BuildComponent(BaseComponent):
# the pull/push progress, as well as the current step index. # the pull/push progress, as well as the current step index.
with self._build_status as status_dict: with self._build_status as status_dict:
try: try:
changed_phase = yield From(self._build_status.set_phase(phase, log_data.get('status_data'))) changed_phase = yield From(
self._build_status.set_phase(phase, log_data.get("status_data"))
)
if changed_phase: if changed_phase:
logger.debug('Build %s has entered a new phase: %s', self.builder_realm, phase) logger.debug(
"Build %s has entered a new phase: %s",
self.builder_realm,
phase,
)
elif self._current_job.repo_build.phase == BUILD_PHASE.CANCELLED: elif self._current_job.repo_build.phase == BUILD_PHASE.CANCELLED:
build_id = self._current_job.repo_build.uuid build_id = self._current_job.repo_build.uuid
logger.debug('Trying to move cancelled build into phase: %s with id: %s', phase, build_id) logger.debug(
"Trying to move cancelled build into phase: %s with id: %s",
phase,
build_id,
)
raise Return(False) raise Return(False)
except InvalidRepositoryBuildException: except InvalidRepositoryBuildException:
build_id = self._current_job.repo_build.uuid build_id = self._current_job.repo_build.uuid
logger.warning('Build %s was not found; repo was probably deleted', build_id) logger.warning(
"Build %s was not found; repo was probably deleted", build_id
)
raise Return(False) raise Return(False)
BuildComponent._process_pushpull_status(status_dict, phase, log_data, self._image_info) BuildComponent._process_pushpull_status(
status_dict, phase, log_data, self._image_info
)
# If the current message represents the beginning of a new step, then update the # If the current message represents the beginning of a new step, then update the
# current command index. # current command index.
if current_step is not None: if current_step is not None:
status_dict['current_command'] = current_step status_dict["current_command"] = current_step
# If the json data contains an error, then something went wrong with a push or pull. # If the json data contains an error, then something went wrong with a push or pull.
if 'error' in log_data: if "error" in log_data:
yield From(self._build_status.set_error(log_data['error'])) yield From(self._build_status.set_error(log_data["error"]))
if current_step is not None: if current_step is not None:
yield From(self._build_status.set_command(current_status_string)) yield From(self._build_status.set_command(current_status_string))
@ -318,25 +376,36 @@ class BuildComponent(BaseComponent):
raise Return(True) raise Return(True)
@trollius.coroutine @trollius.coroutine
def _determine_cache_tag(self, command_comments, base_image_name, base_image_tag, base_image_id): def _determine_cache_tag(
self, command_comments, base_image_name, base_image_tag, base_image_id
):
with self._build_status as status_dict: with self._build_status as status_dict:
status_dict['total_commands'] = len(command_comments) + 1 status_dict["total_commands"] = len(command_comments) + 1
logger.debug('Checking cache on realm %s. Base image: %s:%s (%s)', self.builder_realm, logger.debug(
base_image_name, base_image_tag, base_image_id) "Checking cache on realm %s. Base image: %s:%s (%s)",
self.builder_realm,
base_image_name,
base_image_tag,
base_image_id,
)
tag_found = self._current_job.determine_cached_tag(base_image_id, command_comments) tag_found = self._current_job.determine_cached_tag(
raise Return(tag_found or '') base_image_id, command_comments
)
raise Return(tag_found or "")
@trollius.coroutine @trollius.coroutine
def _build_failure(self, error_message, exception=None): def _build_failure(self, error_message, exception=None):
""" Handles and logs a failed build. """ """ Handles and logs a failed build. """
yield From(self._build_status.set_error(error_message, { yield From(
'internal_error': str(exception) if exception else None self._build_status.set_error(
})) error_message, {"internal_error": str(exception) if exception else None}
)
)
build_id = self._current_job.repo_build.uuid build_id = self._current_job.repo_build.uuid
logger.warning('Build %s failed with message: %s', build_id, error_message) logger.warning("Build %s failed with message: %s", build_id, error_message)
# Mark that the build has finished (in an error state) # Mark that the build has finished (in an error state)
yield From(self._build_finished(BuildJobResult.ERROR)) yield From(self._build_finished(BuildJobResult.ERROR))
@ -362,60 +431,82 @@ class BuildComponent(BaseComponent):
try: try:
yield From(self._build_status.set_phase(BUILD_PHASE.COMPLETE)) yield From(self._build_status.set_phase(BUILD_PHASE.COMPLETE))
except InvalidRepositoryBuildException: except InvalidRepositoryBuildException:
logger.warning('Build %s was not found; repo was probably deleted', build_id) logger.warning(
"Build %s was not found; repo was probably deleted", build_id
)
raise Return() raise Return()
yield From(self._build_finished(BuildJobResult.COMPLETE)) yield From(self._build_finished(BuildJobResult.COMPLETE))
# Label the pushed manifests with the build metadata. # Label the pushed manifests with the build metadata.
manifest_digests = kwargs.get('digests') or [] manifest_digests = kwargs.get("digests") or []
repository = registry_model.lookup_repository(self._current_job.namespace, repository = registry_model.lookup_repository(
self._current_job.repo_name) self._current_job.namespace, self._current_job.repo_name
)
if repository is not None: if repository is not None:
for digest in manifest_digests: for digest in manifest_digests:
with UseThenDisconnect(app.config): with UseThenDisconnect(app.config):
manifest = registry_model.lookup_manifest_by_digest(repository, digest, manifest = registry_model.lookup_manifest_by_digest(
require_available=True) repository, digest, require_available=True
)
if manifest is None: if manifest is None:
continue continue
registry_model.create_manifest_label(manifest, INTERNAL_LABEL_BUILD_UUID, registry_model.create_manifest_label(
build_id, 'internal', 'text/plain') manifest,
INTERNAL_LABEL_BUILD_UUID,
build_id,
"internal",
"text/plain",
)
# Send the notification that the build has completed successfully. # Send the notification that the build has completed successfully.
self._current_job.send_notification('build_success', self._current_job.send_notification(
image_id=kwargs.get('image_id'), "build_success",
manifest_digests=manifest_digests) image_id=kwargs.get("image_id"),
manifest_digests=manifest_digests,
)
except ApplicationError as aex: except ApplicationError as aex:
worker_error = WorkerError(aex.error, aex.kwargs.get('base_error')) worker_error = WorkerError(aex.error, aex.kwargs.get("base_error"))
# Write the error to the log. # Write the error to the log.
yield From(self._build_status.set_error(worker_error.public_message(), yield From(
self._build_status.set_error(
worker_error.public_message(),
worker_error.extra_data(), worker_error.extra_data(),
internal_error=worker_error.is_internal_error(), internal_error=worker_error.is_internal_error(),
requeued=self._current_job.has_retries_remaining())) requeued=self._current_job.has_retries_remaining(),
)
)
# Send the notification that the build has failed. # Send the notification that the build has failed.
self._current_job.send_notification('build_failure', self._current_job.send_notification(
error_message=worker_error.public_message()) "build_failure", error_message=worker_error.public_message()
)
# Mark the build as completed. # Mark the build as completed.
if worker_error.is_internal_error(): if worker_error.is_internal_error():
logger.exception('[BUILD INTERNAL ERROR: Remote] Build ID: %s: %s', build_id, logger.exception(
worker_error.public_message()) "[BUILD INTERNAL ERROR: Remote] Build ID: %s: %s",
build_id,
worker_error.public_message(),
)
yield From(self._build_finished(BuildJobResult.INCOMPLETE)) yield From(self._build_finished(BuildJobResult.INCOMPLETE))
else: else:
logger.debug('Got remote failure exception for build %s: %s', build_id, aex) logger.debug(
"Got remote failure exception for build %s: %s", build_id, aex
)
yield From(self._build_finished(BuildJobResult.ERROR)) yield From(self._build_finished(BuildJobResult.ERROR))
# Remove the current job. # Remove the current job.
self._current_job = None self._current_job = None
@trollius.coroutine @trollius.coroutine
def _build_finished(self, job_status): def _build_finished(self, job_status):
""" Alerts the parent that a build has completed and sets the status back to running. """ """ Alerts the parent that a build has completed and sets the status back to running. """
yield From(self.parent_manager.job_completed(self._current_job, job_status, self)) yield From(
self.parent_manager.job_completed(self._current_job, job_status, self)
)
# Set the component back to a running state. # Set the component back to a running state.
yield From(self._set_status(ComponentStatus.RUNNING)) yield From(self._set_status(ComponentStatus.RUNNING))
@ -423,7 +514,7 @@ class BuildComponent(BaseComponent):
@staticmethod @staticmethod
def _ping(): def _ping():
""" Ping pong. """ """ Ping pong. """
return 'pong' return "pong"
@trollius.coroutine @trollius.coroutine
def _on_ready(self, token, version): def _on_ready(self, token, version):
@ -431,17 +522,25 @@ class BuildComponent(BaseComponent):
self._worker_version = version self._worker_version = version
if not version in SUPPORTED_WORKER_VERSIONS: if not version in SUPPORTED_WORKER_VERSIONS:
logger.warning('Build component (token "%s") is running an out-of-date version: %s', token, logger.warning(
version) 'Build component (token "%s") is running an out-of-date version: %s',
token,
version,
)
raise Return(False) raise Return(False)
if self._component_status != ComponentStatus.WAITING: if self._component_status != ComponentStatus.WAITING:
logger.warning('Build component (token "%s") is already connected', self.expected_token) logger.warning(
'Build component (token "%s") is already connected', self.expected_token
)
raise Return(False) raise Return(False)
if token != self.expected_token: if token != self.expected_token:
logger.warning('Builder token mismatch. Expected: "%s". Found: "%s"', self.expected_token, logger.warning(
token) 'Builder token mismatch. Expected: "%s". Found: "%s"',
self.expected_token,
token,
)
raise Return(False) raise Return(False)
yield From(self._set_status(ComponentStatus.RUNNING)) yield From(self._set_status(ComponentStatus.RUNNING))
@ -449,7 +548,7 @@ class BuildComponent(BaseComponent):
# Start the heartbeat check and updating loop. # Start the heartbeat check and updating loop.
loop = trollius.get_event_loop() loop = trollius.get_event_loop()
loop.create_task(self._heartbeat()) loop.create_task(self._heartbeat())
logger.debug('Build worker %s is connected and ready', self.builder_realm) logger.debug("Build worker %s is connected and ready", self.builder_realm)
raise Return(True) raise Return(True)
@trollius.coroutine @trollius.coroutine
@ -464,7 +563,7 @@ class BuildComponent(BaseComponent):
if self._component_status == ComponentStatus.TIMED_OUT: if self._component_status == ComponentStatus.TIMED_OUT:
return return
logger.debug('Got heartbeat on realm %s', self.builder_realm) logger.debug("Got heartbeat on realm %s", self.builder_realm)
self._last_heartbeat = datetime.datetime.utcnow() self._last_heartbeat = datetime.datetime.utcnow()
@trollius.coroutine @trollius.coroutine
@ -477,14 +576,16 @@ class BuildComponent(BaseComponent):
while True: while True:
# If the component is no longer running or actively building, nothing more to do. # If the component is no longer running or actively building, nothing more to do.
if (self._component_status != ComponentStatus.RUNNING and if (
self._component_status != ComponentStatus.BUILDING): self._component_status != ComponentStatus.RUNNING
and self._component_status != ComponentStatus.BUILDING
):
raise Return() raise Return()
# If there is an active build, write the heartbeat to its status. # If there is an active build, write the heartbeat to its status.
if self._build_status is not None: if self._build_status is not None:
with self._build_status as status_dict: with self._build_status as status_dict:
status_dict['heartbeat'] = int(time.time()) status_dict["heartbeat"] = int(time.time())
# Mark the build item. # Mark the build item.
current_job = self._current_job current_job = self._current_job
@ -492,17 +593,26 @@ class BuildComponent(BaseComponent):
yield From(self.parent_manager.job_heartbeat(current_job)) yield From(self.parent_manager.job_heartbeat(current_job))
# Check the heartbeat from the worker. # Check the heartbeat from the worker.
logger.debug('Checking heartbeat on realm %s', self.builder_realm) logger.debug("Checking heartbeat on realm %s", self.builder_realm)
if (self._last_heartbeat and if (
self._last_heartbeat < datetime.datetime.utcnow() - HEARTBEAT_DELTA): self._last_heartbeat
logger.debug('Heartbeat on realm %s has expired: %s', self.builder_realm, and self._last_heartbeat < datetime.datetime.utcnow() - HEARTBEAT_DELTA
self._last_heartbeat) ):
logger.debug(
"Heartbeat on realm %s has expired: %s",
self.builder_realm,
self._last_heartbeat,
)
yield From(self._timeout()) yield From(self._timeout())
raise Return() raise Return()
logger.debug('Heartbeat on realm %s is valid: %s (%s).', self.builder_realm, logger.debug(
self._last_heartbeat, self._component_status) "Heartbeat on realm %s is valid: %s (%s).",
self.builder_realm,
self._last_heartbeat,
self._component_status,
)
yield From(trollius.sleep(HEARTBEAT_TIMEOUT)) yield From(trollius.sleep(HEARTBEAT_TIMEOUT))
@ -512,19 +622,28 @@ class BuildComponent(BaseComponent):
raise Return() raise Return()
yield From(self._set_status(ComponentStatus.TIMED_OUT)) yield From(self._set_status(ComponentStatus.TIMED_OUT))
logger.warning('Build component with realm %s has timed out', self.builder_realm) logger.warning(
"Build component with realm %s has timed out", self.builder_realm
)
# If we still have a running job, then it has not completed and we need to tell the parent # If we still have a running job, then it has not completed and we need to tell the parent
# manager. # manager.
if self._current_job is not None: if self._current_job is not None:
yield From(self._build_status.set_error('Build worker timed out', internal_error=True, yield From(
requeued=self._current_job.has_retries_remaining())) self._build_status.set_error(
"Build worker timed out",
internal_error=True,
requeued=self._current_job.has_retries_remaining(),
)
)
build_id = self._current_job.build_uuid build_id = self._current_job.build_uuid
logger.error('[BUILD INTERNAL ERROR: Timeout] Build ID: %s', build_id) logger.error("[BUILD INTERNAL ERROR: Timeout] Build ID: %s", build_id)
yield From(self.parent_manager.job_completed(self._current_job, yield From(
BuildJobResult.INCOMPLETE, self.parent_manager.job_completed(
self)) self._current_job, BuildJobResult.INCOMPLETE, self
)
)
# Unregister the current component so that it cannot be invoked again. # Unregister the current component so that it cannot be invoked again.
self.parent_manager.build_component_disposed(self, True) self.parent_manager.build_component_disposed(self, True)

View file

@ -1,15 +1,16 @@
import re import re
def extract_current_step(current_status_string): def extract_current_step(current_status_string):
""" Attempts to extract the current step numeric identifier from the given status string. Returns the step """ Attempts to extract the current step numeric identifier from the given status string. Returns the step
number or None if none. number or None if none.
""" """
# Older format: `Step 12 :` # Older format: `Step 12 :`
# Newer format: `Step 4/13 :` # Newer format: `Step 4/13 :`
step_increment = re.search(r'Step ([0-9]+)/([0-9]+) :', current_status_string) step_increment = re.search(r"Step ([0-9]+)/([0-9]+) :", current_status_string)
if step_increment: if step_increment:
return int(step_increment.group(1)) return int(step_increment.group(1))
step_increment = re.search(r'Step ([0-9]+) :', current_status_string) step_increment = re.search(r"Step ([0-9]+) :", current_status_string)
if step_increment: if step_increment:
return int(step_increment.group(1)) return int(step_increment.group(1))

View file

@ -3,7 +3,9 @@ import pytest
from buildman.component.buildcomponent import BuildComponent from buildman.component.buildcomponent import BuildComponent
@pytest.mark.parametrize('input,expected_path,expected_file', [ @pytest.mark.parametrize(
"input,expected_path,expected_file",
[
("", "/", "Dockerfile"), ("", "/", "Dockerfile"),
("/", "/", "Dockerfile"), ("/", "/", "Dockerfile"),
("/Dockerfile", "/", "Dockerfile"), ("/Dockerfile", "/", "Dockerfile"),
@ -14,23 +16,49 @@ from buildman.component.buildcomponent import BuildComponent
("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"), ("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"), ("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"), ("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/Dockerfile", "/somepath/some_other_path", "Dockerfile"), (
("/somepath/some_other_path/server.Dockerfile", "/somepath/some_other_path", "server.Dockerfile"), "/somepath/some_other_path/Dockerfile",
]) "/somepath/some_other_path",
"Dockerfile",
),
(
"/somepath/some_other_path/server.Dockerfile",
"/somepath/some_other_path",
"server.Dockerfile",
),
],
)
def test_path_is_dockerfile(input, expected_path, expected_file): def test_path_is_dockerfile(input, expected_path, expected_file):
actual_path, actual_file = BuildComponent.name_and_path(input) actual_path, actual_file = BuildComponent.name_and_path(input)
assert actual_path == expected_path assert actual_path == expected_path
assert actual_file == expected_file assert actual_file == expected_file
@pytest.mark.parametrize('build_config,context,dockerfile_path', [
({}, '', ''), @pytest.mark.parametrize(
({'build_subdir': '/builddir/Dockerfile'}, '', '/builddir/Dockerfile'), "build_config,context,dockerfile_path",
({'context': '/builddir'}, '/builddir', ''), [
({'context': '/builddir', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'), ({}, "", ""),
({'context': '/some_other_dir/Dockerfile', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'), ({"build_subdir": "/builddir/Dockerfile"}, "", "/builddir/Dockerfile"),
({'context': '/', 'build_subdir':'Dockerfile'}, '/', 'Dockerfile') ({"context": "/builddir"}, "/builddir", ""),
]) (
{"context": "/builddir", "build_subdir": "/builddir/Dockerfile"},
"/builddir",
"Dockerfile",
),
(
{
"context": "/some_other_dir/Dockerfile",
"build_subdir": "/builddir/Dockerfile",
},
"/builddir",
"Dockerfile",
),
({"context": "/", "build_subdir": "Dockerfile"}, "/", "Dockerfile"),
],
)
def test_extract_dockerfile_args(build_config, context, dockerfile_path): def test_extract_dockerfile_args(build_config, context, dockerfile_path):
actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(build_config) actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(
build_config
)
assert context == actual_context assert context == actual_context
assert dockerfile_path == actual_dockerfile_path assert dockerfile_path == actual_dockerfile_path

View file

@ -3,7 +3,9 @@ import pytest
from buildman.component.buildparse import extract_current_step from buildman.component.buildparse import extract_current_step
@pytest.mark.parametrize('input,expected_step', [ @pytest.mark.parametrize(
"input,expected_step",
[
("", None), ("", None),
("Step a :", None), ("Step a :", None),
("Step 1 :", 1), ("Step 1 :", 1),
@ -11,6 +13,7 @@ from buildman.component.buildparse import extract_current_step
("Step 1/2 : ", 1), ("Step 1/2 : ", 1),
("Step 2/17 : ", 2), ("Step 2/17 : ", 2),
("Step 4/13 : ARG somearg=foo", 4), ("Step 4/13 : ARG somearg=foo", 4),
]) ],
)
def test_extract_current_step(input, expected_step): def test_extract_current_step(input, expected_step):
assert extract_current_step(input) == expected_step assert extract_current_step(input) == expected_step

View file

@ -1,18 +1,22 @@
from data.database import BUILD_PHASE from data.database import BUILD_PHASE
class BuildJobResult(object): class BuildJobResult(object):
""" Build job result enum """ """ Build job result enum """
INCOMPLETE = 'incomplete'
COMPLETE = 'complete' INCOMPLETE = "incomplete"
ERROR = 'error' COMPLETE = "complete"
ERROR = "error"
class BuildServerStatus(object): class BuildServerStatus(object):
""" Build server status enum """ """ Build server status enum """
STARTING = 'starting'
RUNNING = 'running' STARTING = "starting"
SHUTDOWN = 'shutting_down' RUNNING = "running"
EXCEPTION = 'exception' SHUTDOWN = "shutting_down"
EXCEPTION = "exception"
RESULT_PHASES = { RESULT_PHASES = {
BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR, BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,

View file

@ -15,11 +15,13 @@ logger = logging.getLogger(__name__)
class BuildJobLoadException(Exception): class BuildJobLoadException(Exception):
""" Exception raised if a build job could not be instantiated for some reason. """ """ Exception raised if a build job could not be instantiated for some reason. """
pass pass
class BuildJob(object): class BuildJob(object):
""" Represents a single in-progress build job. """ """ Represents a single in-progress build job. """
def __init__(self, job_item): def __init__(self, job_item):
self.job_item = job_item self.job_item = job_item
@ -28,7 +30,8 @@ class BuildJob(object):
self.build_notifier = BuildJobNotifier(self.build_uuid) self.build_notifier = BuildJobNotifier(self.build_uuid)
except ValueError: except ValueError:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not parse build queue item config with ID %s' % self.job_details['build_uuid'] "Could not parse build queue item config with ID %s"
% self.job_details["build_uuid"]
) )
@property @property
@ -38,8 +41,12 @@ class BuildJob(object):
def has_retries_remaining(self): def has_retries_remaining(self):
return self.job_item.retries_remaining > 0 return self.job_item.retries_remaining > 0
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None): def send_notification(
self.build_notifier.send_notification(kind, error_message, image_id, manifest_digests) self, kind, error_message=None, image_id=None, manifest_digests=None
):
self.build_notifier.send_notification(
kind, error_message, image_id, manifest_digests
)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def _load_repo_build(self): def _load_repo_build(self):
@ -48,12 +55,13 @@ class BuildJob(object):
return model.build.get_repository_build(self.build_uuid) return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException: except model.InvalidRepositoryBuildException:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid) "Could not load repository build with ID %s" % self.build_uuid
)
@property @property
def build_uuid(self): def build_uuid(self):
""" Returns the unique UUID for this build job. """ """ Returns the unique UUID for this build job. """
return self.job_details['build_uuid'] return self.job_details["build_uuid"]
@property @property
def namespace(self): def namespace(self):
@ -71,19 +79,21 @@ class BuildJob(object):
def get_build_package_url(self, user_files): def get_build_package_url(self, user_files):
""" Returns the URL of the build package for this build, if any or empty string if none. """ """ Returns the URL of the build package for this build, if any or empty string if none. """
archive_url = self.build_config.get('archive_url', None) archive_url = self.build_config.get("archive_url", None)
if archive_url: if archive_url:
return archive_url return archive_url
if not self.repo_build.resource_key: if not self.repo_build.resource_key:
return '' return ""
return user_files.get_file_url(self.repo_build.resource_key, '127.0.0.1', requires_cors=False) return user_files.get_file_url(
self.repo_build.resource_key, "127.0.0.1", requires_cors=False
)
@property @property
def pull_credentials(self): def pull_credentials(self):
""" Returns the pull credentials for this job, or None if none. """ """ Returns the pull credentials for this job, or None if none. """
return self.job_details.get('pull_credentials') return self.job_details.get("pull_credentials")
@property @property
def build_config(self): def build_config(self):
@ -91,13 +101,19 @@ class BuildJob(object):
return json.loads(self.repo_build.job_config) return json.loads(self.repo_build.job_config)
except ValueError: except ValueError:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.job_details['build_uuid'] "Could not parse repository build job config with ID %s"
% self.job_details["build_uuid"]
) )
def determine_cached_tag(self, base_image_id=None, cache_comments=None): def determine_cached_tag(self, base_image_id=None, cache_comments=None):
""" Returns the tag to pull to prime the cache or None if none. """ """ Returns the tag to pull to prime the cache or None if none. """
cached_tag = self._determine_cached_tag_by_tag() cached_tag = self._determine_cached_tag_by_tag()
logger.debug('Determined cached tag %s for %s: %s', cached_tag, base_image_id, cache_comments) logger.debug(
"Determined cached tag %s for %s: %s",
cached_tag,
base_image_id,
cache_comments,
)
return cached_tag return cached_tag
def _determine_cached_tag_by_tag(self): def _determine_cached_tag_by_tag(self):
@ -105,7 +121,7 @@ class BuildJob(object):
exists in the repository. This is a fallback for when no comment information is available. exists in the repository. This is a fallback for when no comment information is available.
""" """
with UseThenDisconnect(app.config): with UseThenDisconnect(app.config):
tags = self.build_config.get('docker_tags', ['latest']) tags = self.build_config.get("docker_tags", ["latest"])
repository = RepositoryReference.for_repo_obj(self.repo_build.repository) repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
matching_tag = registry_model.find_matching_tag(repository, tags) matching_tag = registry_model.find_matching_tag(repository, tags)
if matching_tag is not None: if matching_tag is not None:
@ -134,7 +150,8 @@ class BuildJobNotifier(object):
return model.build.get_repository_build(self.build_uuid) return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException: except model.InvalidRepositoryBuildException:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid) "Could not load repository build with ID %s" % self.build_uuid
)
@property @property
def build_config(self): def build_config(self):
@ -142,12 +159,15 @@ class BuildJobNotifier(object):
return json.loads(self.repo_build.job_config) return json.loads(self.repo_build.job_config)
except ValueError: except ValueError:
raise BuildJobLoadException( raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.repo_build.uuid "Could not parse repository build job config with ID %s"
% self.repo_build.uuid
) )
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None): def send_notification(
self, kind, error_message=None, image_id=None, manifest_digests=None
):
with UseThenDisconnect(app.config): with UseThenDisconnect(app.config):
tags = self.build_config.get('docker_tags', ['latest']) tags = self.build_config.get("docker_tags", ["latest"])
trigger = self.repo_build.trigger trigger = self.repo_build.trigger
if trigger is not None and trigger.id is not None: if trigger is not None and trigger.id is not None:
trigger_kind = trigger.service.name trigger_kind = trigger.service.name
@ -155,29 +175,35 @@ class BuildJobNotifier(object):
trigger_kind = None trigger_kind = None
event_data = { event_data = {
'build_id': self.repo_build.uuid, "build_id": self.repo_build.uuid,
'build_name': self.repo_build.display_name, "build_name": self.repo_build.display_name,
'docker_tags': tags, "docker_tags": tags,
'trigger_id': trigger.uuid if trigger is not None else None, "trigger_id": trigger.uuid if trigger is not None else None,
'trigger_kind': trigger_kind, "trigger_kind": trigger_kind,
'trigger_metadata': self.build_config.get('trigger_metadata', {}) "trigger_metadata": self.build_config.get("trigger_metadata", {}),
} }
if image_id is not None: if image_id is not None:
event_data['image_id'] = image_id event_data["image_id"] = image_id
if manifest_digests: if manifest_digests:
event_data['manifest_digests'] = manifest_digests event_data["manifest_digests"] = manifest_digests
if error_message is not None: if error_message is not None:
event_data['error_message'] = error_message event_data["error_message"] = error_message
# TODO: remove when more endpoints have been converted to using # TODO: remove when more endpoints have been converted to using
# interfaces # interfaces
repo = AttrDict({ repo = AttrDict(
'namespace_name': self.repo_build.repository.namespace_user.username, {
'name': self.repo_build.repository.name, "namespace_name": self.repo_build.repository.namespace_user.username,
}) "name": self.repo_build.repository.name,
spawn_notification(repo, kind, event_data, }
subpage='build/%s' % self.repo_build.uuid, )
pathargs=['build', self.repo_build.uuid]) spawn_notification(
repo,
kind,
event_data,
subpage="build/%s" % self.repo_build.uuid,
pathargs=["build", self.repo_build.uuid],
)

View file

@ -24,10 +24,10 @@ class StatusHandler(object):
self._build_model = AsyncWrapper(model.build) self._build_model = AsyncWrapper(model.build)
self._status = { self._status = {
'total_commands': 0, "total_commands": 0,
'current_command': None, "current_command": None,
'push_completion': 0.0, "push_completion": 0.0,
'pull_completion': 0.0, "pull_completion": 0.0,
} }
# Write the initial status. # Write the initial status.
@ -36,12 +36,18 @@ class StatusHandler(object):
@coroutine @coroutine
def _append_log_message(self, log_message, log_type=None, log_data=None): def _append_log_message(self, log_message, log_type=None, log_data=None):
log_data = log_data or {} log_data = log_data or {}
log_data['datetime'] = str(datetime.datetime.now()) log_data["datetime"] = str(datetime.datetime.now())
try: try:
yield From(self._build_logs.append_log_message(self._uuid, log_message, log_type, log_data)) yield From(
self._build_logs.append_log_message(
self._uuid, log_message, log_type, log_data
)
)
except RedisError: except RedisError:
logger.exception('Could not save build log for build %s: %s', self._uuid, log_message) logger.exception(
"Could not save build log for build %s: %s", self._uuid, log_message
)
@coroutine @coroutine
def append_log(self, log_message, extra_data=None): def append_log(self, log_message, extra_data=None):
@ -56,16 +62,26 @@ class StatusHandler(object):
raise Return() raise Return()
self._current_command = command self._current_command = command
yield From(self._append_log_message(command, self._build_logs.COMMAND, extra_data)) yield From(
self._append_log_message(command, self._build_logs.COMMAND, extra_data)
)
@coroutine @coroutine
def set_error(self, error_message, extra_data=None, internal_error=False, requeued=False): def set_error(
error_phase = BUILD_PHASE.INTERNAL_ERROR if internal_error and requeued else BUILD_PHASE.ERROR self, error_message, extra_data=None, internal_error=False, requeued=False
):
error_phase = (
BUILD_PHASE.INTERNAL_ERROR
if internal_error and requeued
else BUILD_PHASE.ERROR
)
yield From(self.set_phase(error_phase)) yield From(self.set_phase(error_phase))
extra_data = extra_data or {} extra_data = extra_data or {}
extra_data['internal_error'] = internal_error extra_data["internal_error"] = internal_error
yield From(self._append_log_message(error_message, self._build_logs.ERROR, extra_data)) yield From(
self._append_log_message(error_message, self._build_logs.ERROR, extra_data)
)
@coroutine @coroutine
def set_phase(self, phase, extra_data=None): def set_phase(self, phase, extra_data=None):
@ -85,4 +101,6 @@ class StatusHandler(object):
try: try:
self._sync_build_logs.set_status(self._uuid, self._status) self._sync_build_logs.set_status(self._uuid, self._status)
except RedisError: except RedisError:
logger.exception('Could not set status of build %s to %s', self._uuid, self._status) logger.exception(
"Could not set status of build %s to %s", self._uuid, self._status
)

View file

@ -1,119 +1,99 @@
class WorkerError(object): class WorkerError(object):
""" Helper class which represents errors raised by a build worker. """ """ Helper class which represents errors raised by a build worker. """
def __init__(self, error_code, base_message=None): def __init__(self, error_code, base_message=None):
self._error_code = error_code self._error_code = error_code
self._base_message = base_message self._base_message = base_message
self._error_handlers = { self._error_handlers = {
'io.quay.builder.buildpackissue': { "io.quay.builder.buildpackissue": {
'message': 'Could not load build package', "message": "Could not load build package",
'is_internal': True, "is_internal": True,
}, },
"io.quay.builder.gitfailure": {
'io.quay.builder.gitfailure': { "message": "Could not clone git repository",
'message': 'Could not clone git repository', "show_base_error": True,
'show_base_error': True,
}, },
"io.quay.builder.gitcheckout": {
'io.quay.builder.gitcheckout': { "message": "Could not checkout git ref. If you force pushed recently, "
'message': 'Could not checkout git ref. If you force pushed recently, ' + + "the commit may be missing.",
'the commit may be missing.', "show_base_error": True,
'show_base_error': True,
}, },
"io.quay.builder.cannotextractbuildpack": {
'io.quay.builder.cannotextractbuildpack': { "message": "Could not extract the contents of the build package"
'message': 'Could not extract the contents of the build package'
}, },
"io.quay.builder.cannotpullforcache": {
'io.quay.builder.cannotpullforcache': { "message": "Could not pull cached image",
'message': 'Could not pull cached image', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.dockerfileissue": {
'io.quay.builder.dockerfileissue': { "message": "Could not find or parse Dockerfile",
'message': 'Could not find or parse Dockerfile', "show_base_error": True,
'show_base_error': True
}, },
"io.quay.builder.cannotpullbaseimage": {
'io.quay.builder.cannotpullbaseimage': { "message": "Could not pull base image",
'message': 'Could not pull base image', "show_base_error": True,
'show_base_error': True
}, },
"io.quay.builder.internalerror": {
'io.quay.builder.internalerror': { "message": "An internal error occurred while building. Please submit a ticket.",
'message': 'An internal error occurred while building. Please submit a ticket.', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.buildrunerror": {
'io.quay.builder.buildrunerror': { "message": "Could not start the build process",
'message': 'Could not start the build process', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.builderror": {
'io.quay.builder.builderror': { "message": "A build step failed",
'message': 'A build step failed', "show_base_error": True,
'show_base_error': True
}, },
"io.quay.builder.tagissue": {
'io.quay.builder.tagissue': { "message": "Could not tag built image",
'message': 'Could not tag built image', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.pushissue": {
'io.quay.builder.pushissue': { "message": "Could not push built image",
'message': 'Could not push built image', "show_base_error": True,
'show_base_error': True, "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.dockerconnecterror": {
'io.quay.builder.dockerconnecterror': { "message": "Could not connect to Docker daemon",
'message': 'Could not connect to Docker daemon', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.missingorinvalidargument": {
'io.quay.builder.missingorinvalidargument': { "message": "Missing required arguments for builder",
'message': 'Missing required arguments for builder', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.cachelookupissue": {
'io.quay.builder.cachelookupissue': { "message": "Error checking for a cached tag",
'message': 'Error checking for a cached tag', "is_internal": True,
'is_internal': True
}, },
"io.quay.builder.errorduringphasetransition": {
'io.quay.builder.errorduringphasetransition': { "message": "Error during phase transition. If this problem persists "
'message': 'Error during phase transition. If this problem persists ' + + "please contact customer support.",
'please contact customer support.', "is_internal": True,
'is_internal': True },
"io.quay.builder.clientrejectedtransition": {
"message": "Build can not be finished due to user cancellation."
}, },
'io.quay.builder.clientrejectedtransition': {
'message': 'Build can not be finished due to user cancellation.',
}
} }
def is_internal_error(self): def is_internal_error(self):
handler = self._error_handlers.get(self._error_code) handler = self._error_handlers.get(self._error_code)
return handler.get('is_internal', False) if handler else True return handler.get("is_internal", False) if handler else True
def public_message(self): def public_message(self):
handler = self._error_handlers.get(self._error_code) handler = self._error_handlers.get(self._error_code)
if not handler: if not handler:
return 'An unknown error occurred' return "An unknown error occurred"
message = handler['message'] message = handler["message"]
if handler.get('show_base_error', False) and self._base_message: if handler.get("show_base_error", False) and self._base_message:
message = message + ': ' + self._base_message message = message + ": " + self._base_message
return message return message
def extra_data(self): def extra_data(self):
if self._base_message: if self._base_message:
return { return {"base_error": self._base_message, "error_code": self._error_code}
'base_error': self._base_message,
'error_code': self._error_code
}
return { return {"error_code": self._error_code}
'error_code': self._error_code
}

View file

@ -1,9 +1,18 @@
from trollius import coroutine from trollius import coroutine
class BaseManager(object): class BaseManager(object):
""" Base for all worker managers. """ """ Base for all worker managers. """
def __init__(self, register_component, unregister_component, job_heartbeat_callback,
job_complete_callback, manager_hostname, heartbeat_period_sec): def __init__(
self,
register_component,
unregister_component,
job_heartbeat_callback,
job_complete_callback,
manager_hostname,
heartbeat_period_sec,
):
self.register_component = register_component self.register_component = register_component
self.unregister_component = unregister_component self.unregister_component = unregister_component
self.job_heartbeat_callback = job_heartbeat_callback self.job_heartbeat_callback = job_heartbeat_callback

View file

@ -5,14 +5,14 @@ from buildman.manager.noop_canceller import NoopCanceller
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CANCELLERS = {'ephemeral': OrchestratorCanceller} CANCELLERS = {"ephemeral": OrchestratorCanceller}
class BuildCanceller(object): class BuildCanceller(object):
""" A class to manage cancelling a build """ """ A class to manage cancelling a build """
def __init__(self, app=None): def __init__(self, app=None):
self.build_manager_config = app.config.get('BUILD_MANAGER') self.build_manager_config = app.config.get("BUILD_MANAGER")
if app is None or self.build_manager_config is None: if app is None or self.build_manager_config is None:
self.handler = NoopCanceller() self.handler = NoopCanceller()
else: else:

View file

@ -7,10 +7,11 @@ from buildman.manager.basemanager import BaseManager
from trollius import From, Return, coroutine from trollius import From, Return, coroutine
REGISTRATION_REALM = 'registration' REGISTRATION_REALM = "registration"
RETRY_TIMEOUT = 5 RETRY_TIMEOUT = 5
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DynamicRegistrationComponent(BaseComponent): class DynamicRegistrationComponent(BaseComponent):
""" Component session that handles dynamic registration of the builder components. """ """ Component session that handles dynamic registration of the builder components. """
@ -18,16 +19,18 @@ class DynamicRegistrationComponent(BaseComponent):
self.join(REGISTRATION_REALM) self.join(REGISTRATION_REALM)
def onJoin(self, details): def onJoin(self, details):
logger.debug('Registering registration method') logger.debug("Registering registration method")
yield From(self.register(self._worker_register, u'io.quay.buildworker.register')) yield From(
self.register(self._worker_register, u"io.quay.buildworker.register")
)
def _worker_register(self): def _worker_register(self):
realm = self.parent_manager.add_build_component() realm = self.parent_manager.add_build_component()
logger.debug('Registering new build component+worker with realm %s', realm) logger.debug("Registering new build component+worker with realm %s", realm)
return realm return realm
def kind(self): def kind(self):
return 'registration' return "registration"
class EnterpriseManager(BaseManager): class EnterpriseManager(BaseManager):

View file

@ -8,29 +8,32 @@ class EtcdCanceller(object):
""" A class that sends a message to etcd to cancel a build """ """ A class that sends a message to etcd to cancel a build """
def __init__(self, config): def __init__(self, config):
etcd_host = config.get('ETCD_HOST', '127.0.0.1') etcd_host = config.get("ETCD_HOST", "127.0.0.1")
etcd_port = config.get('ETCD_PORT', 2379) etcd_port = config.get("ETCD_PORT", 2379)
etcd_ca_cert = config.get('ETCD_CA_CERT', None) etcd_ca_cert = config.get("ETCD_CA_CERT", None)
etcd_auth = config.get('ETCD_CERT_AND_KEY', None) etcd_auth = config.get("ETCD_CERT_AND_KEY", None)
if etcd_auth is not None: if etcd_auth is not None:
etcd_auth = tuple(etcd_auth) etcd_auth = tuple(etcd_auth)
etcd_protocol = 'http' if etcd_auth is None else 'https' etcd_protocol = "http" if etcd_auth is None else "https"
logger.debug('Connecting to etcd on %s:%s', etcd_host, etcd_port) logger.debug("Connecting to etcd on %s:%s", etcd_host, etcd_port)
self._cancel_prefix = config.get('ETCD_CANCEL_PREFIX', 'cancel/') self._cancel_prefix = config.get("ETCD_CANCEL_PREFIX", "cancel/")
self._etcd_client = etcd.Client( self._etcd_client = etcd.Client(
host=etcd_host, host=etcd_host,
port=etcd_port, port=etcd_port,
cert=etcd_auth, cert=etcd_auth,
ca_cert=etcd_ca_cert, ca_cert=etcd_ca_cert,
protocol=etcd_protocol, protocol=etcd_protocol,
read_timeout=5) read_timeout=5,
)
def try_cancel_build(self, build_uuid): def try_cancel_build(self, build_uuid):
""" Writes etcd message to cancel build_uuid. """ """ Writes etcd message to cancel build_uuid. """
logger.info("Cancelling build %s".format(build_uuid)) logger.info("Cancelling build %s".format(build_uuid))
try: try:
self._etcd_client.write("{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60) self._etcd_client.write(
"{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60
)
return True return True
except etcd.EtcdException: except etcd.EtcdException:
logger.exception("Failed to write to etcd client %s", build_uuid) logger.exception("Failed to write to etcd client %s", build_uuid)

View file

@ -35,12 +35,14 @@ _TAG_RETRY_COUNT = 3 # Number of times to retry adding tags.
_TAG_RETRY_SLEEP = 2 # Number of seconds to wait between tag retries. _TAG_RETRY_SLEEP = 2 # Number of seconds to wait between tag retries.
ENV = Environment(loader=FileSystemLoader(os.path.join(ROOT_DIR, "buildman/templates"))) ENV = Environment(loader=FileSystemLoader(os.path.join(ROOT_DIR, "buildman/templates")))
TEMPLATE = ENV.get_template('cloudconfig.yaml') TEMPLATE = ENV.get_template("cloudconfig.yaml")
CloudConfigContext().populate_jinja_environment(ENV) CloudConfigContext().populate_jinja_environment(ENV)
class ExecutorException(Exception): class ExecutorException(Exception):
""" Exception raised when there is a problem starting or stopping a builder. """ Exception raised when there is a problem starting or stopping a builder.
""" """
pass pass
@ -52,20 +54,24 @@ class BuilderExecutor(object):
self.executor_config = executor_config self.executor_config = executor_config
self.manager_hostname = manager_hostname self.manager_hostname = manager_hostname
default_websocket_scheme = 'wss' if app.config['PREFERRED_URL_SCHEME'] == 'https' else 'ws' default_websocket_scheme = (
self.websocket_scheme = executor_config.get("WEBSOCKET_SCHEME", default_websocket_scheme) "wss" if app.config["PREFERRED_URL_SCHEME"] == "https" else "ws"
)
self.websocket_scheme = executor_config.get(
"WEBSOCKET_SCHEME", default_websocket_scheme
)
@property @property
def name(self): def name(self):
""" Name returns the unique name for this executor. """ """ Name returns the unique name for this executor. """
return self.executor_config.get('NAME') or self.__class__.__name__ return self.executor_config.get("NAME") or self.__class__.__name__
@property @property
def setup_time(self): def setup_time(self):
""" Returns the amount of time (in seconds) to wait for the execution to start for the build. """ Returns the amount of time (in seconds) to wait for the execution to start for the build.
If None, the manager's default will be used. If None, the manager's default will be used.
""" """
return self.executor_config.get('SETUP_TIME') return self.executor_config.get("SETUP_TIME")
@coroutine @coroutine
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
@ -84,13 +90,13 @@ class BuilderExecutor(object):
""" Returns true if this executor can be used for builds in the given namespace. """ """ Returns true if this executor can be used for builds in the given namespace. """
# Check for an explicit namespace whitelist. # Check for an explicit namespace whitelist.
namespace_whitelist = self.executor_config.get('NAMESPACE_WHITELIST') namespace_whitelist = self.executor_config.get("NAMESPACE_WHITELIST")
if namespace_whitelist is not None and namespace in namespace_whitelist: if namespace_whitelist is not None and namespace in namespace_whitelist:
return True return True
# Check for a staged rollout percentage. If found, we hash the namespace and, if it is found # Check for a staged rollout percentage. If found, we hash the namespace and, if it is found
# in the first X% of the character space, we allow this executor to be used. # in the first X% of the character space, we allow this executor to be used.
staged_rollout = self.executor_config.get('STAGED_ROLLOUT') staged_rollout = self.executor_config.get("STAGED_ROLLOUT")
if staged_rollout is not None: if staged_rollout is not None:
bucket = int(hashlib.sha256(namespace).hexdigest()[-2:], 16) bucket = int(hashlib.sha256(namespace).hexdigest()[-2:], 16)
return bucket < (256 * staged_rollout) return bucket < (256 * staged_rollout)
@ -102,16 +108,23 @@ class BuilderExecutor(object):
def minimum_retry_threshold(self): def minimum_retry_threshold(self):
""" Returns the minimum number of retries required for this executor to be used or 0 if """ Returns the minimum number of retries required for this executor to be used or 0 if
none. """ none. """
return self.executor_config.get('MINIMUM_RETRY_THRESHOLD', 0) return self.executor_config.get("MINIMUM_RETRY_THRESHOLD", 0)
def generate_cloud_config(self, realm, token, build_uuid, coreos_channel, def generate_cloud_config(
manager_hostname, quay_username=None, self,
quay_password=None): realm,
token,
build_uuid,
coreos_channel,
manager_hostname,
quay_username=None,
quay_password=None,
):
if quay_username is None: if quay_username is None:
quay_username = self.executor_config['QUAY_USERNAME'] quay_username = self.executor_config["QUAY_USERNAME"]
if quay_password is None: if quay_password is None:
quay_password = self.executor_config['QUAY_PASSWORD'] quay_password = self.executor_config["QUAY_PASSWORD"]
return TEMPLATE.render( return TEMPLATE.render(
realm=realm, realm=realm,
@ -122,12 +135,14 @@ class BuilderExecutor(object):
manager_hostname=manager_hostname, manager_hostname=manager_hostname,
websocket_scheme=self.websocket_scheme, websocket_scheme=self.websocket_scheme,
coreos_channel=coreos_channel, coreos_channel=coreos_channel,
worker_image=self.executor_config.get('WORKER_IMAGE', 'quay.io/coreos/registry-build-worker'), worker_image=self.executor_config.get(
worker_tag=self.executor_config['WORKER_TAG'], "WORKER_IMAGE", "quay.io/coreos/registry-build-worker"
logentries_token=self.executor_config.get('LOGENTRIES_TOKEN', None), ),
volume_size=self.executor_config.get('VOLUME_SIZE', '42G'), worker_tag=self.executor_config["WORKER_TAG"],
max_lifetime_s=self.executor_config.get('MAX_LIFETIME_S', 10800), logentries_token=self.executor_config.get("LOGENTRIES_TOKEN", None),
ssh_authorized_keys=self.executor_config.get('SSH_AUTHORIZED_KEYS', []), volume_size=self.executor_config.get("VOLUME_SIZE", "42G"),
max_lifetime_s=self.executor_config.get("MAX_LIFETIME_S", 10800),
ssh_authorized_keys=self.executor_config.get("SSH_AUTHORIZED_KEYS", []),
) )
@ -135,7 +150,10 @@ class EC2Executor(BuilderExecutor):
""" Implementation of BuilderExecutor which uses libcloud to start machines on a variety of cloud """ Implementation of BuilderExecutor which uses libcloud to start machines on a variety of cloud
providers. providers.
""" """
COREOS_STACK_URL = 'http://%s.release.core-os.net/amd64-usr/current/coreos_production_ami_hvm.txt'
COREOS_STACK_URL = (
"http://%s.release.core-os.net/amd64-usr/current/coreos_production_ami_hvm.txt"
)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._loop = get_event_loop() self._loop = get_event_loop()
@ -144,73 +162,81 @@ class EC2Executor(BuilderExecutor):
def _get_conn(self): def _get_conn(self):
""" Creates an ec2 connection which can be used to manage instances. """ Creates an ec2 connection which can be used to manage instances.
""" """
return AsyncWrapper(boto.ec2.connect_to_region( return AsyncWrapper(
self.executor_config['EC2_REGION'], boto.ec2.connect_to_region(
aws_access_key_id=self.executor_config['AWS_ACCESS_KEY'], self.executor_config["EC2_REGION"],
aws_secret_access_key=self.executor_config['AWS_SECRET_KEY'], aws_access_key_id=self.executor_config["AWS_ACCESS_KEY"],
)) aws_secret_access_key=self.executor_config["AWS_SECRET_KEY"],
)
)
@classmethod @classmethod
@cachetools.func.ttl_cache(ttl=ONE_HOUR) @cachetools.func.ttl_cache(ttl=ONE_HOUR)
def _get_coreos_ami(cls, ec2_region, coreos_channel): def _get_coreos_ami(cls, ec2_region, coreos_channel):
""" Retrieve the CoreOS AMI id from the canonical listing. """ Retrieve the CoreOS AMI id from the canonical listing.
""" """
stack_list_string = requests.get(EC2Executor.COREOS_STACK_URL % coreos_channel).text stack_list_string = requests.get(
stack_amis = dict([stack.split('=') for stack in stack_list_string.split('|')]) EC2Executor.COREOS_STACK_URL % coreos_channel
).text
stack_amis = dict([stack.split("=") for stack in stack_list_string.split("|")])
return stack_amis[ec2_region] return stack_amis[ec2_region]
@coroutine @coroutine
@duration_collector_async(metric_queue.builder_time_to_start, ['ec2']) @duration_collector_async(metric_queue.builder_time_to_start, ["ec2"])
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
region = self.executor_config['EC2_REGION'] region = self.executor_config["EC2_REGION"]
channel = self.executor_config.get('COREOS_CHANNEL', 'stable') channel = self.executor_config.get("COREOS_CHANNEL", "stable")
coreos_ami = self.executor_config.get('COREOS_AMI', None) coreos_ami = self.executor_config.get("COREOS_AMI", None)
if coreos_ami is None: if coreos_ami is None:
get_ami_callable = partial(self._get_coreos_ami, region, channel) get_ami_callable = partial(self._get_coreos_ami, region, channel)
coreos_ami = yield From(self._loop.run_in_executor(None, get_ami_callable)) coreos_ami = yield From(self._loop.run_in_executor(None, get_ami_callable))
user_data = self.generate_cloud_config(realm, token, build_uuid, channel, self.manager_hostname) user_data = self.generate_cloud_config(
logger.debug('Generated cloud config for build %s: %s', build_uuid, user_data) realm, token, build_uuid, channel, self.manager_hostname
)
logger.debug("Generated cloud config for build %s: %s", build_uuid, user_data)
ec2_conn = self._get_conn() ec2_conn = self._get_conn()
ssd_root_ebs = boto.ec2.blockdevicemapping.BlockDeviceType( ssd_root_ebs = boto.ec2.blockdevicemapping.BlockDeviceType(
size=int(self.executor_config.get('BLOCK_DEVICE_SIZE', 48)), size=int(self.executor_config.get("BLOCK_DEVICE_SIZE", 48)),
volume_type='gp2', volume_type="gp2",
delete_on_termination=True, delete_on_termination=True,
) )
block_devices = boto.ec2.blockdevicemapping.BlockDeviceMapping() block_devices = boto.ec2.blockdevicemapping.BlockDeviceMapping()
block_devices['/dev/xvda'] = ssd_root_ebs block_devices["/dev/xvda"] = ssd_root_ebs
interfaces = None interfaces = None
if self.executor_config.get('EC2_VPC_SUBNET_ID', None) is not None: if self.executor_config.get("EC2_VPC_SUBNET_ID", None) is not None:
interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
subnet_id=self.executor_config['EC2_VPC_SUBNET_ID'], subnet_id=self.executor_config["EC2_VPC_SUBNET_ID"],
groups=self.executor_config['EC2_SECURITY_GROUP_IDS'], groups=self.executor_config["EC2_SECURITY_GROUP_IDS"],
associate_public_ip_address=True, associate_public_ip_address=True,
) )
interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface) interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)
try: try:
reservation = yield From(ec2_conn.run_instances( reservation = yield From(
ec2_conn.run_instances(
coreos_ami, coreos_ami,
instance_type=self.executor_config['EC2_INSTANCE_TYPE'], instance_type=self.executor_config["EC2_INSTANCE_TYPE"],
key_name=self.executor_config.get('EC2_KEY_NAME', None), key_name=self.executor_config.get("EC2_KEY_NAME", None),
user_data=user_data, user_data=user_data,
instance_initiated_shutdown_behavior='terminate', instance_initiated_shutdown_behavior="terminate",
block_device_map=block_devices, block_device_map=block_devices,
network_interfaces=interfaces, network_interfaces=interfaces,
)) )
)
except boto.exception.EC2ResponseError as ec2e: except boto.exception.EC2ResponseError as ec2e:
logger.exception('Unable to spawn builder instance') logger.exception("Unable to spawn builder instance")
metric_queue.ephemeral_build_worker_failure.Inc() metric_queue.ephemeral_build_worker_failure.Inc()
raise ec2e raise ec2e
if not reservation.instances: if not reservation.instances:
raise ExecutorException('Unable to spawn builder instance.') raise ExecutorException("Unable to spawn builder instance.")
elif len(reservation.instances) != 1: elif len(reservation.instances) != 1:
raise ExecutorException('EC2 started wrong number of instances!') raise ExecutorException("EC2 started wrong number of instances!")
launched = AsyncWrapper(reservation.instances[0]) launched = AsyncWrapper(reservation.instances[0])
@ -220,47 +246,60 @@ class EC2Executor(BuilderExecutor):
# Tag the instance with its metadata. # Tag the instance with its metadata.
for i in range(0, _TAG_RETRY_COUNT): for i in range(0, _TAG_RETRY_COUNT):
try: try:
yield From(launched.add_tags({ yield From(
'Name': 'Quay Ephemeral Builder', launched.add_tags(
'Realm': realm, {
'Token': token, "Name": "Quay Ephemeral Builder",
'BuildUUID': build_uuid, "Realm": realm,
})) "Token": token,
"BuildUUID": build_uuid,
}
)
)
except boto.exception.EC2ResponseError as ec2e: except boto.exception.EC2ResponseError as ec2e:
if ec2e.error_code == 'InvalidInstanceID.NotFound': if ec2e.error_code == "InvalidInstanceID.NotFound":
if i < _TAG_RETRY_COUNT - 1: if i < _TAG_RETRY_COUNT - 1:
logger.warning('Failed to write EC2 tags for instance %s for build %s (attempt #%s)', logger.warning(
launched.id, build_uuid, i) "Failed to write EC2 tags for instance %s for build %s (attempt #%s)",
launched.id,
build_uuid,
i,
)
yield From(trollius.sleep(_TAG_RETRY_SLEEP)) yield From(trollius.sleep(_TAG_RETRY_SLEEP))
continue continue
raise ExecutorException('Unable to find builder instance.') raise ExecutorException("Unable to find builder instance.")
logger.exception('Failed to write EC2 tags (attempt #%s)', i) logger.exception("Failed to write EC2 tags (attempt #%s)", i)
logger.debug('Machine with ID %s started for build %s', launched.id, build_uuid) logger.debug("Machine with ID %s started for build %s", launched.id, build_uuid)
raise Return(launched.id) raise Return(launched.id)
@coroutine @coroutine
def stop_builder(self, builder_id): def stop_builder(self, builder_id):
try: try:
ec2_conn = self._get_conn() ec2_conn = self._get_conn()
terminated_instances = yield From(ec2_conn.terminate_instances([builder_id])) terminated_instances = yield From(
ec2_conn.terminate_instances([builder_id])
)
except boto.exception.EC2ResponseError as ec2e: except boto.exception.EC2ResponseError as ec2e:
if ec2e.error_code == 'InvalidInstanceID.NotFound': if ec2e.error_code == "InvalidInstanceID.NotFound":
logger.debug('Instance %s already terminated', builder_id) logger.debug("Instance %s already terminated", builder_id)
return return
logger.exception('Exception when trying to terminate instance %s', builder_id) logger.exception(
"Exception when trying to terminate instance %s", builder_id
)
raise raise
if builder_id not in [si.id for si in terminated_instances]: if builder_id not in [si.id for si in terminated_instances]:
raise ExecutorException('Unable to terminate instance: %s' % builder_id) raise ExecutorException("Unable to terminate instance: %s" % builder_id)
class PopenExecutor(BuilderExecutor): class PopenExecutor(BuilderExecutor):
""" Implementation of BuilderExecutor which uses Popen to fork a quay-builder process. """ Implementation of BuilderExecutor which uses Popen to fork a quay-builder process.
""" """
def __init__(self, executor_config, manager_hostname): def __init__(self, executor_config, manager_hostname):
self._jobs = {} self._jobs = {}
@ -268,42 +307,44 @@ class PopenExecutor(BuilderExecutor):
""" Executor which uses Popen to fork a quay-builder process. """ Executor which uses Popen to fork a quay-builder process.
""" """
@coroutine @coroutine
@duration_collector_async(metric_queue.builder_time_to_start, ['fork']) @duration_collector_async(metric_queue.builder_time_to_start, ["fork"])
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
# Now start a machine for this job, adding the machine id to the etcd information # Now start a machine for this job, adding the machine id to the etcd information
logger.debug('Forking process for build') logger.debug("Forking process for build")
ws_host = os.environ.get("BUILDMAN_WS_HOST", "localhost") ws_host = os.environ.get("BUILDMAN_WS_HOST", "localhost")
ws_port = os.environ.get("BUILDMAN_WS_PORT", "8787") ws_port = os.environ.get("BUILDMAN_WS_PORT", "8787")
builder_env = { builder_env = {
'TOKEN': token, "TOKEN": token,
'REALM': realm, "REALM": realm,
'ENDPOINT': 'ws://%s:%s' % (ws_host, ws_port), "ENDPOINT": "ws://%s:%s" % (ws_host, ws_port),
'DOCKER_TLS_VERIFY': os.environ.get('DOCKER_TLS_VERIFY', ''), "DOCKER_TLS_VERIFY": os.environ.get("DOCKER_TLS_VERIFY", ""),
'DOCKER_CERT_PATH': os.environ.get('DOCKER_CERT_PATH', ''), "DOCKER_CERT_PATH": os.environ.get("DOCKER_CERT_PATH", ""),
'DOCKER_HOST': os.environ.get('DOCKER_HOST', ''), "DOCKER_HOST": os.environ.get("DOCKER_HOST", ""),
'PATH': "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" "PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
} }
logpipe = LogPipe(logging.INFO) logpipe = LogPipe(logging.INFO)
spawned = subprocess.Popen(os.environ.get('BUILDER_BINARY_LOCATION', spawned = subprocess.Popen(
'/usr/local/bin/quay-builder'), os.environ.get("BUILDER_BINARY_LOCATION", "/usr/local/bin/quay-builder"),
stdout=logpipe, stdout=logpipe,
stderr=logpipe, stderr=logpipe,
env=builder_env) env=builder_env,
)
builder_id = str(uuid.uuid4()) builder_id = str(uuid.uuid4())
self._jobs[builder_id] = (spawned, logpipe) self._jobs[builder_id] = (spawned, logpipe)
logger.debug('Builder spawned with id: %s', builder_id) logger.debug("Builder spawned with id: %s", builder_id)
raise Return(builder_id) raise Return(builder_id)
@coroutine @coroutine
def stop_builder(self, builder_id): def stop_builder(self, builder_id):
if builder_id not in self._jobs: if builder_id not in self._jobs:
raise ExecutorException('Builder id not being tracked by executor.') raise ExecutorException("Builder id not being tracked by executor.")
logger.debug('Killing builder with id: %s', builder_id) logger.debug("Killing builder with id: %s", builder_id)
spawned, logpipe = self._jobs[builder_id] spawned, logpipe = self._jobs[builder_id]
if spawned.poll() is None: if spawned.poll() is None:
@ -314,150 +355,167 @@ class PopenExecutor(BuilderExecutor):
class KubernetesExecutor(BuilderExecutor): class KubernetesExecutor(BuilderExecutor):
""" Executes build jobs by creating Kubernetes jobs which run a qemu-kvm virtual """ Executes build jobs by creating Kubernetes jobs which run a qemu-kvm virtual
machine in a pod """ machine in a pod """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(KubernetesExecutor, self).__init__(*args, **kwargs) super(KubernetesExecutor, self).__init__(*args, **kwargs)
self._loop = get_event_loop() self._loop = get_event_loop()
self.namespace = self.executor_config.get('BUILDER_NAMESPACE', 'builder') self.namespace = self.executor_config.get("BUILDER_NAMESPACE", "builder")
self.image = self.executor_config.get('BUILDER_VM_CONTAINER_IMAGE', self.image = self.executor_config.get(
'quay.io/quay/quay-builder-qemu-coreos:stable') "BUILDER_VM_CONTAINER_IMAGE", "quay.io/quay/quay-builder-qemu-coreos:stable"
)
@coroutine @coroutine
def _request(self, method, path, **kwargs): def _request(self, method, path, **kwargs):
request_options = dict(kwargs) request_options = dict(kwargs)
tls_cert = self.executor_config.get('K8S_API_TLS_CERT') tls_cert = self.executor_config.get("K8S_API_TLS_CERT")
tls_key = self.executor_config.get('K8S_API_TLS_KEY') tls_key = self.executor_config.get("K8S_API_TLS_KEY")
tls_ca = self.executor_config.get('K8S_API_TLS_CA') tls_ca = self.executor_config.get("K8S_API_TLS_CA")
service_account_token = self.executor_config.get('SERVICE_ACCOUNT_TOKEN') service_account_token = self.executor_config.get("SERVICE_ACCOUNT_TOKEN")
if 'timeout' not in request_options: if "timeout" not in request_options:
request_options['timeout'] = self.executor_config.get("K8S_API_TIMEOUT", 20) request_options["timeout"] = self.executor_config.get("K8S_API_TIMEOUT", 20)
if service_account_token: if service_account_token:
scheme = 'https' scheme = "https"
request_options['headers'] = {'Authorization': 'Bearer ' + service_account_token} request_options["headers"] = {
logger.debug('Using service account token for Kubernetes authentication') "Authorization": "Bearer " + service_account_token
}
logger.debug("Using service account token for Kubernetes authentication")
elif tls_cert and tls_key: elif tls_cert and tls_key:
scheme = 'https' scheme = "https"
request_options['cert'] = (tls_cert, tls_key) request_options["cert"] = (tls_cert, tls_key)
logger.debug('Using tls certificate and key for Kubernetes authentication') logger.debug("Using tls certificate and key for Kubernetes authentication")
if tls_ca: if tls_ca:
request_options['verify'] = tls_ca request_options["verify"] = tls_ca
else: else:
scheme = 'http' scheme = "http"
server = self.executor_config.get('K8S_API_SERVER', 'localhost:8080') server = self.executor_config.get("K8S_API_SERVER", "localhost:8080")
url = '%s://%s%s' % (scheme, server, path) url = "%s://%s%s" % (scheme, server, path)
logger.debug('Executor config: %s', self.executor_config) logger.debug("Executor config: %s", self.executor_config)
logger.debug('Kubernetes request: %s %s: %s', method, url, request_options) logger.debug("Kubernetes request: %s %s: %s", method, url, request_options)
res = requests.request(method, url, **request_options) res = requests.request(method, url, **request_options)
logger.debug('Kubernetes response: %s: %s', res.status_code, res.text) logger.debug("Kubernetes response: %s: %s", res.status_code, res.text)
raise Return(res) raise Return(res)
def _jobs_path(self): def _jobs_path(self):
return '/apis/batch/v1/namespaces/%s/jobs' % self.namespace return "/apis/batch/v1/namespaces/%s/jobs" % self.namespace
def _job_path(self, build_uuid): def _job_path(self, build_uuid):
return '%s/%s' % (self._jobs_path(), build_uuid) return "%s/%s" % (self._jobs_path(), build_uuid)
def _kubernetes_distribution(self): def _kubernetes_distribution(self):
return self.executor_config.get('KUBERNETES_DISTRIBUTION', 'basic').lower() return self.executor_config.get("KUBERNETES_DISTRIBUTION", "basic").lower()
def _is_basic_kubernetes_distribution(self): def _is_basic_kubernetes_distribution(self):
return self._kubernetes_distribution() == 'basic' return self._kubernetes_distribution() == "basic"
def _is_openshift_kubernetes_distribution(self): def _is_openshift_kubernetes_distribution(self):
return self._kubernetes_distribution() == 'openshift' return self._kubernetes_distribution() == "openshift"
def _build_job_container_resources(self): def _build_job_container_resources(self):
# Minimum acceptable free resources for this container to "fit" in a quota # Minimum acceptable free resources for this container to "fit" in a quota
# These may be lower than the absolute limits if the cluster is knowingly # These may be lower than the absolute limits if the cluster is knowingly
# oversubscribed by some amount. # oversubscribed by some amount.
container_requests = { container_requests = {
'memory' : self.executor_config.get('CONTAINER_MEMORY_REQUEST', '3968Mi'), "memory": self.executor_config.get("CONTAINER_MEMORY_REQUEST", "3968Mi")
} }
container_limits = { container_limits = {
'memory' : self.executor_config.get('CONTAINER_MEMORY_LIMITS', '5120Mi'), "memory": self.executor_config.get("CONTAINER_MEMORY_LIMITS", "5120Mi"),
'cpu' : self.executor_config.get('CONTAINER_CPU_LIMITS', '1000m'), "cpu": self.executor_config.get("CONTAINER_CPU_LIMITS", "1000m"),
} }
resources = { resources = {"requests": container_requests}
'requests': container_requests,
}
if self._is_openshift_kubernetes_distribution(): if self._is_openshift_kubernetes_distribution():
resources['requests']['cpu'] = self.executor_config.get('CONTAINER_CPU_REQUEST', '500m') resources["requests"]["cpu"] = self.executor_config.get(
resources['limits'] = container_limits "CONTAINER_CPU_REQUEST", "500m"
)
resources["limits"] = container_limits
return resources return resources
def _build_job_containers(self, user_data): def _build_job_containers(self, user_data):
vm_memory_limit = self.executor_config.get('VM_MEMORY_LIMIT', '4G') vm_memory_limit = self.executor_config.get("VM_MEMORY_LIMIT", "4G")
vm_volume_size = self.executor_config.get('VOLUME_SIZE', '32G') vm_volume_size = self.executor_config.get("VOLUME_SIZE", "32G")
container = { container = {
'name': 'builder', "name": "builder",
'imagePullPolicy': 'IfNotPresent', "imagePullPolicy": "IfNotPresent",
'image': self.image, "image": self.image,
'securityContext': {'privileged': True}, "securityContext": {"privileged": True},
'env': [ "env": [
{'name': 'USERDATA', 'value': user_data}, {"name": "USERDATA", "value": user_data},
{'name': 'VM_MEMORY', 'value': vm_memory_limit}, {"name": "VM_MEMORY", "value": vm_memory_limit},
{'name': 'VM_VOLUME_SIZE', 'value': vm_volume_size}, {"name": "VM_VOLUME_SIZE", "value": vm_volume_size},
], ],
'resources': self._build_job_container_resources(), "resources": self._build_job_container_resources(),
} }
if self._is_basic_kubernetes_distribution(): if self._is_basic_kubernetes_distribution():
container['volumeMounts'] = [{'name': 'secrets-mask','mountPath': '/var/run/secrets/kubernetes.io/serviceaccount'}] container["volumeMounts"] = [
{
"name": "secrets-mask",
"mountPath": "/var/run/secrets/kubernetes.io/serviceaccount",
}
]
return container return container
def _job_resource(self, build_uuid, user_data, coreos_channel='stable'): def _job_resource(self, build_uuid, user_data, coreos_channel="stable"):
image_pull_secret_name = self.executor_config.get('IMAGE_PULL_SECRET_NAME', 'builder') image_pull_secret_name = self.executor_config.get(
service_account = self.executor_config.get('SERVICE_ACCOUNT_NAME', 'quay-builder-sa') "IMAGE_PULL_SECRET_NAME", "builder"
node_selector_label_key = self.executor_config.get('NODE_SELECTOR_LABEL_KEY', 'beta.kubernetes.io/instance-type') )
node_selector_label_value = self.executor_config.get('NODE_SELECTOR_LABEL_VALUE', '') service_account = self.executor_config.get(
"SERVICE_ACCOUNT_NAME", "quay-builder-sa"
)
node_selector_label_key = self.executor_config.get(
"NODE_SELECTOR_LABEL_KEY", "beta.kubernetes.io/instance-type"
)
node_selector_label_value = self.executor_config.get(
"NODE_SELECTOR_LABEL_VALUE", ""
)
node_selector = { node_selector = {node_selector_label_key: node_selector_label_value}
node_selector_label_key : node_selector_label_value
}
release_sha = release.GIT_HEAD or 'none' release_sha = release.GIT_HEAD or "none"
if ' ' in release_sha: if " " in release_sha:
release_sha = 'HEAD' release_sha = "HEAD"
job_resource = { job_resource = {
'apiVersion': 'batch/v1', "apiVersion": "batch/v1",
'kind': 'Job', "kind": "Job",
'metadata': { "metadata": {
'namespace': self.namespace, "namespace": self.namespace,
'generateName': build_uuid + '-', "generateName": build_uuid + "-",
'labels': { "labels": {
'build': build_uuid, "build": build_uuid,
'time': datetime.datetime.now().strftime('%Y-%m-%d-%H'), "time": datetime.datetime.now().strftime("%Y-%m-%d-%H"),
'manager': socket.gethostname(), "manager": socket.gethostname(),
'quay-sha': release_sha, "quay-sha": release_sha,
}, },
}, },
'spec' : { "spec": {
'activeDeadlineSeconds': self.executor_config.get('MAXIMUM_JOB_TIME', 7200), "activeDeadlineSeconds": self.executor_config.get(
'template': { "MAXIMUM_JOB_TIME", 7200
'metadata': { ),
'labels': { "template": {
'build': build_uuid, "metadata": {
'time': datetime.datetime.now().strftime('%Y-%m-%d-%H'), "labels": {
'manager': socket.gethostname(), "build": build_uuid,
'quay-sha': release_sha, "time": datetime.datetime.now().strftime("%Y-%m-%d-%H"),
"manager": socket.gethostname(),
"quay-sha": release_sha,
}
}, },
}, "spec": {
'spec': { "imagePullSecrets": [{"name": image_pull_secret_name}],
'imagePullSecrets': [{ 'name': image_pull_secret_name }], "restartPolicy": "Never",
'restartPolicy': 'Never', "dnsPolicy": "Default",
'dnsPolicy': 'Default', "containers": [self._build_job_containers(user_data)],
'containers': [self._build_job_containers(user_data)],
}, },
}, },
}, },
@ -465,17 +523,19 @@ class KubernetesExecutor(BuilderExecutor):
if self._is_openshift_kubernetes_distribution(): if self._is_openshift_kubernetes_distribution():
# Setting `automountServiceAccountToken` to false will prevent automounting API credentials for a service account. # Setting `automountServiceAccountToken` to false will prevent automounting API credentials for a service account.
job_resource['spec']['template']['spec']['automountServiceAccountToken'] = False job_resource["spec"]["template"]["spec"][
"automountServiceAccountToken"
] = False
# Use dedicated service account that has no authorization to any resources. # Use dedicated service account that has no authorization to any resources.
job_resource['spec']['template']['spec']['serviceAccount'] = service_account job_resource["spec"]["template"]["spec"]["serviceAccount"] = service_account
# Setting `enableServiceLinks` to false prevents information about other services from being injected into pod's # Setting `enableServiceLinks` to false prevents information about other services from being injected into pod's
# environment variables. Pod has no visibility into other services on the cluster. # environment variables. Pod has no visibility into other services on the cluster.
job_resource['spec']['template']['spec']['enableServiceLinks'] = False job_resource["spec"]["template"]["spec"]["enableServiceLinks"] = False
if node_selector_label_value.strip() != '': if node_selector_label_value.strip() != "":
job_resource['spec']['template']['spec']['nodeSelector'] = node_selector job_resource["spec"]["template"]["spec"]["nodeSelector"] = node_selector
if self._is_basic_kubernetes_distribution(): if self._is_basic_kubernetes_distribution():
# This volume is a hack to mask the token for the namespace's # This volume is a hack to mask the token for the namespace's
@ -486,43 +546,55 @@ class KubernetesExecutor(BuilderExecutor):
# #
# https://github.com/kubernetes/kubernetes/issues/16779 # https://github.com/kubernetes/kubernetes/issues/16779
# #
job_resource['spec']['template']['spec']['volumes'] = [{'name': 'secrets-mask','emptyDir': {'medium': 'Memory'}}] job_resource["spec"]["template"]["spec"]["volumes"] = [
{"name": "secrets-mask", "emptyDir": {"medium": "Memory"}}
]
return job_resource return job_resource
@coroutine @coroutine
@duration_collector_async(metric_queue.builder_time_to_start, ['k8s']) @duration_collector_async(metric_queue.builder_time_to_start, ["k8s"])
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
# generate resource # generate resource
channel = self.executor_config.get('COREOS_CHANNEL', 'stable') channel = self.executor_config.get("COREOS_CHANNEL", "stable")
user_data = self.generate_cloud_config(realm, token, build_uuid, channel, self.manager_hostname) user_data = self.generate_cloud_config(
realm, token, build_uuid, channel, self.manager_hostname
)
resource = self._job_resource(build_uuid, user_data, channel) resource = self._job_resource(build_uuid, user_data, channel)
logger.debug('Using Kubernetes Distribution: %s', self._kubernetes_distribution()) logger.debug(
logger.debug('Generated kubernetes resource:\n%s', resource) "Using Kubernetes Distribution: %s", self._kubernetes_distribution()
)
logger.debug("Generated kubernetes resource:\n%s", resource)
# schedule # schedule
create_job = yield From(self._request('POST', self._jobs_path(), json=resource)) create_job = yield From(self._request("POST", self._jobs_path(), json=resource))
if int(create_job.status_code / 100) != 2: if int(create_job.status_code / 100) != 2:
raise ExecutorException('Failed to create job: %s: %s: %s' % raise ExecutorException(
(build_uuid, create_job.status_code, create_job.text)) "Failed to create job: %s: %s: %s"
% (build_uuid, create_job.status_code, create_job.text)
)
job = create_job.json() job = create_job.json()
raise Return(job['metadata']['name']) raise Return(job["metadata"]["name"])
@coroutine @coroutine
def stop_builder(self, builder_id): def stop_builder(self, builder_id):
pods_path = '/api/v1/namespaces/%s/pods' % self.namespace pods_path = "/api/v1/namespaces/%s/pods" % self.namespace
# Delete the job itself. # Delete the job itself.
try: try:
yield From(self._request('DELETE', self._job_path(builder_id))) yield From(self._request("DELETE", self._job_path(builder_id)))
except: except:
logger.exception('Failed to send delete job call for job %s', builder_id) logger.exception("Failed to send delete job call for job %s", builder_id)
# Delete the pod(s) for the job. # Delete the pod(s) for the job.
selectorString = "job-name=%s" % builder_id selectorString = "job-name=%s" % builder_id
try: try:
yield From(self._request('DELETE', pods_path, params=dict(labelSelector=selectorString))) yield From(
self._request(
"DELETE", pods_path, params=dict(labelSelector=selectorString)
)
)
except: except:
logger.exception("Failed to send delete pod call for job %s", builder_id) logger.exception("Failed to send delete pod call for job %s", builder_id)
@ -530,6 +602,7 @@ class KubernetesExecutor(BuilderExecutor):
class LogPipe(threading.Thread): class LogPipe(threading.Thread):
""" Adapted from http://codereview.stackexchange.com/a/17959 """ Adapted from http://codereview.stackexchange.com/a/17959
""" """
def __init__(self, level): def __init__(self, level):
"""Setup the object with a logger and a loglevel """Setup the object with a logger and a loglevel
and start the thread and start the thread
@ -549,8 +622,8 @@ class LogPipe(threading.Thread):
def run(self): def run(self):
"""Run the thread, logging everything. """Run the thread, logging everything.
""" """
for line in iter(self.pipe_reader.readline, ''): for line in iter(self.pipe_reader.readline, ""):
logging.log(self.level, line.strip('\n')) logging.log(self.level, line.strip("\n"))
self.pipe_reader.close() self.pipe_reader.close()

View file

@ -1,5 +1,6 @@
class NoopCanceller(object): class NoopCanceller(object):
""" A class that can not cancel a build """ """ A class that can not cancel a build """
def __init__(self, config=None): def __init__(self, config=None):
pass pass

View file

@ -7,20 +7,23 @@ from util import slash_join
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CANCEL_PREFIX = 'cancel/' CANCEL_PREFIX = "cancel/"
class OrchestratorCanceller(object): class OrchestratorCanceller(object):
""" An asynchronous way to cancel a build with any Orchestrator. """ """ An asynchronous way to cancel a build with any Orchestrator. """
def __init__(self, config): def __init__(self, config):
self._orchestrator = orchestrator_from_config(config, canceller_only=True) self._orchestrator = orchestrator_from_config(config, canceller_only=True)
def try_cancel_build(self, build_uuid): def try_cancel_build(self, build_uuid):
logger.info('Cancelling build %s', build_uuid) logger.info("Cancelling build %s", build_uuid)
cancel_key = slash_join(CANCEL_PREFIX, build_uuid) cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
try: try:
self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60) self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
return True return True
except OrchestratorError: except OrchestratorError:
logger.exception('Failed to write cancel action to redis with uuid %s', build_uuid) logger.exception(
"Failed to write cancel action to redis with uuid %s", build_uuid
)
return False return False

View file

@ -9,8 +9,7 @@ from trollius import coroutine, get_event_loop, From, Future, Return
from app import metric_queue from app import metric_queue
from buildman.asyncutil import AsyncWrapper from buildman.asyncutil import AsyncWrapper
from buildman.component.buildcomponent import BuildComponent from buildman.component.buildcomponent import BuildComponent
from buildman.manager.ephemeral import (EphemeralBuilderManager, REALM_PREFIX, from buildman.manager.ephemeral import EphemeralBuilderManager, REALM_PREFIX, JOB_PREFIX
JOB_PREFIX)
from buildman.manager.executor import BuilderExecutor, ExecutorException from buildman.manager.executor import BuilderExecutor, ExecutorException
from buildman.orchestrator import KeyEvent, KeyChange from buildman.orchestrator import KeyEvent, KeyChange
from buildman.server import BuildJobResult from buildman.server import BuildJobResult
@ -18,8 +17,8 @@ from util import slash_join
from util.metrics.metricqueue import duration_collector_async from util.metrics.metricqueue import duration_collector_async
BUILD_UUID = 'deadbeef-dead-beef-dead-deadbeefdead' BUILD_UUID = "deadbeef-dead-beef-dead-deadbeefdead"
REALM_ID = '1234-realm' REALM_ID = "1234-realm"
def async_test(f): def async_test(f):
@ -28,6 +27,7 @@ def async_test(f):
future = coro(*args, **kwargs) future = coro(*args, **kwargs)
loop = get_event_loop() loop = get_event_loop()
loop.run_until_complete(future) loop.run_until_complete(future)
return wrapper return wrapper
@ -36,7 +36,9 @@ class TestExecutor(BuilderExecutor):
job_stopped = None job_stopped = None
@coroutine @coroutine
@duration_collector_async(metric_queue.builder_time_to_start, labelvalues=["testlabel"]) @duration_collector_async(
metric_queue.builder_time_to_start, labelvalues=["testlabel"]
)
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
self.job_started = str(uuid.uuid4()) self.job_started = str(uuid.uuid4())
raise Return(self.job_started) raise Return(self.job_started)
@ -48,9 +50,11 @@ class TestExecutor(BuilderExecutor):
class BadExecutor(BuilderExecutor): class BadExecutor(BuilderExecutor):
@coroutine @coroutine
@duration_collector_async(metric_queue.builder_time_to_start, labelvalues=["testlabel"]) @duration_collector_async(
metric_queue.builder_time_to_start, labelvalues=["testlabel"]
)
def start_builder(self, realm, token, build_uuid): def start_builder(self, realm, token, build_uuid):
raise ExecutorException('raised on purpose!') raise ExecutorException("raised on purpose!")
class EphemeralBuilderTestCase(unittest.TestCase): class EphemeralBuilderTestCase(unittest.TestCase):
@ -64,6 +68,7 @@ class EphemeralBuilderTestCase(unittest.TestCase):
new_future = Future() new_future = Future()
new_future.set_result(result) new_future.set_result(result)
return new_future return new_future
return inner return inner
def setUp(self): def setUp(self):
@ -74,15 +79,12 @@ class EphemeralBuilderTestCase(unittest.TestCase):
@coroutine @coroutine
def _register_component(self, realm_spec, build_component, token): def _register_component(self, realm_spec, build_component, token):
raise Return('hello') raise Return("hello")
def _create_build_job(self, namespace='namespace', retries=3): def _create_build_job(self, namespace="namespace", retries=3):
mock_job = Mock() mock_job = Mock()
mock_job.job_details = {'build_uuid': BUILD_UUID} mock_job.job_details = {"build_uuid": BUILD_UUID}
mock_job.job_item = { mock_job.job_item = {"body": json.dumps(mock_job.job_details), "id": 1}
'body': json.dumps(mock_job.job_details),
'id': 1,
}
mock_job.namespace = namespace mock_job.namespace = namespace
mock_job.retries_remaining = retries mock_job.retries_remaining = retries
@ -103,21 +105,26 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
new_future = Future() new_future = Future()
new_future.set_result(result) new_future.set_result(result)
return new_future return new_future
return inner return inner
def _create_mock_executor(self, *args, **kwargs): def _create_mock_executor(self, *args, **kwargs):
self.test_executor = Mock(spec=BuilderExecutor) self.test_executor = Mock(spec=BuilderExecutor)
self.test_executor.start_builder = Mock(side_effect=self._create_completed_future('123')) self.test_executor.start_builder = Mock(
self.test_executor.stop_builder = Mock(side_effect=self._create_completed_future()) side_effect=self._create_completed_future("123")
)
self.test_executor.stop_builder = Mock(
side_effect=self._create_completed_future()
)
self.test_executor.setup_time = 60 self.test_executor.setup_time = 60
self.test_executor.name = 'MockExecutor' self.test_executor.name = "MockExecutor"
self.test_executor.minimum_retry_threshold = 0 self.test_executor.minimum_retry_threshold = 0
return self.test_executor return self.test_executor
def setUp(self): def setUp(self):
super(TestEphemeralLifecycle, self).setUp() super(TestEphemeralLifecycle, self).setUp()
EphemeralBuilderManager.EXECUTORS['test'] = self._create_mock_executor EphemeralBuilderManager.EXECUTORS["test"] = self._create_mock_executor
self.register_component_callback = Mock() self.register_component_callback = Mock()
self.unregister_component_callback = Mock() self.unregister_component_callback = Mock()
@ -129,14 +136,13 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.unregister_component_callback, self.unregister_component_callback,
self.job_heartbeat_callback, self.job_heartbeat_callback,
self.job_complete_callback, self.job_complete_callback,
'127.0.0.1', "127.0.0.1",
30, 30,
) )
self.manager.initialize({ self.manager.initialize(
'EXECUTOR': 'test', {"EXECUTOR": "test", "ORCHESTRATOR": {"MEM_CONFIG": None}}
'ORCHESTRATOR': {'MEM_CONFIG': None}, )
})
# Ensure that that the realm and building callbacks have been registered # Ensure that that the realm and building callbacks have been registered
callback_keys = [key for key in self.manager._orchestrator.callbacks] callback_keys = [key for key in self.manager._orchestrator.callbacks]
@ -144,13 +150,12 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
self.mock_job = self._create_build_job() self.mock_job = self._create_build_job()
self.mock_job_key = slash_join('building', BUILD_UUID) self.mock_job_key = slash_join("building", BUILD_UUID)
def tearDown(self): def tearDown(self):
super(TestEphemeralLifecycle, self).tearDown() super(TestEphemeralLifecycle, self).tearDown()
self.manager.shutdown() self.manager.shutdown()
@coroutine @coroutine
def _setup_job_for_managers(self): def _setup_job_for_managers(self):
test_component = Mock(spec=BuildComponent) test_component = Mock(spec=BuildComponent)
@ -170,19 +175,25 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
realm_for_build = self._find_realm_key(self.manager._orchestrator, BUILD_UUID) realm_for_build = self._find_realm_key(self.manager._orchestrator, BUILD_UUID)
raw_realm_data = yield From(self.manager._orchestrator.get_key(slash_join('realm', raw_realm_data = yield From(
realm_for_build))) self.manager._orchestrator.get_key(slash_join("realm", realm_for_build))
)
realm_data = json.loads(raw_realm_data) realm_data = json.loads(raw_realm_data)
realm_data['realm'] = REALM_ID realm_data["realm"] = REALM_ID
# Right now the job is not registered with any managers because etcd has not accepted the job # Right now the job is not registered with any managers because etcd has not accepted the job
self.assertEqual(self.register_component_callback.call_count, 0) self.assertEqual(self.register_component_callback.call_count, 0)
# Fire off a realm changed with the same data. # Fire off a realm changed with the same data.
yield From(self.manager._realm_callback( yield From(
KeyChange(KeyEvent.CREATE, self.manager._realm_callback(
KeyChange(
KeyEvent.CREATE,
slash_join(REALM_PREFIX, REALM_ID), slash_join(REALM_PREFIX, REALM_ID),
json.dumps(realm_data)))) json.dumps(realm_data),
)
)
)
# Ensure that we have at least one component node. # Ensure that we have at least one component node.
self.assertEqual(self.register_component_callback.call_count, 1) self.assertEqual(self.register_component_callback.call_count, 1)
@ -198,13 +209,12 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
for key, value in iteritems(orchestrator.state): for key, value in iteritems(orchestrator.state):
if key.startswith(REALM_PREFIX): if key.startswith(REALM_PREFIX):
parsed_value = json.loads(value) parsed_value = json.loads(value)
body = json.loads(parsed_value['job_queue_item']['body']) body = json.loads(parsed_value["job_queue_item"]["body"])
if body['build_uuid'] == build_uuid: if body["build_uuid"] == build_uuid:
return parsed_value['realm'] return parsed_value["realm"]
continue continue
raise KeyError raise KeyError
@async_test @async_test
def test_schedule_and_complete(self): def test_schedule_and_complete(self):
# Test that a job is properly registered with all of the managers # Test that a job is properly registered with all of the managers
@ -216,7 +226,11 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID)) self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
# Finish the job # Finish the job
yield From(self.manager.job_completed(self.mock_job, BuildJobResult.COMPLETE, test_component)) yield From(
self.manager.job_completed(
self.mock_job, BuildJobResult.COMPLETE, test_component
)
)
# Ensure that the executor kills the job. # Ensure that the executor kills the job.
self.assertEqual(self.test_executor.stop_builder.call_count, 1) self.assertEqual(self.test_executor.stop_builder.call_count, 1)
@ -230,13 +244,22 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
# Prepare a job to be taken by another manager # Prepare a job to be taken by another manager
test_component = yield From(self._setup_job_for_managers()) test_component = yield From(self._setup_job_for_managers())
yield From(self.manager._realm_callback( yield From(
KeyChange(KeyEvent.DELETE, self.manager._realm_callback(
KeyChange(
KeyEvent.DELETE,
slash_join(REALM_PREFIX, REALM_ID), slash_join(REALM_PREFIX, REALM_ID),
json.dumps({'realm': REALM_ID, json.dumps(
'token': 'beef', {
'execution_id': '123', "realm": REALM_ID,
'job_queue_item': self.mock_job.job_item})))) "token": "beef",
"execution_id": "123",
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
self.unregister_component_callback.assert_called_once_with(test_component) self.unregister_component_callback.assert_called_once_with(test_component)
@ -248,11 +271,20 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID)) self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
# Delete the job once it has "completed". # Delete the job once it has "completed".
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.DELETE, self.manager._job_callback(
KeyChange(
KeyEvent.DELETE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': False, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": False,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
# Ensure the job was removed from the info, but stop was not called. # Ensure the job was removed from the info, but stop was not called.
self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID)) self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
@ -265,11 +297,20 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
# Send a signal to the callback that the job has been created. # Send a signal to the callback that the job has been created.
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.CREATE, self.manager._job_callback(
KeyChange(
KeyEvent.CREATE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': False, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": False,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
# Ensure the create does nothing. # Ensure the create does nothing.
self.assertEqual(self.test_executor.stop_builder.call_count, 0) self.assertEqual(self.test_executor.stop_builder.call_count, 0)
@ -281,11 +322,20 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
# Send a signal to the callback that a worker has expired # Send a signal to the callback that a worker has expired
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.EXPIRE, self.manager._job_callback(
KeyChange(
KeyEvent.EXPIRE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': True, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": True,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
# Since the realm was never registered, expiration should do nothing. # Since the realm was never registered, expiration should do nothing.
self.assertEqual(self.test_executor.stop_builder.call_count, 0) self.assertEqual(self.test_executor.stop_builder.call_count, 0)
@ -298,13 +348,22 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
callback_keys = [key for key in self.manager._orchestrator.callbacks] callback_keys = [key for key in self.manager._orchestrator.callbacks]
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.EXPIRE, self.manager._job_callback(
KeyChange(
KeyEvent.EXPIRE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': True, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": True,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
self.test_executor.stop_builder.assert_called_once_with('123') self.test_executor.stop_builder.assert_called_once_with("123")
self.assertEqual(self.test_executor.stop_builder.call_count, 1) self.assertEqual(self.test_executor.stop_builder.call_count, 1)
@async_test @async_test
@ -316,11 +375,20 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
# Send a signal to the callback that a worker has expired # Send a signal to the callback that a worker has expired
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.DELETE, self.manager._job_callback(
KeyChange(
KeyEvent.DELETE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': False, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": False,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
self.assertEqual(self.test_executor.stop_builder.call_count, 0) self.assertEqual(self.test_executor.stop_builder.call_count, 0)
self.assertEqual(self.job_complete_callback.call_count, 0) self.assertEqual(self.job_complete_callback.call_count, 0)
@ -335,25 +403,36 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
self.assertIn(JOB_PREFIX, callback_keys) self.assertIn(JOB_PREFIX, callback_keys)
# Send a signal to the callback that a worker has expired # Send a signal to the callback that a worker has expired
yield From(self.manager._job_callback( yield From(
KeyChange(KeyEvent.EXPIRE, self.manager._job_callback(
KeyChange(
KeyEvent.EXPIRE,
self.mock_job_key, self.mock_job_key,
json.dumps({'had_heartbeat': False, json.dumps(
'job_queue_item': self.mock_job.job_item})))) {
"had_heartbeat": False,
"job_queue_item": self.mock_job.job_item,
}
),
)
)
)
self.test_executor.stop_builder.assert_called_once_with('123') self.test_executor.stop_builder.assert_called_once_with("123")
self.assertEqual(self.test_executor.stop_builder.call_count, 1) self.assertEqual(self.test_executor.stop_builder.call_count, 1)
# Ensure the job was marked as incomplete, with an update_phase to True (so the DB record and # Ensure the job was marked as incomplete, with an update_phase to True (so the DB record and
# logs are updated as well) # logs are updated as well)
yield From(self.job_complete_callback.assert_called_once_with(ANY, BuildJobResult.INCOMPLETE, yield From(
'MockExecutor', self.job_complete_callback.assert_called_once_with(
update_phase=True)) ANY, BuildJobResult.INCOMPLETE, "MockExecutor", update_phase=True
)
)
@async_test @async_test
def test_change_worker(self): def test_change_worker(self):
# Send a signal to the callback that a worker key has been changed # Send a signal to the callback that a worker key has been changed
self.manager._job_callback(KeyChange(KeyEvent.SET, self.mock_job_key, 'value')) self.manager._job_callback(KeyChange(KeyEvent.SET, self.mock_job_key, "value"))
self.assertEqual(self.test_executor.stop_builder.call_count, 0) self.assertEqual(self.test_executor.stop_builder.call_count, 0)
@async_test @async_test
@ -361,18 +440,25 @@ class TestEphemeralLifecycle(EphemeralBuilderTestCase):
test_component = yield From(self._setup_job_for_managers()) test_component = yield From(self._setup_job_for_managers())
# Send a signal to the callback that a realm has expired # Send a signal to the callback that a realm has expired
yield From(self.manager._realm_callback(KeyChange( yield From(
self.manager._realm_callback(
KeyChange(
KeyEvent.EXPIRE, KeyEvent.EXPIRE,
self.mock_job_key, self.mock_job_key,
json.dumps({ json.dumps(
'realm': REALM_ID, {
'execution_id': 'foobar', "realm": REALM_ID,
'executor_name': 'MockExecutor', "execution_id": "foobar",
'job_queue_item': {'body': '{"build_uuid": "fakeid"}'}, "executor_name": "MockExecutor",
})))) "job_queue_item": {"body": '{"build_uuid": "fakeid"}'},
}
),
)
)
)
# Ensure that the cleanup code for the executor was called. # Ensure that the cleanup code for the executor was called.
self.test_executor.stop_builder.assert_called_once_with('foobar') self.test_executor.stop_builder.assert_called_once_with("foobar")
self.assertEqual(self.test_executor.stop_builder.call_count, 1) self.assertEqual(self.test_executor.stop_builder.call_count, 1)
@ -396,7 +482,7 @@ class TestEphemeral(EphemeralBuilderTestCase):
unregister_component_callback, unregister_component_callback,
job_heartbeat_callback, job_heartbeat_callback,
job_complete_callback, job_complete_callback,
'127.0.0.1', "127.0.0.1",
30, 30,
) )
@ -405,124 +491,132 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.manager.shutdown() self.manager.shutdown()
def test_verify_executor_oldconfig(self): def test_verify_executor_oldconfig(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTOR': 'test', {
'EXECUTOR_CONFIG': dict(MINIMUM_RETRY_THRESHOLD=42), "EXECUTOR": "test",
'ORCHESTRATOR': {'MEM_CONFIG': None}, "EXECUTOR_CONFIG": dict(MINIMUM_RETRY_THRESHOLD=42),
}) "ORCHESTRATOR": {"MEM_CONFIG": None},
}
)
# Ensure that we have a single test executor. # Ensure that we have a single test executor.
self.assertEqual(1, len(self.manager.registered_executors)) self.assertEqual(1, len(self.manager.registered_executors))
self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold) self.assertEqual(
self.assertEqual('TestExecutor', self.manager.registered_executors[0].name) 42, self.manager.registered_executors[0].minimum_retry_threshold
)
self.assertEqual("TestExecutor", self.manager.registered_executors[0].name)
def test_verify_executor_newconfig(self): def test_verify_executor_newconfig(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [{ {
'EXECUTOR': 'test', "EXECUTORS": [{"EXECUTOR": "test", "MINIMUM_RETRY_THRESHOLD": 42}],
'MINIMUM_RETRY_THRESHOLD': 42 "ORCHESTRATOR": {"MEM_CONFIG": None},
}], }
'ORCHESTRATOR': {'MEM_CONFIG': None}, )
})
# Ensure that we have a single test executor. # Ensure that we have a single test executor.
self.assertEqual(1, len(self.manager.registered_executors)) self.assertEqual(1, len(self.manager.registered_executors))
self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold) self.assertEqual(
42, self.manager.registered_executors[0].minimum_retry_threshold
)
def test_multiple_executors_samename(self): def test_multiple_executors_samename(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
EphemeralBuilderManager.EXECUTORS['anotherexecutor'] = TestExecutor EphemeralBuilderManager.EXECUTORS["anotherexecutor"] = TestExecutor
with self.assertRaises(Exception): with self.assertRaises(Exception):
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [
{ {
'NAME': 'primary', "EXECUTORS": [
'EXECUTOR': 'test', {
'MINIMUM_RETRY_THRESHOLD': 42 "NAME": "primary",
"EXECUTOR": "test",
"MINIMUM_RETRY_THRESHOLD": 42,
}, },
{ {
'NAME': 'primary', "NAME": "primary",
'EXECUTOR': 'anotherexecutor', "EXECUTOR": "anotherexecutor",
'MINIMUM_RETRY_THRESHOLD': 24 "MINIMUM_RETRY_THRESHOLD": 24,
}, },
], ],
'ORCHESTRATOR': {'MEM_CONFIG': None}, "ORCHESTRATOR": {"MEM_CONFIG": None},
}) }
)
def test_verify_multiple_executors(self): def test_verify_multiple_executors(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
EphemeralBuilderManager.EXECUTORS['anotherexecutor'] = TestExecutor EphemeralBuilderManager.EXECUTORS["anotherexecutor"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [
{ {
'NAME': 'primary', "EXECUTORS": [
'EXECUTOR': 'test', {
'MINIMUM_RETRY_THRESHOLD': 42 "NAME": "primary",
"EXECUTOR": "test",
"MINIMUM_RETRY_THRESHOLD": 42,
}, },
{ {
'NAME': 'secondary', "NAME": "secondary",
'EXECUTOR': 'anotherexecutor', "EXECUTOR": "anotherexecutor",
'MINIMUM_RETRY_THRESHOLD': 24 "MINIMUM_RETRY_THRESHOLD": 24,
}, },
], ],
'ORCHESTRATOR': {'MEM_CONFIG': None}, "ORCHESTRATOR": {"MEM_CONFIG": None},
}) }
)
# Ensure that we have a two test executors. # Ensure that we have a two test executors.
self.assertEqual(2, len(self.manager.registered_executors)) self.assertEqual(2, len(self.manager.registered_executors))
self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold) self.assertEqual(
self.assertEqual(24, self.manager.registered_executors[1].minimum_retry_threshold) 42, self.manager.registered_executors[0].minimum_retry_threshold
)
self.assertEqual(
24, self.manager.registered_executors[1].minimum_retry_threshold
)
def test_skip_invalid_executor(self): def test_skip_invalid_executor(self):
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [
{ {
'EXECUTOR': 'unknown', "EXECUTORS": [{"EXECUTOR": "unknown", "MINIMUM_RETRY_THRESHOLD": 42}],
'MINIMUM_RETRY_THRESHOLD': 42 "ORCHESTRATOR": {"MEM_CONFIG": None},
}, }
], )
'ORCHESTRATOR': {'MEM_CONFIG': None},
})
self.assertEqual(0, len(self.manager.registered_executors)) self.assertEqual(0, len(self.manager.registered_executors))
@async_test @async_test
def test_schedule_job_namespace_filter(self): def test_schedule_job_namespace_filter(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [{ {
'EXECUTOR': 'test', "EXECUTORS": [
'NAMESPACE_WHITELIST': ['something'], {"EXECUTOR": "test", "NAMESPACE_WHITELIST": ["something"]}
}], ],
'ORCHESTRATOR': {'MEM_CONFIG': None}, "ORCHESTRATOR": {"MEM_CONFIG": None},
}) }
)
# Try with a build job in an invalid namespace. # Try with a build job in an invalid namespace.
build_job = self._create_build_job(namespace='somethingelse') build_job = self._create_build_job(namespace="somethingelse")
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertFalse(result[0]) self.assertFalse(result[0])
# Try with a valid namespace. # Try with a valid namespace.
build_job = self._create_build_job(namespace='something') build_job = self._create_build_job(namespace="something")
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@async_test @async_test
def test_schedule_job_retries_filter(self): def test_schedule_job_retries_filter(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [{ {
'EXECUTOR': 'test', "EXECUTORS": [{"EXECUTOR": "test", "MINIMUM_RETRY_THRESHOLD": 2}],
'MINIMUM_RETRY_THRESHOLD': 2, "ORCHESTRATOR": {"MEM_CONFIG": None},
}], }
'ORCHESTRATOR': {'MEM_CONFIG': None}, )
})
# Try with a build job that has too few retries. # Try with a build job that has too few retries.
build_job = self._create_build_job(retries=1) build_job = self._create_build_job(retries=1)
@ -536,29 +630,31 @@ class TestEphemeral(EphemeralBuilderTestCase):
@async_test @async_test
def test_schedule_job_executor_fallback(self): def test_schedule_job_executor_fallback(self):
EphemeralBuilderManager.EXECUTORS['primary'] = TestExecutor EphemeralBuilderManager.EXECUTORS["primary"] = TestExecutor
EphemeralBuilderManager.EXECUTORS['secondary'] = TestExecutor EphemeralBuilderManager.EXECUTORS["secondary"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTORS': [
{ {
'NAME': 'primary', "EXECUTORS": [
'EXECUTOR': 'primary', {
'NAMESPACE_WHITELIST': ['something'], "NAME": "primary",
'MINIMUM_RETRY_THRESHOLD': 3, "EXECUTOR": "primary",
"NAMESPACE_WHITELIST": ["something"],
"MINIMUM_RETRY_THRESHOLD": 3,
}, },
{ {
'NAME': 'secondary', "NAME": "secondary",
'EXECUTOR': 'secondary', "EXECUTOR": "secondary",
'MINIMUM_RETRY_THRESHOLD': 2, "MINIMUM_RETRY_THRESHOLD": 2,
}, },
], ],
'ALLOWED_WORKER_COUNT': 5, "ALLOWED_WORKER_COUNT": 5,
'ORCHESTRATOR': {'MEM_CONFIG': None}, "ORCHESTRATOR": {"MEM_CONFIG": None},
}) }
)
# Try a job not matching the primary's namespace filter. Should schedule on secondary. # Try a job not matching the primary's namespace filter. Should schedule on secondary.
build_job = self._create_build_job(namespace='somethingelse') build_job = self._create_build_job(namespace="somethingelse")
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@ -569,7 +665,7 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.manager.registered_executors[1].job_started = None self.manager.registered_executors[1].job_started = None
# Try a job not matching the primary's retry minimum. Should schedule on secondary. # Try a job not matching the primary's retry minimum. Should schedule on secondary.
build_job = self._create_build_job(namespace='something', retries=2) build_job = self._create_build_job(namespace="something", retries=2)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@ -580,7 +676,7 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.manager.registered_executors[1].job_started = None self.manager.registered_executors[1].job_started = None
# Try a job matching the primary. Should schedule on the primary. # Try a job matching the primary. Should schedule on the primary.
build_job = self._create_build_job(namespace='something', retries=3) build_job = self._create_build_job(namespace="something", retries=3)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@ -591,7 +687,7 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.manager.registered_executors[1].job_started = None self.manager.registered_executors[1].job_started = None
# Try a job not matching either's restrictions. # Try a job not matching either's restrictions.
build_job = self._create_build_job(namespace='somethingelse', retries=1) build_job = self._create_build_job(namespace="somethingelse", retries=1)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertFalse(result[0]) self.assertFalse(result[0])
@ -601,27 +697,27 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.manager.registered_executors[0].job_started = None self.manager.registered_executors[0].job_started = None
self.manager.registered_executors[1].job_started = None self.manager.registered_executors[1].job_started = None
@async_test @async_test
def test_schedule_job_single_executor(self): def test_schedule_job_single_executor(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTOR': 'test', {
'EXECUTOR_CONFIG': {}, "EXECUTOR": "test",
'ALLOWED_WORKER_COUNT': 5, "EXECUTOR_CONFIG": {},
'ORCHESTRATOR': {'MEM_CONFIG': None}, "ALLOWED_WORKER_COUNT": 5,
}) "ORCHESTRATOR": {"MEM_CONFIG": None},
}
)
build_job = self._create_build_job(namespace='something', retries=3) build_job = self._create_build_job(namespace="something", retries=3)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
self.assertIsNotNone(self.manager.registered_executors[0].job_started) self.assertIsNotNone(self.manager.registered_executors[0].job_started)
self.manager.registered_executors[0].job_started = None self.manager.registered_executors[0].job_started = None
build_job = self._create_build_job(namespace="something", retries=0)
build_job = self._create_build_job(namespace='something', retries=0)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@ -630,30 +726,34 @@ class TestEphemeral(EphemeralBuilderTestCase):
@async_test @async_test
def test_executor_exception(self): def test_executor_exception(self):
EphemeralBuilderManager.EXECUTORS['bad'] = BadExecutor EphemeralBuilderManager.EXECUTORS["bad"] = BadExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTOR': 'bad', {
'EXECUTOR_CONFIG': {}, "EXECUTOR": "bad",
'ORCHESTRATOR': {'MEM_CONFIG': None}, "EXECUTOR_CONFIG": {},
}) "ORCHESTRATOR": {"MEM_CONFIG": None},
}
)
build_job = self._create_build_job(namespace='something', retries=3) build_job = self._create_build_job(namespace="something", retries=3)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertFalse(result[0]) self.assertFalse(result[0])
@async_test @async_test
def test_schedule_and_stop(self): def test_schedule_and_stop(self):
EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
self.manager.initialize({ self.manager.initialize(
'EXECUTOR': 'test', {
'EXECUTOR_CONFIG': {}, "EXECUTOR": "test",
'ORCHESTRATOR': {'MEM_CONFIG': None}, "EXECUTOR_CONFIG": {},
}) "ORCHESTRATOR": {"MEM_CONFIG": None},
}
)
# Start the build job. # Start the build job.
build_job = self._create_build_job(namespace='something', retries=3) build_job = self._create_build_job(namespace="something", retries=3)
result = yield From(self.manager.schedule(build_job)) result = yield From(self.manager.schedule(build_job))
self.assertTrue(result[0]) self.assertTrue(result[0])
@ -661,19 +761,23 @@ class TestEphemeral(EphemeralBuilderTestCase):
self.assertIsNotNone(executor.job_started) self.assertIsNotNone(executor.job_started)
# Register the realm so the build information is added. # Register the realm so the build information is added.
yield From(self.manager._register_realm({ yield From(
'realm': str(uuid.uuid4()), self.manager._register_realm(
'token': str(uuid.uuid4()), {
'execution_id': executor.job_started, "realm": str(uuid.uuid4()),
'executor_name': 'TestExecutor', "token": str(uuid.uuid4()),
'build_uuid': build_job.build_uuid, "execution_id": executor.job_started,
'job_queue_item': build_job.job_item, "executor_name": "TestExecutor",
})) "build_uuid": build_job.build_uuid,
"job_queue_item": build_job.job_item,
}
)
)
# Stop the build job. # Stop the build job.
yield From(self.manager.kill_builder_executor(build_job.build_uuid)) yield From(self.manager.kill_builder_executor(build_job.build_uuid))
self.assertEqual(executor.job_stopped, executor.job_started) self.assertEqual(executor.job_stopped, executor.job_started)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -2,4 +2,3 @@ import buildtrigger.bitbuckethandler
import buildtrigger.customhandler import buildtrigger.customhandler
import buildtrigger.githubhandler import buildtrigger.githubhandler
import buildtrigger.gitlabhandler import buildtrigger.gitlabhandler

View file

@ -9,158 +9,155 @@ from data import model
from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException
NAMESPACES_SCHEMA = { NAMESPACES_SCHEMA = {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {
'personal': { "personal": {
'type': 'boolean', "type": "boolean",
'description': 'True if the namespace is the user\'s personal namespace', "description": "True if the namespace is the user's personal namespace",
}, },
'score': { "score": {
'type': 'number', "type": "number",
'description': 'Score of the relevance of the namespace', "description": "Score of the relevance of the namespace",
}, },
'avatar_url': { "avatar_url": {
'type': ['string', 'null'], "type": ["string", "null"],
'description': 'URL of the avatar for this namespace', "description": "URL of the avatar for this namespace",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'URL of the website to view the namespace', "description": "URL of the website to view the namespace",
}, },
'id': { "id": {
'type': 'string', "type": "string",
'description': 'Trigger-internal ID of the namespace', "description": "Trigger-internal ID of the namespace",
}, },
'title': { "title": {
'type': 'string', "type": "string",
'description': 'Human-readable title of the namespace', "description": "Human-readable title of the namespace",
}, },
}, },
'required': ['personal', 'score', 'avatar_url', 'id', 'title'], "required": ["personal", "score", "avatar_url", "id", "title"],
}, },
} }
BUILD_SOURCES_SCHEMA = { BUILD_SOURCES_SCHEMA = {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {
'name': { "name": {
'type': 'string', "type": "string",
'description': 'The name of the repository, without its namespace', "description": "The name of the repository, without its namespace",
}, },
'full_name': { "full_name": {
'type': 'string', "type": "string",
'description': 'The name of the repository, with its namespace', "description": "The name of the repository, with its namespace",
}, },
'description': { "description": {
'type': 'string', "type": "string",
'description': 'The description of the repository. May be an empty string', "description": "The description of the repository. May be an empty string",
}, },
'last_updated': { "last_updated": {
'type': 'number', "type": "number",
'description': 'The date/time when the repository was last updated, since epoch in UTC', "description": "The date/time when the repository was last updated, since epoch in UTC",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'The URL at which to view the repository in the browser', "description": "The URL at which to view the repository in the browser",
}, },
'has_admin_permissions': { "has_admin_permissions": {
'type': 'boolean', "type": "boolean",
'description': 'True if the current user has admin permissions on the repository', "description": "True if the current user has admin permissions on the repository",
}, },
'private': { "private": {
'type': 'boolean', "type": "boolean",
'description': 'True if the repository is private', "description": "True if the repository is private",
}, },
}, },
'required': ['name', 'full_name', 'description', 'last_updated', "required": [
'has_admin_permissions', 'private'], "name",
"full_name",
"description",
"last_updated",
"has_admin_permissions",
"private",
],
}, },
} }
METADATA_SCHEMA = { METADATA_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'commit': { "commit": {
'type': 'string', "type": "string",
'description': 'first 7 characters of the SHA-1 identifier for a git commit', "description": "first 7 characters of the SHA-1 identifier for a git commit",
'pattern': '^([A-Fa-f0-9]{7,})$', "pattern": "^([A-Fa-f0-9]{7,})$",
}, },
'git_url': { "git_url": {
'type': 'string', "type": "string",
'description': 'The GIT url to use for the checkout', "description": "The GIT url to use for the checkout",
}, },
'ref': { "ref": {
'type': 'string', "type": "string",
'description': 'git reference for a git commit', "description": "git reference for a git commit",
'pattern': r'^refs\/(heads|tags|remotes)\/(.+)$', "pattern": r"^refs\/(heads|tags|remotes)\/(.+)$",
}, },
'default_branch': { "default_branch": {
'type': 'string', "type": "string",
'description': 'default branch of the git repository', "description": "default branch of the git repository",
}, },
'commit_info': { "commit_info": {
'type': 'object', "type": "object",
'description': 'metadata about a git commit', "description": "metadata about a git commit",
'properties': { "properties": {
'url': { "url": {"type": "string", "description": "URL to view a git commit"},
'type': 'string', "message": {"type": "string", "description": "git commit message"},
'description': 'URL to view a git commit', "date": {"type": "string", "description": "timestamp for a git commit"},
"author": {
"type": "object",
"description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
}, },
'message': { "url": {
'type': 'string', "type": "string",
'description': 'git commit message', "description": "URL to view the profile of the author",
}, },
'date': { "avatar_url": {
'type': 'string', "type": "string",
'description': 'timestamp for a git commit' "description": "URL to view the avatar of the author",
},
'author': {
'type': 'object',
'description': 'metadata about the author of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the author',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
}, },
}, },
'required': ['username'], "required": ["username"],
}, },
'committer': { "committer": {
'type': 'object', "type": "object",
'description': 'metadata about the committer of a git commit', "description": "metadata about the committer of a git commit",
'properties': { "properties": {
'username': { "username": {
'type': 'string', "type": "string",
'description': 'username of the committer', "description": "username of the committer",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'URL to view the profile of the committer', "description": "URL to view the profile of the committer",
}, },
'avatar_url': { "avatar_url": {
'type': 'string', "type": "string",
'description': 'URL to view the avatar of the committer', "description": "URL to view the avatar of the committer",
}, },
}, },
'required': ['username'], "required": ["username"],
}, },
}, },
'required': ['message'], "required": ["message"],
}, },
}, },
'required': ['commit', 'git_url'], "required": ["commit", "git_url"],
} }
@ -289,7 +286,9 @@ class BuildTriggerHandler(object):
if subc.service_name() == trigger.service.name: if subc.service_name() == trigger.service.name:
return subc(trigger, override_config) return subc(trigger, override_config)
raise InvalidServiceException('Unable to find service: %s' % trigger.service.name) raise InvalidServiceException(
"Unable to find service: %s" % trigger.service.name
)
def put_config_key(self, key, value): def put_config_key(self, key, value):
""" Updates a config key in the trigger, saving it to the DB. """ """ Updates a config key in the trigger, saving it to the DB. """
@ -298,13 +297,15 @@ class BuildTriggerHandler(object):
def set_auth_token(self, auth_token): def set_auth_token(self, auth_token):
""" Sets the auth token for the trigger, saving it to the DB. """ """ Sets the auth token for the trigger, saving it to the DB. """
model.build.update_build_trigger(self.trigger, self.config, auth_token=auth_token) model.build.update_build_trigger(
self.trigger, self.config, auth_token=auth_token
)
def get_dockerfile_path(self): def get_dockerfile_path(self):
""" Returns the normalized path to the Dockerfile found in the subdirectory """ Returns the normalized path to the Dockerfile found in the subdirectory
in the config. """ in the config. """
dockerfile_path = self.config.get('dockerfile_path') or 'Dockerfile' dockerfile_path = self.config.get("dockerfile_path") or "Dockerfile"
if dockerfile_path[0] == '/': if dockerfile_path[0] == "/":
dockerfile_path = dockerfile_path[1:] dockerfile_path = dockerfile_path[1:]
return dockerfile_path return dockerfile_path
@ -313,13 +314,13 @@ class BuildTriggerHandler(object):
validate(metadata, METADATA_SCHEMA) validate(metadata, METADATA_SCHEMA)
config = self.config config = self.config
ref = metadata.get('ref', None) ref = metadata.get("ref", None)
commit_sha = metadata['commit'] commit_sha = metadata["commit"]
default_branch = metadata.get('default_branch', None) default_branch = metadata.get("default_branch", None)
prepared = PreparedBuild(self.trigger) prepared = PreparedBuild(self.trigger)
prepared.name_from_sha(commit_sha) prepared.name_from_sha(commit_sha)
prepared.subdirectory = config.get('dockerfile_path', None) prepared.subdirectory = config.get("dockerfile_path", None)
prepared.context = config.get('context', None) prepared.context = config.get("context", None)
prepared.is_manual = is_manual prepared.is_manual = is_manual
prepared.metadata = metadata prepared.metadata = metadata

View file

@ -9,176 +9,168 @@ from jsonschema import validate
from app import app, get_app_url from app import app, get_app_url
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerDeactivationException, TriggerStartException, RepositoryReadException,
InvalidPayloadException, TriggerProviderException, TriggerActivationException,
TriggerDeactivationException,
TriggerStartException,
InvalidPayloadException,
TriggerProviderException,
SkipRequestException, SkipRequestException,
determine_build_ref, raise_if_skipped_build, determine_build_ref,
find_matching_branches) raise_if_skipped_build,
find_matching_branches,
)
from util.dict_wrappers import JSONPathDict, SafeDictSetter from util.dict_wrappers import JSONPathDict, SafeDictSetter
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_BITBUCKET_COMMIT_URL = 'https://bitbucket.org/%s/commits/%s' _BITBUCKET_COMMIT_URL = "https://bitbucket.org/%s/commits/%s"
_RAW_AUTHOR_REGEX = re.compile(r'.*<(.+)>') _RAW_AUTHOR_REGEX = re.compile(r".*<(.+)>")
BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = { BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'repository': { "repository": {
'type': 'object', "type": "object",
'properties': { "properties": {"full_name": {"type": "string"}},
'full_name': { "required": ["full_name"],
'type': 'string',
},
},
'required': ['full_name'],
}, # /Repository }, # /Repository
'push': { "push": {
'type': 'object', "type": "object",
'properties': { "properties": {
'changes': { "changes": {
'type': 'array', "type": "array",
'items': { "items": {
'type': 'object', "type": "object",
'properties': { "properties": {
'new': { "new": {
'type': 'object', "type": "object",
'properties': { "properties": {
'target': { "target": {
'type': 'object', "type": "object",
'properties': { "properties": {
'hash': { "hash": {"type": "string"},
'type': 'string' "message": {"type": "string"},
"date": {"type": "string"},
"author": {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"display_name": {
"type": "string"
}, },
'message': { "account_id": {
'type': 'string' "type": "string"
}, },
'date': { "links": {
'type': 'string' "type": "object",
"properties": {
"avatar": {
"type": "object",
"properties": {
"href": {
"type": "string"
}
}, },
'author': { "required": [
'type': 'object', "href"
'properties': { ],
'user': { }
'type': 'object',
'properties': {
'display_name': {
'type': 'string',
}, },
'account_id': { "required": ["avatar"],
'type': 'string',
},
'links': {
'type': 'object',
'properties': {
'avatar': {
'type': 'object',
'properties': {
'href': {
'type': 'string',
},
},
'required': ['href'],
},
},
'required': ['avatar'],
}, # /User }, # /User
}, },
}, # /Author } # /Author
}, },
}, },
}, },
'required': ['hash', 'message', 'date'], "required": ["hash", "message", "date"],
}, # /Target } # /Target
}, },
'required': ['name', 'target'], "required": ["name", "target"],
}, # /New } # /New
}, },
}, # /Changes item }, # /Changes item
}, # /Changes } # /Changes
}, },
'required': ['changes'], "required": ["changes"],
}, # / Push }, # / Push
}, },
'actor': { "actor": {
'type': 'object', "type": "object",
'properties': { "properties": {
'account_id': { "account_id": {"type": "string"},
'type': 'string', "display_name": {"type": "string"},
"links": {
"type": "object",
"properties": {
"avatar": {
"type": "object",
"properties": {"href": {"type": "string"}},
"required": ["href"],
}
}, },
'display_name': { "required": ["avatar"],
'type': 'string',
},
'links': {
'type': 'object',
'properties': {
'avatar': {
'type': 'object',
'properties': {
'href': {
'type': 'string',
},
},
'required': ['href'],
},
},
'required': ['avatar'],
}, },
}, },
}, # /Actor }, # /Actor
'required': ['push', 'repository'], "required": ["push", "repository"],
} # /Root } # /Root
BITBUCKET_COMMIT_INFO_SCHEMA = { BITBUCKET_COMMIT_INFO_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'node': { "node": {"type": "string"},
'type': 'string', "message": {"type": "string"},
"timestamp": {"type": "string"},
"raw_author": {"type": "string"},
}, },
'message': { "required": ["node", "message", "timestamp"],
'type': 'string',
},
'timestamp': {
'type': 'string',
},
'raw_author': {
'type': 'string',
},
},
'required': ['node', 'message', 'timestamp']
} }
def get_transformed_commit_info(bb_commit, ref, default_branch, repository_name, lookup_author):
def get_transformed_commit_info(
bb_commit, ref, default_branch, repository_name, lookup_author
):
""" Returns the BitBucket commit information transformed into our own """ Returns the BitBucket commit information transformed into our own
payload format. payload format.
""" """
try: try:
validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA) validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
except Exception as exc: except Exception as exc:
logger.exception('Exception when validating Bitbucket commit information: %s from %s', exc.message, bb_commit) logger.exception(
"Exception when validating Bitbucket commit information: %s from %s",
exc.message,
bb_commit,
)
raise InvalidPayloadException(exc.message) raise InvalidPayloadException(exc.message)
commit = JSONPathDict(bb_commit) commit = JSONPathDict(bb_commit)
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = commit['node'] config["commit"] = commit["node"]
config['ref'] = ref config["ref"] = ref
config['default_branch'] = default_branch config["default_branch"] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = _BITBUCKET_COMMIT_URL % (repository_name, commit['node']) config["commit_info.url"] = _BITBUCKET_COMMIT_URL % (
config['commit_info.message'] = commit['message'] repository_name,
config['commit_info.date'] = commit['timestamp'] commit["node"],
)
config["commit_info.message"] = commit["message"]
config["commit_info.date"] = commit["timestamp"]
match = _RAW_AUTHOR_REGEX.match(commit['raw_author']) match = _RAW_AUTHOR_REGEX.match(commit["raw_author"])
if match: if match:
author = lookup_author(match.group(1)) author = lookup_author(match.group(1))
author_info = JSONPathDict(author) if author is not None else None author_info = JSONPathDict(author) if author is not None else None
if author_info: if author_info:
config['commit_info.author.username'] = author_info['user.display_name'] config["commit_info.author.username"] = author_info["user.display_name"]
config['commit_info.author.avatar_url'] = author_info['user.avatar'] config["commit_info.author.avatar_url"] = author_info["user.avatar"]
return config.dict_value() return config.dict_value()
@ -190,36 +182,39 @@ def get_transformed_webhook_payload(bb_payload, default_branch=None):
try: try:
validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA) validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
except Exception as exc: except Exception as exc:
logger.exception('Exception when validating Bitbucket webhook payload: %s from %s', exc.message, logger.exception(
bb_payload) "Exception when validating Bitbucket webhook payload: %s from %s",
exc.message,
bb_payload,
)
raise InvalidPayloadException(exc.message) raise InvalidPayloadException(exc.message)
payload = JSONPathDict(bb_payload) payload = JSONPathDict(bb_payload)
change = payload['push.changes[-1].new'] change = payload["push.changes[-1].new"]
if not change: if not change:
raise SkipRequestException raise SkipRequestException
is_branch = change['type'] == 'branch' is_branch = change["type"] == "branch"
ref = 'refs/heads/' + change['name'] if is_branch else 'refs/tags/' + change['name'] ref = "refs/heads/" + change["name"] if is_branch else "refs/tags/" + change["name"]
repository_name = payload['repository.full_name'] repository_name = payload["repository.full_name"]
target = change['target'] target = change["target"]
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = target['hash'] config["commit"] = target["hash"]
config['ref'] = ref config["ref"] = ref
config['default_branch'] = default_branch config["default_branch"] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = target['links.html.href'] or '' config["commit_info.url"] = target["links.html.href"] or ""
config['commit_info.message'] = target['message'] config["commit_info.message"] = target["message"]
config['commit_info.date'] = target['date'] config["commit_info.date"] = target["date"]
config['commit_info.author.username'] = target['author.user.display_name'] config["commit_info.author.username"] = target["author.user.display_name"]
config['commit_info.author.avatar_url'] = target['author.user.links.avatar.href'] config["commit_info.author.avatar_url"] = target["author.user.links.avatar.href"]
config['commit_info.committer.username'] = payload['actor.display_name'] config["commit_info.committer.username"] = payload["actor.display_name"]
config['commit_info.committer.avatar_url'] = payload['actor.links.avatar.href'] config["commit_info.committer.avatar_url"] = payload["actor.links.avatar.href"]
return config.dict_value() return config.dict_value()
@ -227,43 +222,49 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
""" """
BuildTrigger for Bitbucket. BuildTrigger for Bitbucket.
""" """
@classmethod @classmethod
def service_name(cls): def service_name(cls):
return 'bitbucket' return "bitbucket"
def _get_client(self): def _get_client(self):
""" Returns a BitBucket API client for this trigger's config. """ """ Returns a BitBucket API client for this trigger's config. """
key = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_KEY', '') key = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get("CONSUMER_KEY", "")
secret = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_SECRET', '') secret = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get(
"CONSUMER_SECRET", ""
)
trigger_uuid = self.trigger.uuid trigger_uuid = self.trigger.uuid
callback_url = '%s/oauth1/bitbucket/callback/trigger/%s' % (get_app_url(), trigger_uuid) callback_url = "%s/oauth1/bitbucket/callback/trigger/%s" % (
get_app_url(),
trigger_uuid,
)
return BitBucket(key, secret, callback_url, timeout=15) return BitBucket(key, secret, callback_url, timeout=15)
def _get_authorized_client(self): def _get_authorized_client(self):
""" Returns an authorized API client. """ """ Returns an authorized API client. """
base_client = self._get_client() base_client = self._get_client()
auth_token = self.auth_token or 'invalid:invalid' auth_token = self.auth_token or "invalid:invalid"
token_parts = auth_token.split(':') token_parts = auth_token.split(":")
if len(token_parts) != 2: if len(token_parts) != 2:
token_parts = ['invalid', 'invalid'] token_parts = ["invalid", "invalid"]
(access_token, access_token_secret) = token_parts (access_token, access_token_secret) = token_parts
return base_client.get_authorized_client(access_token, access_token_secret) return base_client.get_authorized_client(access_token, access_token_secret)
def _get_repository_client(self): def _get_repository_client(self):
""" Returns an API client for working with this config's BB repository. """ """ Returns an API client for working with this config's BB repository. """
source = self.config['build_source'] source = self.config["build_source"]
(namespace, name) = source.split('/') (namespace, name) = source.split("/")
bitbucket_client = self._get_authorized_client() bitbucket_client = self._get_authorized_client()
return bitbucket_client.for_namespace(namespace).repositories().get(name) return bitbucket_client.for_namespace(namespace).repositories().get(name)
def _get_default_branch(self, repository, default_value='master'): def _get_default_branch(self, repository, default_value="master"):
""" Returns the default branch for the repository or the value given. """ """ Returns the default branch for the repository or the value given. """
(result, data, _) = repository.get_main_branch() (result, data, _) = repository.get_main_branch()
if result: if result:
return data['name'] return data["name"]
return default_value return default_value
@ -279,16 +280,18 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
def exchange_verifier(self, verifier): def exchange_verifier(self, verifier):
""" Exchanges the given verifier token to setup this trigger. """ """ Exchanges the given verifier token to setup this trigger. """
bitbucket_client = self._get_client() bitbucket_client = self._get_client()
access_token = self.config.get('access_token', '') access_token = self.config.get("access_token", "")
access_token_secret = self.auth_token access_token_secret = self.auth_token
# Exchange the verifier for a new access token. # Exchange the verifier for a new access token.
(result, data, _) = bitbucket_client.verify_token(access_token, access_token_secret, verifier) (result, data, _) = bitbucket_client.verify_token(
access_token, access_token_secret, verifier
)
if not result: if not result:
return False return False
# Save the updated access token and secret. # Save the updated access token and secret.
self.set_auth_token(data[0] + ':' + data[1]) self.set_auth_token(data[0] + ":" + data[1])
# Retrieve the current authorized user's information and store the username in the config. # Retrieve the current authorized user's information and store the username in the config.
authorized_client = self._get_authorized_client() authorized_client = self._get_authorized_client()
@ -296,68 +299,67 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
if not result: if not result:
return False return False
self.put_config_key('account_id', data['user']['account_id']) self.put_config_key("account_id", data["user"]["account_id"])
self.put_config_key('nickname', data['user']['nickname']) self.put_config_key("nickname", data["user"]["nickname"])
return True return True
def is_active(self): def is_active(self):
return 'webhook_id' in self.config return "webhook_id" in self.config
def activate(self, standard_webhook_url): def activate(self, standard_webhook_url):
config = self.config config = self.config
# Add a deploy key to the repository. # Add a deploy key to the repository.
public_key, private_key = generate_ssh_keypair() public_key, private_key = generate_ssh_keypair()
config['credentials'] = [ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
{
'name': 'SSH Public Key',
'value': public_key,
},
]
repository = self._get_repository_client() repository = self._get_repository_client()
(result, created_deploykey, err_msg) = repository.deploykeys().create( (result, created_deploykey, err_msg) = repository.deploykeys().create(
app.config['REGISTRY_TITLE'] + ' webhook key', public_key) app.config["REGISTRY_TITLE"] + " webhook key", public_key
)
if not result: if not result:
msg = 'Unable to add deploy key to repository: %s' % err_msg msg = "Unable to add deploy key to repository: %s" % err_msg
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
config['deploy_key_id'] = created_deploykey['pk'] config["deploy_key_id"] = created_deploykey["pk"]
# Add a webhook callback. # Add a webhook callback.
description = 'Webhook for invoking builds on %s' % app.config['REGISTRY_TITLE_SHORT'] description = (
webhook_events = ['repo:push'] "Webhook for invoking builds on %s" % app.config["REGISTRY_TITLE_SHORT"]
)
webhook_events = ["repo:push"]
(result, created_webhook, err_msg) = repository.webhooks().create( (result, created_webhook, err_msg) = repository.webhooks().create(
description, standard_webhook_url, webhook_events) description, standard_webhook_url, webhook_events
)
if not result: if not result:
msg = 'Unable to add webhook to repository: %s' % err_msg msg = "Unable to add webhook to repository: %s" % err_msg
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
config['webhook_id'] = created_webhook['uuid'] config["webhook_id"] = created_webhook["uuid"]
self.config = config self.config = config
return config, {'private_key': private_key} return config, {"private_key": private_key}
def deactivate(self): def deactivate(self):
config = self.config config = self.config
webhook_id = config.pop('webhook_id', None) webhook_id = config.pop("webhook_id", None)
deploy_key_id = config.pop('deploy_key_id', None) deploy_key_id = config.pop("deploy_key_id", None)
repository = self._get_repository_client() repository = self._get_repository_client()
# Remove the webhook. # Remove the webhook.
if webhook_id is not None: if webhook_id is not None:
(result, _, err_msg) = repository.webhooks().delete(webhook_id) (result, _, err_msg) = repository.webhooks().delete(webhook_id)
if not result: if not result:
msg = 'Unable to remove webhook from repository: %s' % err_msg msg = "Unable to remove webhook from repository: %s" % err_msg
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
# Remove the public key. # Remove the public key.
if deploy_key_id is not None: if deploy_key_id is not None:
(result, _, err_msg) = repository.deploykeys().delete(deploy_key_id) (result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
if not result: if not result:
msg = 'Unable to remove deploy key from repository: %s' % err_msg msg = "Unable to remove deploy key from repository: %s" % err_msg
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
return config return config
@ -366,46 +368,47 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
bitbucket_client = self._get_authorized_client() bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories() (result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result: if not result:
raise RepositoryReadException('Could not read repository list: ' + err_msg) raise RepositoryReadException("Could not read repository list: " + err_msg)
namespaces = {} namespaces = {}
for repo in data: for repo in data:
owner = repo['owner'] owner = repo["owner"]
if owner in namespaces: if owner in namespaces:
namespaces[owner]['score'] = namespaces[owner]['score'] + 1 namespaces[owner]["score"] = namespaces[owner]["score"] + 1
else: else:
namespaces[owner] = { namespaces[owner] = {
'personal': owner == self.config.get('nickname', self.config.get('username')), "personal": owner
'id': owner, == self.config.get("nickname", self.config.get("username")),
'title': owner, "id": owner,
'avatar_url': repo['logo'], "title": owner,
'url': 'https://bitbucket.org/%s' % (owner), "avatar_url": repo["logo"],
'score': 1, "url": "https://bitbucket.org/%s" % (owner),
"score": 1,
} }
return BuildTriggerHandler.build_namespaces_response(namespaces) return BuildTriggerHandler.build_namespaces_response(namespaces)
def list_build_sources_for_namespace(self, namespace): def list_build_sources_for_namespace(self, namespace):
def repo_view(repo): def repo_view(repo):
last_modified = dateutil.parser.parse(repo['utc_last_updated']) last_modified = dateutil.parser.parse(repo["utc_last_updated"])
return { return {
'name': repo['slug'], "name": repo["slug"],
'full_name': '%s/%s' % (repo['owner'], repo['slug']), "full_name": "%s/%s" % (repo["owner"], repo["slug"]),
'description': repo['description'] or '', "description": repo["description"] or "",
'last_updated': timegm(last_modified.utctimetuple()), "last_updated": timegm(last_modified.utctimetuple()),
'url': 'https://bitbucket.org/%s/%s' % (repo['owner'], repo['slug']), "url": "https://bitbucket.org/%s/%s" % (repo["owner"], repo["slug"]),
'has_admin_permissions': repo['read_only'] is False, "has_admin_permissions": repo["read_only"] is False,
'private': repo['is_private'], "private": repo["is_private"],
} }
bitbucket_client = self._get_authorized_client() bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories() (result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result: if not result:
raise RepositoryReadException('Could not read repository list: ' + err_msg) raise RepositoryReadException("Could not read repository list: " + err_msg)
repos = [repo_view(repo) for repo in data if repo['owner'] == namespace] repos = [repo_view(repo) for repo in data if repo["owner"] == namespace]
return BuildTriggerHandler.build_sources_response(repos) return BuildTriggerHandler.build_sources_response(repos)
def list_build_subdirs(self): def list_build_subdirs(self):
@ -413,50 +416,57 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
repository = self._get_repository_client() repository = self._get_repository_client()
# Find the first matching branch. # Find the first matching branch.
repo_branches = self.list_field_values('branch_name') or [] repo_branches = self.list_field_values("branch_name") or []
branches = find_matching_branches(config, repo_branches) branches = find_matching_branches(config, repo_branches)
if not branches: if not branches:
branches = [self._get_default_branch(repository)] branches = [self._get_default_branch(repository)]
(result, data, err_msg) = repository.get_path_contents('', revision=branches[0]) (result, data, err_msg) = repository.get_path_contents("", revision=branches[0])
if not result: if not result:
raise RepositoryReadException(err_msg) raise RepositoryReadException(err_msg)
files = set([f['path'] for f in data['files']]) files = set([f["path"] for f in data["files"]])
return ["/" + file_path for file_path in files if self.filename_is_dockerfile(os.path.basename(file_path))] return [
"/" + file_path
for file_path in files
if self.filename_is_dockerfile(os.path.basename(file_path))
]
def load_dockerfile_contents(self): def load_dockerfile_contents(self):
repository = self._get_repository_client() repository = self._get_repository_client()
path = self.get_dockerfile_path() path = self.get_dockerfile_path()
(result, data, err_msg) = repository.get_raw_path_contents(path, revision='master') (result, data, err_msg) = repository.get_raw_path_contents(
path, revision="master"
)
if not result: if not result:
return None return None
return data return data
def list_field_values(self, field_name, limit=None): def list_field_values(self, field_name, limit=None):
if 'build_source' not in self.config: if "build_source" not in self.config:
return None return None
source = self.config['build_source'] source = self.config["build_source"]
(namespace, name) = source.split('/') (namespace, name) = source.split("/")
bitbucket_client = self._get_authorized_client() bitbucket_client = self._get_authorized_client()
repository = bitbucket_client.for_namespace(namespace).repositories().get(name) repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
if field_name == 'refs': if field_name == "refs":
(result, data, _) = repository.get_branches_and_tags() (result, data, _) = repository.get_branches_and_tags()
if not result: if not result:
return None return None
branches = [b['name'] for b in data['branches']] branches = [b["name"] for b in data["branches"]]
tags = [t['name'] for t in data['tags']] tags = [t["name"] for t in data["tags"]]
return ([{'kind': 'branch', 'name': b} for b in branches] + return [{"kind": "branch", "name": b} for b in branches] + [
[{'kind': 'tag', 'name': tag} for tag in tags]) {"kind": "tag", "name": tag} for tag in tags
]
if field_name == 'tag_name': if field_name == "tag_name":
(result, data, _) = repository.get_tags() (result, data, _) = repository.get_tags()
if not result: if not result:
return None return None
@ -467,7 +477,7 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
return tags return tags
if field_name == 'branch_name': if field_name == "branch_name":
(result, data, _) = repository.get_branches() (result, data, _) = repository.get_branches()
if not result: if not result:
return None return None
@ -481,21 +491,23 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
return None return None
def get_repository_url(self): def get_repository_url(self):
source = self.config['build_source'] source = self.config["build_source"]
(namespace, name) = source.split('/') (namespace, name) = source.split("/")
return 'https://bitbucket.org/%s/%s' % (namespace, name) return "https://bitbucket.org/%s/%s" % (namespace, name)
def handle_trigger_request(self, request): def handle_trigger_request(self, request):
payload = request.get_json() payload = request.get_json()
if payload is None: if payload is None:
raise InvalidPayloadException('Missing payload') raise InvalidPayloadException("Missing payload")
logger.debug('Got BitBucket request: %s', payload) logger.debug("Got BitBucket request: %s", payload)
repository = self._get_repository_client() repository = self._get_repository_client()
default_branch = self._get_default_branch(repository) default_branch = self._get_default_branch(repository)
metadata = get_transformed_webhook_payload(payload, default_branch=default_branch) metadata = get_transformed_webhook_payload(
payload, default_branch=default_branch
)
prepared = self.prepare_build(metadata) prepared = self.prepare_build(metadata)
# Check if we should skip this build. # Check if we should skip this build.
@ -511,17 +523,17 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
# Lookup the commit SHA for the branch. # Lookup the commit SHA for the branch.
(result, data, _) = repository.get_branch(branch_name) (result, data, _) = repository.get_branch(branch_name)
if not result: if not result:
raise TriggerStartException('Could not find branch in repository') raise TriggerStartException("Could not find branch in repository")
return data['target']['hash'] return data["target"]["hash"]
def get_tag_sha(tag_name): def get_tag_sha(tag_name):
# Lookup the commit SHA for the tag. # Lookup the commit SHA for the tag.
(result, data, _) = repository.get_tag(tag_name) (result, data, _) = repository.get_tag(tag_name)
if not result: if not result:
raise TriggerStartException('Could not find tag in repository') raise TriggerStartException("Could not find tag in repository")
return data['target']['hash'] return data["target"]["hash"]
def lookup_author(email_address): def lookup_author(email_address):
(result, data, _) = bitbucket_client.accounts().get_profile(email_address) (result, data, _) = bitbucket_client.accounts().get_profile(email_address)
@ -529,17 +541,19 @@ class BitbucketBuildTrigger(BuildTriggerHandler):
# Find the branch or tag to build. # Find the branch or tag to build.
default_branch = self._get_default_branch(repository) default_branch = self._get_default_branch(repository)
(commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, (commit_sha, ref) = determine_build_ref(
default_branch) run_parameters, get_branch_sha, get_tag_sha, default_branch
)
# Lookup the commit SHA in BitBucket. # Lookup the commit SHA in BitBucket.
(result, commit_info, _) = repository.changesets().get(commit_sha) (result, commit_info, _) = repository.changesets().get(commit_sha)
if not result: if not result:
raise TriggerStartException('Could not lookup commit SHA') raise TriggerStartException("Could not lookup commit SHA")
# Return a prepared build for the commit. # Return a prepared build for the commit.
repository_name = '%s/%s' % (repository.namespace, repository.repository_name) repository_name = "%s/%s" % (repository.namespace, repository.repository_name)
metadata = get_transformed_commit_info(commit_info, ref, default_branch, metadata = get_transformed_commit_info(
repository_name, lookup_author) commit_info, ref, default_branch, repository_name, lookup_author
)
return self.prepare_build(metadata, is_manual=True) return self.prepare_build(metadata, is_manual=True)

View file

@ -2,22 +2,33 @@ import logging
import json import json
from jsonschema import validate, ValidationError from jsonschema import validate, ValidationError
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerStartException, ValidationRequestException, RepositoryReadException,
TriggerActivationException,
TriggerStartException,
ValidationRequestException,
InvalidPayloadException, InvalidPayloadException,
SkipRequestException, raise_if_skipped_build, SkipRequestException,
find_matching_branches) raise_if_skipped_build,
find_matching_branches,
)
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.bitbuckethandler import (BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema, from buildtrigger.bitbuckethandler import (
get_transformed_webhook_payload as bb_payload) BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
get_transformed_webhook_payload as bb_payload,
)
from buildtrigger.githubhandler import (GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema, from buildtrigger.githubhandler import (
get_transformed_webhook_payload as gh_payload) GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
get_transformed_webhook_payload as gh_payload,
)
from buildtrigger.gitlabhandler import (GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema, from buildtrigger.gitlabhandler import (
get_transformed_webhook_payload as gl_payload) GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
get_transformed_webhook_payload as gl_payload,
)
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
@ -49,7 +60,7 @@ def custom_trigger_payload(metadata, git_url):
continue continue
result = handler(metadata) result = handler(metadata)
result['git_url'] = git_url result["git_url"] = git_url
return result return result
# If we have reached this point and no other schemas validated, then raise the error for the # If we have reached this point and no other schemas validated, then raise the error for the
@ -57,95 +68,92 @@ def custom_trigger_payload(metadata, git_url):
if custom_handler_validation_error is not None: if custom_handler_validation_error is not None:
raise InvalidPayloadException(custom_handler_validation_error.message) raise InvalidPayloadException(custom_handler_validation_error.message)
metadata['git_url'] = git_url metadata["git_url"] = git_url
return metadata return metadata
class CustomBuildTrigger(BuildTriggerHandler): class CustomBuildTrigger(BuildTriggerHandler):
payload_schema = { payload_schema = {
'type': 'object', "type": "object",
'properties': { "properties": {
'commit': { "commit": {
'type': 'string', "type": "string",
'description': 'first 7 characters of the SHA-1 identifier for a git commit', "description": "first 7 characters of the SHA-1 identifier for a git commit",
'pattern': '^([A-Fa-f0-9]{7,})$', "pattern": "^([A-Fa-f0-9]{7,})$",
}, },
'ref': { "ref": {
'type': 'string', "type": "string",
'description': 'git reference for a git commit', "description": "git reference for a git commit",
'pattern': '^refs\/(heads|tags|remotes)\/(.+)$', "pattern": "^refs\/(heads|tags|remotes)\/(.+)$",
}, },
'default_branch': { "default_branch": {
'type': 'string', "type": "string",
'description': 'default branch of the git repository', "description": "default branch of the git repository",
}, },
'commit_info': { "commit_info": {
'type': 'object', "type": "object",
'description': 'metadata about a git commit', "description": "metadata about a git commit",
'properties': { "properties": {
'url': { "url": {
'type': 'string', "type": "string",
'description': 'URL to view a git commit', "description": "URL to view a git commit",
}, },
'message': { "message": {"type": "string", "description": "git commit message"},
'type': 'string', "date": {
'description': 'git commit message', "type": "string",
"description": "timestamp for a git commit",
}, },
'date': { "author": {
'type': 'string', "type": "object",
'description': 'timestamp for a git commit' "description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
}, },
'author': { "url": {
'type': 'object', "type": "string",
'description': 'metadata about the author of a git commit', "description": "URL to view the profile of the author",
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
}, },
'url': { "avatar_url": {
'type': 'string', "type": "string",
'description': 'URL to view the profile of the author', "description": "URL to view the avatar of the author",
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
}, },
}, },
'required': ['username', 'url', 'avatar_url'], "required": ["username", "url", "avatar_url"],
}, },
'committer': { "committer": {
'type': 'object', "type": "object",
'description': 'metadata about the committer of a git commit', "description": "metadata about the committer of a git commit",
'properties': { "properties": {
'username': { "username": {
'type': 'string', "type": "string",
'description': 'username of the committer', "description": "username of the committer",
}, },
'url': { "url": {
'type': 'string', "type": "string",
'description': 'URL to view the profile of the committer', "description": "URL to view the profile of the committer",
}, },
'avatar_url': { "avatar_url": {
'type': 'string', "type": "string",
'description': 'URL to view the avatar of the committer', "description": "URL to view the avatar of the committer",
}, },
}, },
'required': ['username', 'url', 'avatar_url'], "required": ["username", "url", "avatar_url"],
}, },
}, },
'required': ['url', 'message', 'date'], "required": ["url", "message", "date"],
}, },
}, },
'required': ['commit', 'ref', 'default_branch'], "required": ["commit", "ref", "default_branch"],
} }
@classmethod @classmethod
def service_name(cls): def service_name(cls):
return 'custom-git' return "custom-git"
def is_active(self): def is_active(self):
return self.config.has_key('credentials') return self.config.has_key("credentials")
def _metadata_from_payload(self, payload, git_url): def _metadata_from_payload(self, payload, git_url):
# Parse the JSON payload. # Parse the JSON payload.
@ -159,11 +167,11 @@ class CustomBuildTrigger(BuildTriggerHandler):
def handle_trigger_request(self, request): def handle_trigger_request(self, request):
payload = request.data payload = request.data
if not payload: if not payload:
raise InvalidPayloadException('Missing expected payload') raise InvalidPayloadException("Missing expected payload")
logger.debug('Payload %s', payload) logger.debug("Payload %s", payload)
metadata = self._metadata_from_payload(payload, self.config['build_source']) metadata = self._metadata_from_payload(payload, self.config["build_source"])
prepared = self.prepare_build(metadata) prepared = self.prepare_build(metadata)
# Check if we should skip this build. # Check if we should skip this build.
@ -173,15 +181,12 @@ class CustomBuildTrigger(BuildTriggerHandler):
def manual_start(self, run_parameters=None): def manual_start(self, run_parameters=None):
# commit_sha is the only required parameter # commit_sha is the only required parameter
commit_sha = run_parameters.get('commit_sha') commit_sha = run_parameters.get("commit_sha")
if commit_sha is None: if commit_sha is None:
raise TriggerStartException('missing required parameter') raise TriggerStartException("missing required parameter")
config = self.config config = self.config
metadata = { metadata = {"commit": commit_sha, "git_url": config["build_source"]}
'commit': commit_sha,
'git_url': config['build_source'],
}
try: try:
return self.prepare_build(metadata, is_manual=True) return self.prepare_build(metadata, is_manual=True)
@ -191,22 +196,16 @@ class CustomBuildTrigger(BuildTriggerHandler):
def activate(self, standard_webhook_url): def activate(self, standard_webhook_url):
config = self.config config = self.config
public_key, private_key = generate_ssh_keypair() public_key, private_key = generate_ssh_keypair()
config['credentials'] = [ config["credentials"] = [
{ {"name": "SSH Public Key", "value": public_key},
'name': 'SSH Public Key', {"name": "Webhook Endpoint URL", "value": standard_webhook_url},
'value': public_key,
},
{
'name': 'Webhook Endpoint URL',
'value': standard_webhook_url,
},
] ]
self.config = config self.config = config
return config, {'private_key': private_key} return config, {"private_key": private_key}
def deactivate(self): def deactivate(self):
config = self.config config = self.config
config.pop('credentials', None) config.pop("credentials", None)
self.config = config self.config = config
return config return config

View file

@ -7,18 +7,29 @@ from calendar import timegm
from functools import wraps from functools import wraps
from ssl import SSLError from ssl import SSLError
from github import (Github, UnknownObjectException, GithubException, from github import (
BadCredentialsException as GitHubBadCredentialsException) Github,
UnknownObjectException,
GithubException,
BadCredentialsException as GitHubBadCredentialsException,
)
from jsonschema import validate from jsonschema import validate
from app import app, github_trigger from app import app, github_trigger
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerDeactivationException, TriggerStartException, RepositoryReadException,
EmptyRepositoryException, ValidationRequestException, TriggerActivationException,
SkipRequestException, InvalidPayloadException, TriggerDeactivationException,
determine_build_ref, raise_if_skipped_build, TriggerStartException,
find_matching_branches) EmptyRepositoryException,
ValidationRequestException,
SkipRequestException,
InvalidPayloadException,
determine_build_ref,
raise_if_skipped_build,
find_matching_branches,
)
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from endpoints.exception import ExternalServiceError from endpoints.exception import ExternalServiceError
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
@ -27,70 +38,45 @@ from util.dict_wrappers import JSONPathDict, SafeDictSetter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GITHUB_WEBHOOK_PAYLOAD_SCHEMA = { GITHUB_WEBHOOK_PAYLOAD_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'ref': { "ref": {"type": "string"},
'type': 'string', "head_commit": {
"type": ["object", "null"],
"properties": {
"id": {"type": "string"},
"url": {"type": "string"},
"message": {"type": "string"},
"timestamp": {"type": "string"},
"author": {
"type": "object",
"properties": {
"username": {"type": "string"},
"html_url": {"type": "string"},
"avatar_url": {"type": "string"},
}, },
'head_commit': {
'type': ['object', 'null'],
'properties': {
'id': {
'type': 'string',
}, },
'url': { "committer": {
'type': 'string', "type": "object",
}, "properties": {
'message': { "username": {"type": "string"},
'type': 'string', "html_url": {"type": "string"},
}, "avatar_url": {"type": "string"},
'timestamp': {
'type': 'string',
},
'author': {
'type': 'object',
'properties': {
'username': {
'type': 'string'
},
'html_url': {
'type': 'string'
},
'avatar_url': {
'type': 'string'
}, },
}, },
}, },
'committer': { "required": ["id", "url", "message", "timestamp"],
'type': 'object',
'properties': {
'username': {
'type': 'string'
}, },
'html_url': { "repository": {
'type': 'string' "type": "object",
}, "properties": {"ssh_url": {"type": "string"}},
'avatar_url': { "required": ["ssh_url"],
'type': 'string'
}, },
}, },
}, "required": ["ref", "head_commit", "repository"],
},
'required': ['id', 'url', 'message', 'timestamp'],
},
'repository': {
'type': 'object',
'properties': {
'ssh_url': {
'type': 'string',
},
},
'required': ['ssh_url'],
},
},
'required': ['ref', 'head_commit', 'repository'],
} }
def get_transformed_webhook_payload(gh_payload, default_branch=None, lookup_user=None): def get_transformed_webhook_payload(gh_payload, default_branch=None, lookup_user=None):
""" Returns the GitHub webhook JSON payload transformed into our own payload """ Returns the GitHub webhook JSON payload transformed into our own payload
format. If the gh_payload is not valid, returns None. format. If the gh_payload is not valid, returns None.
@ -102,43 +88,54 @@ def get_transformed_webhook_payload(gh_payload, default_branch=None, lookup_user
payload = JSONPathDict(gh_payload) payload = JSONPathDict(gh_payload)
if payload['head_commit'] is None: if payload["head_commit"] is None:
raise SkipRequestException raise SkipRequestException
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = payload['head_commit.id'] config["commit"] = payload["head_commit.id"]
config['ref'] = payload['ref'] config["ref"] = payload["ref"]
config['default_branch'] = payload['repository.default_branch'] or default_branch config["default_branch"] = payload["repository.default_branch"] or default_branch
config['git_url'] = payload['repository.ssh_url'] config["git_url"] = payload["repository.ssh_url"]
config['commit_info.url'] = payload['head_commit.url'] config["commit_info.url"] = payload["head_commit.url"]
config['commit_info.message'] = payload['head_commit.message'] config["commit_info.message"] = payload["head_commit.message"]
config['commit_info.date'] = payload['head_commit.timestamp'] config["commit_info.date"] = payload["head_commit.timestamp"]
config['commit_info.author.username'] = payload['head_commit.author.username'] config["commit_info.author.username"] = payload["head_commit.author.username"]
config['commit_info.author.url'] = payload.get('head_commit.author.html_url') config["commit_info.author.url"] = payload.get("head_commit.author.html_url")
config['commit_info.author.avatar_url'] = payload.get('head_commit.author.avatar_url') config["commit_info.author.avatar_url"] = payload.get(
"head_commit.author.avatar_url"
)
config['commit_info.committer.username'] = payload.get('head_commit.committer.username') config["commit_info.committer.username"] = payload.get(
config['commit_info.committer.url'] = payload.get('head_commit.committer.html_url') "head_commit.committer.username"
config['commit_info.committer.avatar_url'] = payload.get('head_commit.committer.avatar_url') )
config["commit_info.committer.url"] = payload.get("head_commit.committer.html_url")
config["commit_info.committer.avatar_url"] = payload.get(
"head_commit.committer.avatar_url"
)
# Note: GitHub doesn't always return the extra information for users, so we do the lookup # Note: GitHub doesn't always return the extra information for users, so we do the lookup
# manually if possible. # manually if possible.
if (lookup_user and not payload.get('head_commit.author.html_url') and if (
payload.get('head_commit.author.username')): lookup_user
author_info = lookup_user(payload['head_commit.author.username']) and not payload.get("head_commit.author.html_url")
and payload.get("head_commit.author.username")
):
author_info = lookup_user(payload["head_commit.author.username"])
if author_info: if author_info:
config['commit_info.author.url'] = author_info['html_url'] config["commit_info.author.url"] = author_info["html_url"]
config['commit_info.author.avatar_url'] = author_info['avatar_url'] config["commit_info.author.avatar_url"] = author_info["avatar_url"]
if (lookup_user and if (
payload.get('head_commit.committer.username') and lookup_user
not payload.get('head_commit.committer.html_url')): and payload.get("head_commit.committer.username")
committer_info = lookup_user(payload['head_commit.committer.username']) and not payload.get("head_commit.committer.html_url")
):
committer_info = lookup_user(payload["head_commit.committer.username"])
if committer_info: if committer_info:
config['commit_info.committer.url'] = committer_info['html_url'] config["commit_info.committer.url"] = committer_info["html_url"]
config['commit_info.committer.avatar_url'] = committer_info['avatar_url'] config["commit_info.committer.avatar_url"] = committer_info["avatar_url"]
return config.dict_value() return config.dict_value()
@ -149,9 +146,10 @@ def _catch_ssl_errors(func):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except SSLError as se: except SSLError as se:
msg = 'Request to the GitHub API failed: %s' % se.message msg = "Request to the GitHub API failed: %s" % se.message
logger.exception(msg) logger.exception(msg)
raise ExternalServiceError(msg) raise ExternalServiceError(msg)
return wrapper return wrapper
@ -159,79 +157,79 @@ class GithubBuildTrigger(BuildTriggerHandler):
""" """
BuildTrigger for GitHub that uses the archive API and buildpacks. BuildTrigger for GitHub that uses the archive API and buildpacks.
""" """
def _get_client(self): def _get_client(self):
""" Returns an authenticated client for talking to the GitHub API. """ """ Returns an authenticated client for talking to the GitHub API. """
return Github(self.auth_token, return Github(
self.auth_token,
base_url=github_trigger.api_endpoint(), base_url=github_trigger.api_endpoint(),
client_id=github_trigger.client_id(), client_id=github_trigger.client_id(),
client_secret=github_trigger.client_secret(), client_secret=github_trigger.client_secret(),
timeout=5) timeout=5,
)
@classmethod @classmethod
def service_name(cls): def service_name(cls):
return 'github' return "github"
def is_active(self): def is_active(self):
return 'hook_id' in self.config return "hook_id" in self.config
def get_repository_url(self): def get_repository_url(self):
source = self.config['build_source'] source = self.config["build_source"]
return github_trigger.get_public_url(source) return github_trigger.get_public_url(source)
@staticmethod @staticmethod
def _get_error_message(ghe, default_msg): def _get_error_message(ghe, default_msg):
if ghe.data.get('errors') and ghe.data['errors'][0].get('message'): if ghe.data.get("errors") and ghe.data["errors"][0].get("message"):
return ghe.data['errors'][0]['message'] return ghe.data["errors"][0]["message"]
return default_msg return default_msg
@_catch_ssl_errors @_catch_ssl_errors
def activate(self, standard_webhook_url): def activate(self, standard_webhook_url):
config = self.config config = self.config
new_build_source = config['build_source'] new_build_source = config["build_source"]
gh_client = self._get_client() gh_client = self._get_client()
# Find the GitHub repository. # Find the GitHub repository.
try: try:
gh_repo = gh_client.get_repo(new_build_source) gh_repo = gh_client.get_repo(new_build_source)
except UnknownObjectException: except UnknownObjectException:
msg = 'Unable to find GitHub repository for source: %s' % new_build_source msg = "Unable to find GitHub repository for source: %s" % new_build_source
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
# Add a deploy key to the GitHub repository. # Add a deploy key to the GitHub repository.
public_key, private_key = generate_ssh_keypair() public_key, private_key = generate_ssh_keypair()
config['credentials'] = [ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
{
'name': 'SSH Public Key',
'value': public_key,
},
]
try: try:
deploy_key = gh_repo.create_key('%s Builder' % app.config['REGISTRY_TITLE'], deploy_key = gh_repo.create_key(
public_key) "%s Builder" % app.config["REGISTRY_TITLE"], public_key
config['deploy_key_id'] = deploy_key.id )
config["deploy_key_id"] = deploy_key.id
except GithubException as ghe: except GithubException as ghe:
default_msg = 'Unable to add deploy key to repository: %s' % new_build_source default_msg = (
"Unable to add deploy key to repository: %s" % new_build_source
)
msg = GithubBuildTrigger._get_error_message(ghe, default_msg) msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
# Add the webhook to the GitHub repository. # Add the webhook to the GitHub repository.
webhook_config = { webhook_config = {"url": standard_webhook_url, "content_type": "json"}
'url': standard_webhook_url,
'content_type': 'json',
}
try: try:
hook = gh_repo.create_hook('web', webhook_config) hook = gh_repo.create_hook("web", webhook_config)
config['hook_id'] = hook.id config["hook_id"] = hook.id
config['master_branch'] = gh_repo.default_branch config["master_branch"] = gh_repo.default_branch
except GithubException as ghe: except GithubException as ghe:
default_msg = 'Unable to create webhook on repository: %s' % new_build_source default_msg = (
"Unable to create webhook on repository: %s" % new_build_source
)
msg = GithubBuildTrigger._get_error_message(ghe, default_msg) msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
return config, {'private_key': private_key} return config, {"private_key": private_key}
@_catch_ssl_errors @_catch_ssl_errors
def deactivate(self): def deactivate(self):
@ -240,38 +238,41 @@ class GithubBuildTrigger(BuildTriggerHandler):
# Find the GitHub repository. # Find the GitHub repository.
try: try:
repo = gh_client.get_repo(config['build_source']) repo = gh_client.get_repo(config["build_source"])
except UnknownObjectException: except UnknownObjectException:
msg = 'Unable to find GitHub repository for source: %s' % config['build_source'] msg = (
"Unable to find GitHub repository for source: %s"
% config["build_source"]
)
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
except GitHubBadCredentialsException: except GitHubBadCredentialsException:
msg = 'Unable to access repository to disable trigger' msg = "Unable to access repository to disable trigger"
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
# If the trigger uses a deploy key, remove it. # If the trigger uses a deploy key, remove it.
try: try:
if config['deploy_key_id']: if config["deploy_key_id"]:
deploy_key = repo.get_key(config['deploy_key_id']) deploy_key = repo.get_key(config["deploy_key_id"])
deploy_key.delete() deploy_key.delete()
except KeyError: except KeyError:
# There was no config['deploy_key_id'], thus this is an old trigger without a deploy key. # There was no config['deploy_key_id'], thus this is an old trigger without a deploy key.
pass pass
except GithubException as ghe: except GithubException as ghe:
default_msg = 'Unable to remove deploy key: %s' % config['deploy_key_id'] default_msg = "Unable to remove deploy key: %s" % config["deploy_key_id"]
msg = GithubBuildTrigger._get_error_message(ghe, default_msg) msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
# Remove the webhook. # Remove the webhook.
if 'hook_id' in config: if "hook_id" in config:
try: try:
hook = repo.get_hook(config['hook_id']) hook = repo.get_hook(config["hook_id"])
hook.delete() hook.delete()
except GithubException as ghe: except GithubException as ghe:
default_msg = 'Unable to remove hook: %s' % config['hook_id'] default_msg = "Unable to remove hook: %s" % config["hook_id"]
msg = GithubBuildTrigger._get_error_message(ghe, default_msg) msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
raise TriggerDeactivationException(msg) raise TriggerDeactivationException(msg)
config.pop('hook_id', None) config.pop("hook_id", None)
self.config = config self.config = config
return config return config
@ -283,12 +284,12 @@ class GithubBuildTrigger(BuildTriggerHandler):
# Build the full set of namespaces for the user, starting with their own. # Build the full set of namespaces for the user, starting with their own.
namespaces = {} namespaces = {}
namespaces[usr.login] = { namespaces[usr.login] = {
'personal': True, "personal": True,
'id': usr.login, "id": usr.login,
'title': usr.name or usr.login, "title": usr.name or usr.login,
'avatar_url': usr.avatar_url, "avatar_url": usr.avatar_url,
'url': usr.html_url, "url": usr.html_url,
'score': usr.plan.private_repos if usr.plan else 0, "score": usr.plan.private_repos if usr.plan else 0,
} }
for org in usr.get_orgs(): for org in usr.get_orgs():
@ -299,12 +300,12 @@ class GithubBuildTrigger(BuildTriggerHandler):
# loop, which was massively slowing down the load time for users when setting # loop, which was massively slowing down the load time for users when setting
# up triggers. # up triggers.
namespaces[organization] = { namespaces[organization] = {
'personal': False, "personal": False,
'id': organization, "id": organization,
'title': organization, "title": organization,
'avatar_url': org.avatar_url, "avatar_url": org.avatar_url,
'url': '', "url": "",
'score': 0, "score": 0,
} }
return BuildTriggerHandler.build_namespaces_response(namespaces) return BuildTriggerHandler.build_namespaces_response(namespaces)
@ -313,19 +314,23 @@ class GithubBuildTrigger(BuildTriggerHandler):
def list_build_sources_for_namespace(self, namespace): def list_build_sources_for_namespace(self, namespace):
def repo_view(repo): def repo_view(repo):
return { return {
'name': repo.name, "name": repo.name,
'full_name': repo.full_name, "full_name": repo.full_name,
'description': repo.description or '', "description": repo.description or "",
'last_updated': timegm(repo.pushed_at.utctimetuple()) if repo.pushed_at else 0, "last_updated": timegm(repo.pushed_at.utctimetuple())
'url': repo.html_url, if repo.pushed_at
'has_admin_permissions': repo.permissions.admin, else 0,
'private': repo.private, "url": repo.html_url,
"has_admin_permissions": repo.permissions.admin,
"private": repo.private,
} }
gh_client = self._get_client() gh_client = self._get_client()
usr = gh_client.get_user() usr = gh_client.get_user()
if namespace == usr.login: if namespace == usr.login:
repos = [repo_view(repo) for repo in usr.get_repos(type='owner', sort='updated')] repos = [
repo_view(repo) for repo in usr.get_repos(type="owner", sort="updated")
]
return BuildTriggerHandler.build_sources_response(repos) return BuildTriggerHandler.build_sources_response(repos)
try: try:
@ -335,31 +340,38 @@ class GithubBuildTrigger(BuildTriggerHandler):
except GithubException: except GithubException:
return [] return []
repos = [repo_view(repo) for repo in org.get_repos(type='member')] repos = [repo_view(repo) for repo in org.get_repos(type="member")]
return BuildTriggerHandler.build_sources_response(repos) return BuildTriggerHandler.build_sources_response(repos)
@_catch_ssl_errors @_catch_ssl_errors
def list_build_subdirs(self): def list_build_subdirs(self):
config = self.config config = self.config
gh_client = self._get_client() gh_client = self._get_client()
source = config['build_source'] source = config["build_source"]
try: try:
repo = gh_client.get_repo(source) repo = gh_client.get_repo(source)
# Find the first matching branch. # Find the first matching branch.
repo_branches = self.list_field_values('branch_name') or [] repo_branches = self.list_field_values("branch_name") or []
branches = find_matching_branches(config, repo_branches) branches = find_matching_branches(config, repo_branches)
branches = branches or [repo.default_branch or 'master'] branches = branches or [repo.default_branch or "master"]
default_commit = repo.get_branch(branches[0]).commit default_commit = repo.get_branch(branches[0]).commit
commit_tree = repo.get_git_tree(default_commit.sha, recursive=True) commit_tree = repo.get_git_tree(default_commit.sha, recursive=True)
return [elem.path for elem in commit_tree.tree return [
if (elem.type == u'blob' and self.filename_is_dockerfile(os.path.basename(elem.path)))] elem.path
for elem in commit_tree.tree
if (
elem.type == u"blob"
and self.filename_is_dockerfile(os.path.basename(elem.path))
)
]
except GithubException as ghe: except GithubException as ghe:
message = ghe.data.get('message', 'Unable to list contents of repository: %s' % source) message = ghe.data.get(
if message == 'Branch not found': "message", "Unable to list contents of repository: %s" % source
)
if message == "Branch not found":
raise EmptyRepositoryException() raise EmptyRepositoryException()
raise RepositoryReadException(message) raise RepositoryReadException(message)
@ -368,12 +380,14 @@ class GithubBuildTrigger(BuildTriggerHandler):
def load_dockerfile_contents(self): def load_dockerfile_contents(self):
config = self.config config = self.config
gh_client = self._get_client() gh_client = self._get_client()
source = config['build_source'] source = config["build_source"]
try: try:
repo = gh_client.get_repo(source) repo = gh_client.get_repo(source)
except GithubException as ghe: except GithubException as ghe:
message = ghe.data.get('message', 'Unable to list contents of repository: %s' % source) message = ghe.data.get(
"message", "Unable to list contents of repository: %s" % source
)
raise RepositoryReadException(message) raise RepositoryReadException(message)
path = self.get_dockerfile_path() path = self.get_dockerfile_path()
@ -394,25 +408,26 @@ class GithubBuildTrigger(BuildTriggerHandler):
return None return None
content = file_info.content content = file_info.content
if file_info.encoding == 'base64': if file_info.encoding == "base64":
content = base64.b64decode(content) content = base64.b64decode(content)
return content return content
@_catch_ssl_errors @_catch_ssl_errors
def list_field_values(self, field_name, limit=None): def list_field_values(self, field_name, limit=None):
if field_name == 'refs': if field_name == "refs":
branches = self.list_field_values('branch_name') branches = self.list_field_values("branch_name")
tags = self.list_field_values('tag_name') tags = self.list_field_values("tag_name")
return ([{'kind': 'branch', 'name': b} for b in branches] + return [{"kind": "branch", "name": b} for b in branches] + [
[{'kind': 'tag', 'name': tag} for tag in tags]) {"kind": "tag", "name": tag} for tag in tags
]
config = self.config config = self.config
source = config.get('build_source') source = config.get("build_source")
if source is None: if source is None:
return [] return []
if field_name == 'tag_name': if field_name == "tag_name":
try: try:
gh_client = self._get_client() gh_client = self._get_client()
repo = gh_client.get_repo(source) repo = gh_client.get_repo(source)
@ -424,11 +439,13 @@ class GithubBuildTrigger(BuildTriggerHandler):
except GitHubBadCredentialsException: except GitHubBadCredentialsException:
return [] return []
except GithubException: except GithubException:
logger.exception("Got GitHub Exception when trying to list tags for trigger %s", logger.exception(
self.trigger.id) "Got GitHub Exception when trying to list tags for trigger %s",
self.trigger.id,
)
return [] return []
if field_name == 'branch_name': if field_name == "branch_name":
try: try:
gh_client = self._get_client() gh_client = self._get_client()
repo = gh_client.get_repo(source) repo = gh_client.get_repo(source)
@ -447,11 +464,13 @@ class GithubBuildTrigger(BuildTriggerHandler):
return branches return branches
except GitHubBadCredentialsException: except GitHubBadCredentialsException:
return ['master'] return ["master"]
except GithubException: except GithubException:
logger.exception("Got GitHub Exception when trying to list branches for trigger %s", logger.exception(
self.trigger.id) "Got GitHub Exception when trying to list branches for trigger %s",
return ['master'] self.trigger.id,
)
return ["master"]
return None return None
@ -460,48 +479,50 @@ class GithubBuildTrigger(BuildTriggerHandler):
try: try:
commit = repo.get_commit(commit_sha) commit = repo.get_commit(commit_sha)
except GithubException: except GithubException:
logger.exception('Could not load commit information from GitHub') logger.exception("Could not load commit information from GitHub")
return None return None
commit_info = { commit_info = {
'url': commit.html_url, "url": commit.html_url,
'message': commit.commit.message, "message": commit.commit.message,
'date': commit.last_modified "date": commit.last_modified,
} }
if commit.author: if commit.author:
commit_info['author'] = { commit_info["author"] = {
'username': commit.author.login, "username": commit.author.login,
'avatar_url': commit.author.avatar_url, "avatar_url": commit.author.avatar_url,
'url': commit.author.html_url "url": commit.author.html_url,
} }
if commit.committer: if commit.committer:
commit_info['committer'] = { commit_info["committer"] = {
'username': commit.committer.login, "username": commit.committer.login,
'avatar_url': commit.committer.avatar_url, "avatar_url": commit.committer.avatar_url,
'url': commit.committer.html_url "url": commit.committer.html_url,
} }
return { return {
'commit': commit_sha, "commit": commit_sha,
'ref': ref, "ref": ref,
'default_branch': repo.default_branch, "default_branch": repo.default_branch,
'git_url': repo.ssh_url, "git_url": repo.ssh_url,
'commit_info': commit_info "commit_info": commit_info,
} }
@_catch_ssl_errors @_catch_ssl_errors
def manual_start(self, run_parameters=None): def manual_start(self, run_parameters=None):
config = self.config config = self.config
source = config['build_source'] source = config["build_source"]
try: try:
gh_client = self._get_client() gh_client = self._get_client()
repo = gh_client.get_repo(source) repo = gh_client.get_repo(source)
default_branch = repo.default_branch default_branch = repo.default_branch
except GithubException as ghe: except GithubException as ghe:
msg = GithubBuildTrigger._get_error_message(ghe, 'Unable to start build trigger') msg = GithubBuildTrigger._get_error_message(
ghe, "Unable to start build trigger"
)
raise TriggerStartException(msg) raise TriggerStartException(msg)
def get_branch_sha(branch_name): def get_branch_sha(branch_name):
@ -509,18 +530,19 @@ class GithubBuildTrigger(BuildTriggerHandler):
branch = repo.get_branch(branch_name) branch = repo.get_branch(branch_name)
return branch.commit.sha return branch.commit.sha
except GithubException: except GithubException:
raise TriggerStartException('Could not find branch in repository') raise TriggerStartException("Could not find branch in repository")
def get_tag_sha(tag_name): def get_tag_sha(tag_name):
tags = {tag.name: tag for tag in repo.get_tags()} tags = {tag.name: tag for tag in repo.get_tags()}
if not tag_name in tags: if not tag_name in tags:
raise TriggerStartException('Could not find tag in repository') raise TriggerStartException("Could not find tag in repository")
return tags[tag_name].commit.sha return tags[tag_name].commit.sha
# Find the branch or tag to build. # Find the branch or tag to build.
(commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, (commit_sha, ref) = determine_build_ref(
default_branch) run_parameters, get_branch_sha, get_tag_sha, default_branch
)
metadata = GithubBuildTrigger._build_metadata_for_commit(commit_sha, ref, repo) metadata = GithubBuildTrigger._build_metadata_for_commit(commit_sha, ref, repo)
return self.prepare_build(metadata, is_manual=True) return self.prepare_build(metadata, is_manual=True)
@ -530,10 +552,7 @@ class GithubBuildTrigger(BuildTriggerHandler):
try: try:
gh_client = self._get_client() gh_client = self._get_client()
user = gh_client.get_user(username) user = gh_client.get_user(username)
return { return {"html_url": user.html_url, "avatar_url": user.avatar_url}
'html_url': user.html_url,
'avatar_url': user.avatar_url
}
except GithubException: except GithubException:
return None return None
@ -542,44 +561,51 @@ class GithubBuildTrigger(BuildTriggerHandler):
# Check the payload to see if we should skip it based on the lack of a head_commit. # Check the payload to see if we should skip it based on the lack of a head_commit.
payload = request.get_json() payload = request.get_json()
if payload is None: if payload is None:
raise InvalidPayloadException('Missing payload') raise InvalidPayloadException("Missing payload")
# This is for GitHub's probing/testing. # This is for GitHub's probing/testing.
if 'zen' in payload: if "zen" in payload:
raise SkipRequestException() raise SkipRequestException()
# Lookup the default branch for the repository. # Lookup the default branch for the repository.
if 'repository' not in payload: if "repository" not in payload:
raise InvalidPayloadException("Missing 'repository' on request") raise InvalidPayloadException("Missing 'repository' on request")
if 'owner' not in payload['repository']: if "owner" not in payload["repository"]:
raise InvalidPayloadException("Missing 'owner' on repository") raise InvalidPayloadException("Missing 'owner' on repository")
if 'name' not in payload['repository']['owner']: if "name" not in payload["repository"]["owner"]:
raise InvalidPayloadException("Missing owner 'name' on repository") raise InvalidPayloadException("Missing owner 'name' on repository")
if 'name' not in payload['repository']: if "name" not in payload["repository"]:
raise InvalidPayloadException("Missing 'name' on repository") raise InvalidPayloadException("Missing 'name' on repository")
default_branch = None default_branch = None
lookup_user = None lookup_user = None
try: try:
repo_full_name = '%s/%s' % (payload['repository']['owner']['name'], repo_full_name = "%s/%s" % (
payload['repository']['name']) payload["repository"]["owner"]["name"],
payload["repository"]["name"],
)
gh_client = self._get_client() gh_client = self._get_client()
repo = gh_client.get_repo(repo_full_name) repo = gh_client.get_repo(repo_full_name)
default_branch = repo.default_branch default_branch = repo.default_branch
lookup_user = self.lookup_user lookup_user = self.lookup_user
except GitHubBadCredentialsException: except GitHubBadCredentialsException:
logger.exception('Got GitHub Credentials Exception; Cannot lookup default branch') logger.exception(
"Got GitHub Credentials Exception; Cannot lookup default branch"
)
except GithubException: except GithubException:
logger.exception("Got GitHub Exception when trying to start trigger %s", self.trigger.id) logger.exception(
"Got GitHub Exception when trying to start trigger %s", self.trigger.id
)
raise SkipRequestException() raise SkipRequestException()
logger.debug('GitHub trigger payload %s', payload) logger.debug("GitHub trigger payload %s", payload)
metadata = get_transformed_webhook_payload(payload, default_branch=default_branch, metadata = get_transformed_webhook_payload(
lookup_user=lookup_user) payload, default_branch=default_branch, lookup_user=lookup_user
)
prepared = self.prepare_build(metadata) prepared = self.prepare_build(metadata)
# Check if we should skip this build. # Check if we should skip this build.

View file

@ -11,12 +11,18 @@ import requests
from jsonschema import validate from jsonschema import validate
from app import app, gitlab_trigger from app import app, gitlab_trigger
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException, from buildtrigger.triggerutil import (
TriggerDeactivationException, TriggerStartException, RepositoryReadException,
SkipRequestException, InvalidPayloadException, TriggerActivationException,
TriggerDeactivationException,
TriggerStartException,
SkipRequestException,
InvalidPayloadException,
TriggerAuthException, TriggerAuthException,
determine_build_ref, raise_if_skipped_build, determine_build_ref,
find_matching_branches) raise_if_skipped_build,
find_matching_branches,
)
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from endpoints.exception import ExternalServiceError from endpoints.exception import ExternalServiceError
from util.security.ssh import generate_ssh_keypair from util.security.ssh import generate_ssh_keypair
@ -25,55 +31,35 @@ from util.dict_wrappers import JSONPathDict, SafeDictSetter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GITLAB_WEBHOOK_PAYLOAD_SCHEMA = { GITLAB_WEBHOOK_PAYLOAD_SCHEMA = {
'type': 'object', "type": "object",
'properties': { "properties": {
'ref': { "ref": {"type": "string"},
'type': 'string', "checkout_sha": {"type": ["string", "null"]},
"repository": {
"type": "object",
"properties": {"git_ssh_url": {"type": "string"}},
"required": ["git_ssh_url"],
}, },
'checkout_sha': { "commits": {
'type': ['string', 'null'], "type": "array",
}, "items": {
'repository': { "type": "object",
'type': 'object', "properties": {
'properties': { "id": {"type": "string"},
'git_ssh_url': { "url": {"type": ["string", "null"]},
'type': 'string', "message": {"type": "string"},
"timestamp": {"type": "string"},
"author": {
"type": "object",
"properties": {"email": {"type": "string"}},
"required": ["email"],
}, },
}, },
'required': ['git_ssh_url'], "required": ["id", "message", "timestamp"],
},
'commits': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'id': {
'type': 'string',
},
'url': {
'type': ['string', 'null'],
},
'message': {
'type': 'string',
},
'timestamp': {
'type': 'string',
},
'author': {
'type': 'object',
'properties': {
'email': {
'type': 'string',
},
},
'required': ['email'],
},
},
'required': ['id', 'message', 'timestamp'],
}, },
}, },
}, },
'required': ['ref', 'checkout_sha', 'repository'], "required": ["ref", "checkout_sha", "repository"],
} }
_ACCESS_LEVEL_MAP = { _ACCESS_LEVEL_MAP = {
@ -93,13 +79,14 @@ def _catch_timeouts_and_errors(func):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
msg = 'Request to the GitLab API timed out' msg = "Request to the GitLab API timed out"
logger.exception(msg) logger.exception(msg)
raise ExternalServiceError(msg) raise ExternalServiceError(msg)
except gitlab.GitlabError: except gitlab.GitlabError:
msg = 'GitLab API error. Please contact support.' msg = "GitLab API error. Please contact support."
logger.exception(msg) logger.exception(msg)
raise ExternalServiceError(msg) raise ExternalServiceError(msg)
return wrapper return wrapper
@ -124,8 +111,9 @@ def _paginated_iterator(func, exc, **kwargs):
page = page + 1 page = page + 1
def get_transformed_webhook_payload(gl_payload, default_branch=None, lookup_user=None, def get_transformed_webhook_payload(
lookup_commit=None): gl_payload, default_branch=None, lookup_user=None, lookup_commit=None
):
""" Returns the Gitlab webhook JSON payload transformed into our own payload """ Returns the Gitlab webhook JSON payload transformed into our own payload
format. If the gl_payload is not valid, returns None. format. If the gl_payload is not valid, returns None.
""" """
@ -136,58 +124,60 @@ def get_transformed_webhook_payload(gl_payload, default_branch=None, lookup_user
payload = JSONPathDict(gl_payload) payload = JSONPathDict(gl_payload)
if payload['object_kind'] != 'push' and payload['object_kind'] != 'tag_push': if payload["object_kind"] != "push" and payload["object_kind"] != "tag_push":
# Unknown kind of webhook. # Unknown kind of webhook.
raise SkipRequestException raise SkipRequestException
# Check for empty commits. The commits list will be empty if the branch is deleted. # Check for empty commits. The commits list will be empty if the branch is deleted.
commits = payload['commits'] commits = payload["commits"]
if payload['object_kind'] == 'push' and not commits: if payload["object_kind"] == "push" and not commits:
raise SkipRequestException raise SkipRequestException
# Check for missing commit information. # Check for missing commit information.
commit_sha = payload['checkout_sha'] or payload['after'] commit_sha = payload["checkout_sha"] or payload["after"]
if commit_sha is None or commit_sha == '0000000000000000000000000000000000000000': if commit_sha is None or commit_sha == "0000000000000000000000000000000000000000":
raise SkipRequestException raise SkipRequestException
config = SafeDictSetter() config = SafeDictSetter()
config['commit'] = commit_sha config["commit"] = commit_sha
config['ref'] = payload['ref'] config["ref"] = payload["ref"]
config['default_branch'] = default_branch config["default_branch"] = default_branch
config['git_url'] = payload['repository.git_ssh_url'] config["git_url"] = payload["repository.git_ssh_url"]
found_commit = JSONPathDict({}) found_commit = JSONPathDict({})
if payload['object_kind'] == 'push' or payload['object_kind'] == 'tag_push': if payload["object_kind"] == "push" or payload["object_kind"] == "tag_push":
# Find the commit associated with the checkout_sha. Gitlab doesn't (necessary) send this in # Find the commit associated with the checkout_sha. Gitlab doesn't (necessary) send this in
# any order, so we cannot simply index into the commits list. # any order, so we cannot simply index into the commits list.
found_commit = None found_commit = None
if commits is not None: if commits is not None:
for commit in commits: for commit in commits:
if commit['id'] == payload['checkout_sha']: if commit["id"] == payload["checkout_sha"]:
found_commit = JSONPathDict(commit) found_commit = JSONPathDict(commit)
break break
if found_commit is None and lookup_commit: if found_commit is None and lookup_commit:
checkout_sha = payload['checkout_sha'] or payload['after'] checkout_sha = payload["checkout_sha"] or payload["after"]
found_commit_info = lookup_commit(payload['project_id'], checkout_sha) found_commit_info = lookup_commit(payload["project_id"], checkout_sha)
found_commit = JSONPathDict(dict(found_commit_info) if found_commit_info else {}) found_commit = JSONPathDict(
dict(found_commit_info) if found_commit_info else {}
)
if found_commit is None: if found_commit is None:
raise SkipRequestException raise SkipRequestException
config['commit_info.url'] = found_commit['url'] config["commit_info.url"] = found_commit["url"]
config['commit_info.message'] = found_commit['message'] config["commit_info.message"] = found_commit["message"]
config['commit_info.date'] = found_commit['timestamp'] config["commit_info.date"] = found_commit["timestamp"]
# Note: Gitlab does not send full user information with the payload, so we have to # Note: Gitlab does not send full user information with the payload, so we have to
# (optionally) look it up. # (optionally) look it up.
author_email = found_commit['author.email'] or found_commit['author_email'] author_email = found_commit["author.email"] or found_commit["author_email"]
if lookup_user and author_email: if lookup_user and author_email:
author_info = lookup_user(author_email) author_info = lookup_user(author_email)
if author_info: if author_info:
config['commit_info.author.username'] = author_info['username'] config["commit_info.author.username"] = author_info["username"]
config['commit_info.author.url'] = author_info['html_url'] config["commit_info.author.url"] = author_info["html_url"]
config['commit_info.author.avatar_url'] = author_info['avatar_url'] config["commit_info.author.avatar_url"] = author_info["avatar_url"]
return config.dict_value() return config.dict_value()
@ -196,15 +186,20 @@ class GitLabBuildTrigger(BuildTriggerHandler):
""" """
BuildTrigger for GitLab. BuildTrigger for GitLab.
""" """
@classmethod @classmethod
def service_name(cls): def service_name(cls):
return 'gitlab' return "gitlab"
def _get_authorized_client(self): def _get_authorized_client(self):
auth_token = self.auth_token or 'invalid' auth_token = self.auth_token or "invalid"
api_version = self.config.get('API_VERSION', '4') api_version = self.config.get("API_VERSION", "4")
client = gitlab.Gitlab(gitlab_trigger.api_endpoint(), oauth_token=auth_token, timeout=20, client = gitlab.Gitlab(
api_version=api_version) gitlab_trigger.api_endpoint(),
oauth_token=auth_token,
timeout=20,
api_version=api_version,
)
try: try:
client.auth() client.auth()
except gitlab.GitlabGetError as ex: except gitlab.GitlabGetError as ex:
@ -213,55 +208,51 @@ class GitLabBuildTrigger(BuildTriggerHandler):
return client return client
def is_active(self): def is_active(self):
return 'hook_id' in self.config return "hook_id" in self.config
@_catch_timeouts_and_errors @_catch_timeouts_and_errors
def activate(self, standard_webhook_url): def activate(self, standard_webhook_url):
config = self.config config = self.config
new_build_source = config['build_source'] new_build_source = config["build_source"]
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
# Find the GitLab repository. # Find the GitLab repository.
gl_project = gl_client.projects.get(new_build_source) gl_project = gl_client.projects.get(new_build_source)
if not gl_project: if not gl_project:
msg = 'Unable to find GitLab repository for source: %s' % new_build_source msg = "Unable to find GitLab repository for source: %s" % new_build_source
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
# Add a deploy key to the repository. # Add a deploy key to the repository.
public_key, private_key = generate_ssh_keypair() public_key, private_key = generate_ssh_keypair()
config['credentials'] = [ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
{
'name': 'SSH Public Key',
'value': public_key,
},
]
key = gl_project.keys.create({ key = gl_project.keys.create(
'title': '%s Builder' % app.config['REGISTRY_TITLE'], {"title": "%s Builder" % app.config["REGISTRY_TITLE"], "key": public_key}
'key': public_key, )
})
if not key: if not key:
msg = 'Unable to add deploy key to repository: %s' % new_build_source msg = "Unable to add deploy key to repository: %s" % new_build_source
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
config['key_id'] = key.get_id() config["key_id"] = key.get_id()
# Add the webhook to the GitLab repository. # Add the webhook to the GitLab repository.
hook = gl_project.hooks.create({ hook = gl_project.hooks.create(
'url': standard_webhook_url, {
'push': True, "url": standard_webhook_url,
'tag_push': True, "push": True,
'push_events': True, "tag_push": True,
'tag_push_events': True, "push_events": True,
}) "tag_push_events": True,
}
)
if not hook: if not hook:
msg = 'Unable to create webhook on repository: %s' % new_build_source msg = "Unable to create webhook on repository: %s" % new_build_source
raise TriggerActivationException(msg) raise TriggerActivationException(msg)
config['hook_id'] = hook.get_id() config["hook_id"] = hook.get_id()
self.config = config self.config = config
return config, {'private_key': private_key} return config, {"private_key": private_key}
def deactivate(self): def deactivate(self):
config = self.config config = self.config
@ -269,10 +260,10 @@ class GitLabBuildTrigger(BuildTriggerHandler):
# Find the GitLab repository. # Find the GitLab repository.
try: try:
gl_project = gl_client.projects.get(config['build_source']) gl_project = gl_client.projects.get(config["build_source"])
if not gl_project: if not gl_project:
config.pop('key_id', None) config.pop("key_id", None)
config.pop('hook_id', None) config.pop("hook_id", None)
self.config = config self.config = config
return config return config
except gitlab.GitlabGetError as ex: except gitlab.GitlabGetError as ex:
@ -281,21 +272,21 @@ class GitLabBuildTrigger(BuildTriggerHandler):
# Remove the webhook. # Remove the webhook.
try: try:
gl_project.hooks.delete(config['hook_id']) gl_project.hooks.delete(config["hook_id"])
except gitlab.GitlabDeleteError as ex: except gitlab.GitlabDeleteError as ex:
if ex.response_code != 404: if ex.response_code != 404:
raise raise
config.pop('hook_id', None) config.pop("hook_id", None)
# Remove the key # Remove the key
try: try:
gl_project.keys.delete(config['key_id']) gl_project.keys.delete(config["key_id"])
except gitlab.GitlabDeleteError as ex: except gitlab.GitlabDeleteError as ex:
if ex.response_code != 404: if ex.response_code != 404:
raise raise
config.pop('key_id', None) config.pop("key_id", None)
self.config = config self.config = config
return config return config
@ -305,37 +296,41 @@ class GitLabBuildTrigger(BuildTriggerHandler):
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
current_user = gl_client.user current_user = gl_client.user
if not current_user: if not current_user:
raise RepositoryReadException('Unable to get current user') raise RepositoryReadException("Unable to get current user")
namespaces = {} namespaces = {}
for namespace in _paginated_iterator(gl_client.namespaces.list, RepositoryReadException): for namespace in _paginated_iterator(
gl_client.namespaces.list, RepositoryReadException
):
namespace_id = namespace.get_id() namespace_id = namespace.get_id()
if namespace_id in namespaces: if namespace_id in namespaces:
namespaces[namespace_id]['score'] = namespaces[namespace_id]['score'] + 1 namespaces[namespace_id]["score"] = (
namespaces[namespace_id]["score"] + 1
)
else: else:
owner = namespace.attributes['name'] owner = namespace.attributes["name"]
namespaces[namespace_id] = { namespaces[namespace_id] = {
'personal': namespace.attributes['kind'] == 'user', "personal": namespace.attributes["kind"] == "user",
'id': str(namespace_id), "id": str(namespace_id),
'title': namespace.attributes['name'], "title": namespace.attributes["name"],
'avatar_url': namespace.attributes.get('avatar_url'), "avatar_url": namespace.attributes.get("avatar_url"),
'score': 1, "score": 1,
'url': namespace.attributes.get('web_url') or '', "url": namespace.attributes.get("web_url") or "",
} }
return BuildTriggerHandler.build_namespaces_response(namespaces) return BuildTriggerHandler.build_namespaces_response(namespaces)
def _get_namespace(self, gl_client, gl_namespace, lazy=False): def _get_namespace(self, gl_client, gl_namespace, lazy=False):
try: try:
if gl_namespace.attributes['kind'] == 'group': if gl_namespace.attributes["kind"] == "group":
return gl_client.groups.get(gl_namespace.attributes['id'], lazy=lazy) return gl_client.groups.get(gl_namespace.attributes["id"], lazy=lazy)
if gl_namespace.attributes['kind'] == 'user': if gl_namespace.attributes["kind"] == "user":
return gl_client.users.get(gl_client.user.attributes['id'], lazy=lazy) return gl_client.users.get(gl_client.user.attributes["id"], lazy=lazy)
# Note: This doesn't seem to work for IDs retrieved via the namespaces API; the IDs are # Note: This doesn't seem to work for IDs retrieved via the namespaces API; the IDs are
# different. # different.
return gl_client.users.get(gl_namespace.attributes['id'], lazy=lazy) return gl_client.users.get(gl_namespace.attributes["id"], lazy=lazy)
except gitlab.GitlabGetError: except gitlab.GitlabGetError:
return None return None
@ -346,15 +341,17 @@ class GitLabBuildTrigger(BuildTriggerHandler):
def repo_view(repo): def repo_view(repo):
# Because *anything* can be None in GitLab API! # Because *anything* can be None in GitLab API!
permissions = repo.attributes.get('permissions') or {} permissions = repo.attributes.get("permissions") or {}
group_access = permissions.get('group_access') or {} group_access = permissions.get("group_access") or {}
project_access = permissions.get('project_access') or {} project_access = permissions.get("project_access") or {}
missing_group_access = permissions.get('group_access') is None missing_group_access = permissions.get("group_access") is None
missing_project_access = permissions.get('project_access') is None missing_project_access = permissions.get("project_access") is None
access_level = max(group_access.get('access_level') or 0, access_level = max(
project_access.get('access_level') or 0) group_access.get("access_level") or 0,
project_access.get("access_level") or 0,
)
has_admin_permission = _ACCESS_LEVEL_MAP.get(access_level, ("", False))[1] has_admin_permission = _ACCESS_LEVEL_MAP.get(access_level, ("", False))[1]
if missing_group_access or missing_project_access: if missing_group_access or missing_project_access:
@ -365,20 +362,24 @@ class GitLabBuildTrigger(BuildTriggerHandler):
has_admin_permission = True has_admin_permission = True
view = { view = {
'name': repo.attributes['path'], "name": repo.attributes["path"],
'full_name': repo.attributes['path_with_namespace'], "full_name": repo.attributes["path_with_namespace"],
'description': repo.attributes.get('description') or '', "description": repo.attributes.get("description") or "",
'url': repo.attributes.get('web_url'), "url": repo.attributes.get("web_url"),
'has_admin_permissions': has_admin_permission, "has_admin_permissions": has_admin_permission,
'private': repo.attributes.get('visibility') == 'private', "private": repo.attributes.get("visibility") == "private",
} }
if repo.attributes.get('last_activity_at'): if repo.attributes.get("last_activity_at"):
try: try:
last_modified = dateutil.parser.parse(repo.attributes['last_activity_at']) last_modified = dateutil.parser.parse(
view['last_updated'] = timegm(last_modified.utctimetuple()) repo.attributes["last_activity_at"]
)
view["last_updated"] = timegm(last_modified.utctimetuple())
except ValueError: except ValueError:
logger.exception('Gitlab gave us an invalid last_activity_at: %s', last_modified) logger.exception(
"Gitlab gave us an invalid last_activity_at: %s", last_modified
)
return view return view
@ -390,10 +391,14 @@ class GitLabBuildTrigger(BuildTriggerHandler):
return [] return []
namespace_obj = self._get_namespace(gl_client, gl_namespace, lazy=True) namespace_obj = self._get_namespace(gl_client, gl_namespace, lazy=True)
repositories = _paginated_iterator(namespace_obj.projects.list, RepositoryReadException) repositories = _paginated_iterator(
namespace_obj.projects.list, RepositoryReadException
)
try: try:
return BuildTriggerHandler.build_sources_response([repo_view(repo) for repo in repositories]) return BuildTriggerHandler.build_sources_response(
[repo_view(repo) for repo in repositories]
)
except gitlab.GitlabGetError: except gitlab.GitlabGetError:
return [] return []
@ -401,46 +406,53 @@ class GitLabBuildTrigger(BuildTriggerHandler):
def list_build_subdirs(self): def list_build_subdirs(self):
config = self.config config = self.config
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
new_build_source = config['build_source'] new_build_source = config["build_source"]
gl_project = gl_client.projects.get(new_build_source) gl_project = gl_client.projects.get(new_build_source)
if not gl_project: if not gl_project:
msg = 'Unable to find GitLab repository for source: %s' % new_build_source msg = "Unable to find GitLab repository for source: %s" % new_build_source
raise RepositoryReadException(msg) raise RepositoryReadException(msg)
repo_branches = gl_project.branches.list() repo_branches = gl_project.branches.list()
if not repo_branches: if not repo_branches:
msg = 'Unable to find GitLab branches for source: %s' % new_build_source msg = "Unable to find GitLab branches for source: %s" % new_build_source
raise RepositoryReadException(msg) raise RepositoryReadException(msg)
branches = [branch.attributes['name'] for branch in repo_branches] branches = [branch.attributes["name"] for branch in repo_branches]
branches = find_matching_branches(config, branches) branches = find_matching_branches(config, branches)
branches = branches or [gl_project.attributes['default_branch'] or 'master'] branches = branches or [gl_project.attributes["default_branch"] or "master"]
repo_tree = gl_project.repository_tree(ref=branches[0]) repo_tree = gl_project.repository_tree(ref=branches[0])
if not repo_tree: if not repo_tree:
msg = 'Unable to find GitLab repository tree for source: %s' % new_build_source msg = (
"Unable to find GitLab repository tree for source: %s"
% new_build_source
)
raise RepositoryReadException(msg) raise RepositoryReadException(msg)
return [node['name'] for node in repo_tree if self.filename_is_dockerfile(node['name'])] return [
node["name"]
for node in repo_tree
if self.filename_is_dockerfile(node["name"])
]
@_catch_timeouts_and_errors @_catch_timeouts_and_errors
def load_dockerfile_contents(self): def load_dockerfile_contents(self):
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
path = self.get_dockerfile_path() path = self.get_dockerfile_path()
gl_project = gl_client.projects.get(self.config['build_source']) gl_project = gl_client.projects.get(self.config["build_source"])
if not gl_project: if not gl_project:
return None return None
branches = self.list_field_values('branch_name') branches = self.list_field_values("branch_name")
branches = find_matching_branches(self.config, branches) branches = find_matching_branches(self.config, branches)
if branches == []: if branches == []:
return None return None
branch_name = branches[0] branch_name = branches[0]
if gl_project.attributes['default_branch'] in branches: if gl_project.attributes["default_branch"] in branches:
branch_name = gl_project.attributes['default_branch'] branch_name = gl_project.attributes["default_branch"]
try: try:
return gl_project.files.get(path, branch_name).decode() return gl_project.files.get(path, branch_name).decode()
@ -449,19 +461,20 @@ class GitLabBuildTrigger(BuildTriggerHandler):
@_catch_timeouts_and_errors @_catch_timeouts_and_errors
def list_field_values(self, field_name, limit=None): def list_field_values(self, field_name, limit=None):
if field_name == 'refs': if field_name == "refs":
branches = self.list_field_values('branch_name') branches = self.list_field_values("branch_name")
tags = self.list_field_values('tag_name') tags = self.list_field_values("tag_name")
return ([{'kind': 'branch', 'name': b} for b in branches] + return [{"kind": "branch", "name": b} for b in branches] + [
[{'kind': 'tag', 'name': t} for t in tags]) {"kind": "tag", "name": t} for t in tags
]
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
gl_project = gl_client.projects.get(self.config['build_source']) gl_project = gl_client.projects.get(self.config["build_source"])
if not gl_project: if not gl_project:
return [] return []
if field_name == 'tag_name': if field_name == "tag_name":
tags = gl_project.tags.list() tags = gl_project.tags.list()
if not tags: if not tags:
return [] return []
@ -469,9 +482,9 @@ class GitLabBuildTrigger(BuildTriggerHandler):
if limit: if limit:
tags = tags[0:limit] tags = tags[0:limit]
return [tag.attributes['name'] for tag in tags] return [tag.attributes["name"] for tag in tags]
if field_name == 'branch_name': if field_name == "branch_name":
branches = gl_project.branches.list() branches = gl_project.branches.list()
if not branches: if not branches:
return [] return []
@ -479,12 +492,12 @@ class GitLabBuildTrigger(BuildTriggerHandler):
if limit: if limit:
branches = branches[0:limit] branches = branches[0:limit]
return [branch.attributes['name'] for branch in branches] return [branch.attributes["name"] for branch in branches]
return None return None
def get_repository_url(self): def get_repository_url(self):
return gitlab_trigger.get_public_url(self.config['build_source']) return gitlab_trigger.get_public_url(self.config["build_source"])
@_catch_timeouts_and_errors @_catch_timeouts_and_errors
def lookup_commit(self, repo_id, commit_sha): def lookup_commit(self, repo_id, commit_sha):
@ -492,7 +505,7 @@ class GitLabBuildTrigger(BuildTriggerHandler):
return None return None
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
gl_project = gl_client.projects.get(self.config['build_source'], lazy=True) gl_project = gl_client.projects.get(self.config["build_source"], lazy=True)
commit = gl_project.commits.get(commit_sha) commit = gl_project.commits.get(commit_sha)
if not commit: if not commit:
return None return None
@ -509,9 +522,9 @@ class GitLabBuildTrigger(BuildTriggerHandler):
[user] = result [user] = result
return { return {
'username': user.attributes['username'], "username": user.attributes["username"],
'html_url': user.attributes['web_url'], "html_url": user.attributes["web_url"],
'avatar_url': user.attributes['avatar_url'] "avatar_url": user.attributes["avatar_url"],
} }
except ValueError: except ValueError:
return None return None
@ -523,37 +536,39 @@ class GitLabBuildTrigger(BuildTriggerHandler):
return None return None
metadata = { metadata = {
'commit': commit.attributes['id'], "commit": commit.attributes["id"],
'ref': ref, "ref": ref,
'default_branch': repo.attributes['default_branch'], "default_branch": repo.attributes["default_branch"],
'git_url': repo.attributes['ssh_url_to_repo'], "git_url": repo.attributes["ssh_url_to_repo"],
'commit_info': { "commit_info": {
'url': os.path.join(repo.attributes['web_url'], 'commit', commit.attributes['id']), "url": os.path.join(
'message': commit.attributes['message'], repo.attributes["web_url"], "commit", commit.attributes["id"]
'date': commit.attributes['committed_date'], ),
"message": commit.attributes["message"],
"date": commit.attributes["committed_date"],
}, },
} }
committer = None committer = None
if 'committer_email' in commit.attributes: if "committer_email" in commit.attributes:
committer = self.lookup_user(commit.attributes['committer_email']) committer = self.lookup_user(commit.attributes["committer_email"])
author = None author = None
if 'author_email' in commit.attributes: if "author_email" in commit.attributes:
author = self.lookup_user(commit.attributes['author_email']) author = self.lookup_user(commit.attributes["author_email"])
if committer is not None: if committer is not None:
metadata['commit_info']['committer'] = { metadata["commit_info"]["committer"] = {
'username': committer['username'], "username": committer["username"],
'avatar_url': committer['avatar_url'], "avatar_url": committer["avatar_url"],
'url': committer.get('http_url', ''), "url": committer.get("http_url", ""),
} }
if author is not None: if author is not None:
metadata['commit_info']['author'] = { metadata["commit_info"]["author"] = {
'username': author['username'], "username": author["username"],
'avatar_url': author['avatar_url'], "avatar_url": author["avatar_url"],
'url': author.get('http_url', ''), "url": author.get("http_url", ""),
} }
return metadata return metadata
@ -561,29 +576,33 @@ class GitLabBuildTrigger(BuildTriggerHandler):
@_catch_timeouts_and_errors @_catch_timeouts_and_errors
def manual_start(self, run_parameters=None): def manual_start(self, run_parameters=None):
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
gl_project = gl_client.projects.get(self.config['build_source']) gl_project = gl_client.projects.get(self.config["build_source"])
if not gl_project: if not gl_project:
raise TriggerStartException('Could not find repository') raise TriggerStartException("Could not find repository")
def get_tag_sha(tag_name): def get_tag_sha(tag_name):
try: try:
tag = gl_project.tags.get(tag_name) tag = gl_project.tags.get(tag_name)
except gitlab.GitlabGetError: except gitlab.GitlabGetError:
raise TriggerStartException('Could not find tag in repository') raise TriggerStartException("Could not find tag in repository")
return tag.attributes['commit']['id'] return tag.attributes["commit"]["id"]
def get_branch_sha(branch_name): def get_branch_sha(branch_name):
try: try:
branch = gl_project.branches.get(branch_name) branch = gl_project.branches.get(branch_name)
except gitlab.GitlabGetError: except gitlab.GitlabGetError:
raise TriggerStartException('Could not find branch in repository') raise TriggerStartException("Could not find branch in repository")
return branch.attributes['commit']['id'] return branch.attributes["commit"]["id"]
# Find the branch or tag to build. # Find the branch or tag to build.
(commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, (commit_sha, ref) = determine_build_ref(
gl_project.attributes['default_branch']) run_parameters,
get_branch_sha,
get_tag_sha,
gl_project.attributes["default_branch"],
)
metadata = self.get_metadata_for_commit(commit_sha, ref, gl_project) metadata = self.get_metadata_for_commit(commit_sha, ref, gl_project)
return self.prepare_build(metadata, is_manual=True) return self.prepare_build(metadata, is_manual=True)
@ -594,13 +613,16 @@ class GitLabBuildTrigger(BuildTriggerHandler):
if not payload: if not payload:
raise InvalidPayloadException() raise InvalidPayloadException()
logger.debug('GitLab trigger payload %s', payload) logger.debug("GitLab trigger payload %s", payload)
# Lookup the default branch. # Lookup the default branch.
gl_client = self._get_authorized_client() gl_client = self._get_authorized_client()
gl_project = gl_client.projects.get(self.config['build_source']) gl_project = gl_client.projects.get(self.config["build_source"])
if not gl_project: if not gl_project:
logger.debug('Skipping GitLab build; project %s not found', self.config['build_source']) logger.debug(
"Skipping GitLab build; project %s not found",
self.config["build_source"],
)
raise InvalidPayloadException() raise InvalidPayloadException()
def lookup_commit(repo_id, commit_sha): def lookup_commit(repo_id, commit_sha):
@ -610,10 +632,13 @@ class GitLabBuildTrigger(BuildTriggerHandler):
return dict(commit.attributes) return dict(commit.attributes)
default_branch = gl_project.attributes['default_branch'] default_branch = gl_project.attributes["default_branch"]
metadata = get_transformed_webhook_payload(payload, default_branch=default_branch, metadata = get_transformed_webhook_payload(
payload,
default_branch=default_branch,
lookup_user=self.lookup_user, lookup_user=self.lookup_user,
lookup_commit=lookup_commit) lookup_commit=lookup_commit,
)
prepared = self.prepare_build(metadata) prepared = self.prepare_build(metadata)
# Check if we should skip this build. # Check if we should skip this build.

View file

@ -4,107 +4,113 @@ from mock import Mock
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
from util.morecollections import AttrDict from util.morecollections import AttrDict
def get_bitbucket_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger')) def get_bitbucket_trigger(dockerfile_path=""):
trigger = BitbucketBuildTrigger(trigger_obj, { trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
'build_source': 'foo/bar', trigger = BitbucketBuildTrigger(
'dockerfile_path': dockerfile_path, trigger_obj,
'nickname': 'knownuser', {
'account_id': 'foo', "build_source": "foo/bar",
}) "dockerfile_path": dockerfile_path,
"nickname": "knownuser",
"account_id": "foo",
},
)
trigger._get_client = get_mock_bitbucket trigger._get_client = get_mock_bitbucket
return trigger return trigger
def get_repo_path_contents(path, revision): def get_repo_path_contents(path, revision):
data = { data = {"files": [{"path": "Dockerfile"}]}
'files': [{'path': 'Dockerfile'}],
}
return (True, data, None) return (True, data, None)
def get_raw_path_contents(path, revision):
if path == 'Dockerfile':
return (True, 'hello world', None)
if path == 'somesubdir/Dockerfile': def get_raw_path_contents(path, revision):
return (True, 'hi universe', None) if path == "Dockerfile":
return (True, "hello world", None)
if path == "somesubdir/Dockerfile":
return (True, "hi universe", None)
return (False, None, None) return (False, None, None)
def get_branches_and_tags(): def get_branches_and_tags():
data = { data = {
'branches': [{'name': 'master'}, {'name': 'otherbranch'}], "branches": [{"name": "master"}, {"name": "otherbranch"}],
'tags': [{'name': 'sometag'}, {'name': 'someothertag'}], "tags": [{"name": "sometag"}, {"name": "someothertag"}],
} }
return (True, data, None) return (True, data, None)
def get_branches(): def get_branches():
return (True, {'master': {}, 'otherbranch': {}}, None) return (True, {"master": {}, "otherbranch": {}}, None)
def get_tags(): def get_tags():
return (True, {'sometag': {}, 'someothertag': {}}, None) return (True, {"sometag": {}, "someothertag": {}}, None)
def get_branch(branch_name): def get_branch(branch_name):
if branch_name != 'master': if branch_name != "master":
return (False, None, None) return (False, None, None)
data = { data = {"target": {"hash": "aaaaaaa"}}
'target': {
'hash': 'aaaaaaa',
},
}
return (True, data, None) return (True, data, None)
def get_tag(tag_name): def get_tag(tag_name):
if tag_name != 'sometag': if tag_name != "sometag":
return (False, None, None) return (False, None, None)
data = { data = {"target": {"hash": "aaaaaaa"}}
'target': {
'hash': 'aaaaaaa',
},
}
return (True, data, None) return (True, data, None)
def get_changeset_mock(commit_sha): def get_changeset_mock(commit_sha):
if commit_sha != 'aaaaaaa': if commit_sha != "aaaaaaa":
return (False, None, 'Not found') return (False, None, "Not found")
data = { data = {
'node': 'aaaaaaa', "node": "aaaaaaa",
'message': 'some message', "message": "some message",
'timestamp': 'now', "timestamp": "now",
'raw_author': 'foo@bar.com', "raw_author": "foo@bar.com",
} }
return (True, data, None) return (True, data, None)
def get_changesets(): def get_changesets():
changesets_mock = Mock() changesets_mock = Mock()
changesets_mock.get = Mock(side_effect=get_changeset_mock) changesets_mock.get = Mock(side_effect=get_changeset_mock)
return changesets_mock return changesets_mock
def get_deploykeys(): def get_deploykeys():
deploykeys_mock = Mock() deploykeys_mock = Mock()
deploykeys_mock.create = Mock(return_value=(True, {'pk': 'someprivatekey'}, None)) deploykeys_mock.create = Mock(return_value=(True, {"pk": "someprivatekey"}, None))
deploykeys_mock.delete = Mock(return_value=(True, {}, None)) deploykeys_mock.delete = Mock(return_value=(True, {}, None))
return deploykeys_mock return deploykeys_mock
def get_webhooks(): def get_webhooks():
webhooks_mock = Mock() webhooks_mock = Mock()
webhooks_mock.create = Mock(return_value=(True, {'uuid': 'someuuid'}, None)) webhooks_mock.create = Mock(return_value=(True, {"uuid": "someuuid"}, None))
webhooks_mock.delete = Mock(return_value=(True, {}, None)) webhooks_mock.delete = Mock(return_value=(True, {}, None))
return webhooks_mock return webhooks_mock
def get_repo_mock(name): def get_repo_mock(name):
if name != 'bar': if name != "bar":
return None return None
repo_mock = Mock() repo_mock = Mock()
repo_mock.get_main_branch = Mock(return_value=(True, {'name': 'master'}, None)) repo_mock.get_main_branch = Mock(return_value=(True, {"name": "master"}, None))
repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents) repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents) repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags) repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
@ -118,41 +124,47 @@ def get_repo_mock(name):
repo_mock.webhooks = Mock(side_effect=get_webhooks) repo_mock.webhooks = Mock(side_effect=get_webhooks)
return repo_mock return repo_mock
def get_repositories_mock(): def get_repositories_mock():
repos_mock = Mock() repos_mock = Mock()
repos_mock.get = Mock(side_effect=get_repo_mock) repos_mock.get = Mock(side_effect=get_repo_mock)
return repos_mock return repos_mock
def get_namespace_mock(namespace): def get_namespace_mock(namespace):
namespace_mock = Mock() namespace_mock = Mock()
namespace_mock.repositories = Mock(side_effect=get_repositories_mock) namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
return namespace_mock return namespace_mock
def get_repo(namespace, name): def get_repo(namespace, name):
return { return {
'owner': namespace, "owner": namespace,
'logo': 'avatarurl', "logo": "avatarurl",
'slug': name, "slug": name,
'description': 'some %s repo' % (name), "description": "some %s repo" % (name),
'utc_last_updated': str(datetime.utcfromtimestamp(0)), "utc_last_updated": str(datetime.utcfromtimestamp(0)),
'read_only': namespace != 'knownuser', "read_only": namespace != "knownuser",
'is_private': name == 'somerepo', "is_private": name == "somerepo",
} }
def get_visible_repos(): def get_visible_repos():
repos = [ repos = [
get_repo('knownuser', 'somerepo'), get_repo("knownuser", "somerepo"),
get_repo('someorg', 'somerepo'), get_repo("someorg", "somerepo"),
get_repo('someorg', 'anotherrepo'), get_repo("someorg", "anotherrepo"),
] ]
return (True, repos, None) return (True, repos, None)
def get_authed_mock(token, secret): def get_authed_mock(token, secret):
authed_mock = Mock() authed_mock = Mock()
authed_mock.for_namespace = Mock(side_effect=get_namespace_mock) authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos) authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
return authed_mock return authed_mock
def get_mock_bitbucket(): def get_mock_bitbucket():
bitbucket_mock = Mock() bitbucket_mock = Mock()
bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock) bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)

View file

@ -6,41 +6,45 @@ from github import GithubException
from buildtrigger.githubhandler import GithubBuildTrigger from buildtrigger.githubhandler import GithubBuildTrigger
from util.morecollections import AttrDict from util.morecollections import AttrDict
def get_github_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger')) def get_github_trigger(dockerfile_path=""):
trigger = GithubBuildTrigger(trigger_obj, {'build_source': 'foo', 'dockerfile_path': dockerfile_path}) trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger = GithubBuildTrigger(
trigger_obj, {"build_source": "foo", "dockerfile_path": dockerfile_path}
)
trigger._get_client = get_mock_github trigger._get_client = get_mock_github
return trigger return trigger
def get_mock_github(): def get_mock_github():
def get_commit_mock(commit_sha): def get_commit_mock(commit_sha):
if commit_sha == 'aaaaaaa': if commit_sha == "aaaaaaa":
commit_mock = Mock() commit_mock = Mock()
commit_mock.sha = commit_sha commit_mock.sha = commit_sha
commit_mock.html_url = 'http://url/to/commit' commit_mock.html_url = "http://url/to/commit"
commit_mock.last_modified = 'now' commit_mock.last_modified = "now"
commit_mock.commit = Mock() commit_mock.commit = Mock()
commit_mock.commit.message = 'some cool message' commit_mock.commit.message = "some cool message"
commit_mock.committer = Mock() commit_mock.committer = Mock()
commit_mock.committer.login = 'someuser' commit_mock.committer.login = "someuser"
commit_mock.committer.avatar_url = 'avatarurl' commit_mock.committer.avatar_url = "avatarurl"
commit_mock.committer.html_url = 'htmlurl' commit_mock.committer.html_url = "htmlurl"
commit_mock.author = Mock() commit_mock.author = Mock()
commit_mock.author.login = 'someuser' commit_mock.author.login = "someuser"
commit_mock.author.avatar_url = 'avatarurl' commit_mock.author.avatar_url = "avatarurl"
commit_mock.author.html_url = 'htmlurl' commit_mock.author.html_url = "htmlurl"
return commit_mock return commit_mock
raise GithubException(None, None) raise GithubException(None, None)
def get_branch_mock(branch_name): def get_branch_mock(branch_name):
if branch_name == 'master': if branch_name == "master":
branch_mock = Mock() branch_mock = Mock()
branch_mock.commit = Mock() branch_mock.commit = Mock()
branch_mock.commit.sha = 'aaaaaaa' branch_mock.commit.sha = "aaaaaaa"
return branch_mock return branch_mock
raise GithubException(None, None) raise GithubException(None, None)
@ -50,39 +54,42 @@ def get_mock_github():
repo_mock.owner = Mock() repo_mock.owner = Mock()
repo_mock.owner.login = namespace repo_mock.owner.login = namespace
repo_mock.full_name = '%s/%s' % (namespace, name) repo_mock.full_name = "%s/%s" % (namespace, name)
repo_mock.name = name repo_mock.name = name
repo_mock.description = 'some %s repo' % (name) repo_mock.description = "some %s repo" % (name)
if name != 'anotherrepo': if name != "anotherrepo":
repo_mock.pushed_at = datetime.utcfromtimestamp(0) repo_mock.pushed_at = datetime.utcfromtimestamp(0)
else: else:
repo_mock.pushed_at = None repo_mock.pushed_at = None
repo_mock.html_url = 'https://bitbucket.org/%s/%s' % (namespace, name) repo_mock.html_url = "https://bitbucket.org/%s/%s" % (namespace, name)
repo_mock.private = name == 'somerepo' repo_mock.private = name == "somerepo"
repo_mock.permissions = Mock() repo_mock.permissions = Mock()
repo_mock.permissions.admin = namespace == 'knownuser' repo_mock.permissions.admin = namespace == "knownuser"
return repo_mock return repo_mock
def get_user_repos_mock(type='all', sort='created'): def get_user_repos_mock(type="all", sort="created"):
return [get_repo_mock('knownuser', 'somerepo')] return [get_repo_mock("knownuser", "somerepo")]
def get_org_repos_mock(type='all'): def get_org_repos_mock(type="all"):
return [get_repo_mock('someorg', 'somerepo'), get_repo_mock('someorg', 'anotherrepo')] return [
get_repo_mock("someorg", "somerepo"),
get_repo_mock("someorg", "anotherrepo"),
]
def get_orgs_mock(): def get_orgs_mock():
return [get_org_mock('someorg')] return [get_org_mock("someorg")]
def get_user_mock(username='knownuser'): def get_user_mock(username="knownuser"):
if username == 'knownuser': if username == "knownuser":
user_mock = Mock() user_mock = Mock()
user_mock.name = username user_mock.name = username
user_mock.plan = Mock() user_mock.plan = Mock()
user_mock.plan.private_repos = 1 user_mock.plan.private_repos = 1
user_mock.login = username user_mock.login = username
user_mock.html_url = 'https://bitbucket.org/%s' % (username) user_mock.html_url = "https://bitbucket.org/%s" % (username)
user_mock.avatar_url = 'avatarurl' user_mock.avatar_url = "avatarurl"
user_mock.get_repos = Mock(side_effect=get_user_repos_mock) user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
user_mock.get_orgs = Mock(side_effect=get_orgs_mock) user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
return user_mock return user_mock
@ -90,12 +97,12 @@ def get_mock_github():
raise GithubException(None, None) raise GithubException(None, None)
def get_org_mock(namespace): def get_org_mock(namespace):
if namespace == 'someorg': if namespace == "someorg":
org_mock = Mock() org_mock = Mock()
org_mock.get_repos = Mock(side_effect=get_org_repos_mock) org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
org_mock.login = namespace org_mock.login = namespace
org_mock.html_url = 'https://bitbucket.org/%s' % (namespace) org_mock.html_url = "https://bitbucket.org/%s" % (namespace)
org_mock.avatar_url = 'avatarurl' org_mock.avatar_url = "avatarurl"
org_mock.name = namespace org_mock.name = namespace
org_mock.plan = Mock() org_mock.plan = Mock()
org_mock.plan.private_repos = 2 org_mock.plan.private_repos = 2
@ -105,64 +112,62 @@ def get_mock_github():
def get_tags_mock(): def get_tags_mock():
sometag = Mock() sometag = Mock()
sometag.name = 'sometag' sometag.name = "sometag"
sometag.commit = get_commit_mock('aaaaaaa') sometag.commit = get_commit_mock("aaaaaaa")
someothertag = Mock() someothertag = Mock()
someothertag.name = 'someothertag' someothertag.name = "someothertag"
someothertag.commit = get_commit_mock('aaaaaaa') someothertag.commit = get_commit_mock("aaaaaaa")
return [sometag, someothertag] return [sometag, someothertag]
def get_branches_mock(): def get_branches_mock():
master = Mock() master = Mock()
master.name = 'master' master.name = "master"
master.commit = get_commit_mock('aaaaaaa') master.commit = get_commit_mock("aaaaaaa")
otherbranch = Mock() otherbranch = Mock()
otherbranch.name = 'otherbranch' otherbranch.name = "otherbranch"
otherbranch.commit = get_commit_mock('aaaaaaa') otherbranch.commit = get_commit_mock("aaaaaaa")
return [master, otherbranch] return [master, otherbranch]
def get_contents_mock(filepath): def get_contents_mock(filepath):
if filepath == 'Dockerfile': if filepath == "Dockerfile":
m = Mock() m = Mock()
m.content = 'hello world' m.content = "hello world"
return m return m
if filepath == 'somesubdir/Dockerfile': if filepath == "somesubdir/Dockerfile":
m = Mock() m = Mock()
m.content = 'hi universe' m.content = "hi universe"
return m return m
raise GithubException(None, None) raise GithubException(None, None)
def get_git_tree_mock(commit_sha, recursive=False): def get_git_tree_mock(commit_sha, recursive=False):
first_file = Mock() first_file = Mock()
first_file.type = 'blob' first_file.type = "blob"
first_file.path = 'Dockerfile' first_file.path = "Dockerfile"
second_file = Mock() second_file = Mock()
second_file.type = 'other' second_file.type = "other"
second_file.path = '/some/Dockerfile' second_file.path = "/some/Dockerfile"
third_file = Mock() third_file = Mock()
third_file.type = 'blob' third_file.type = "blob"
third_file.path = 'somesubdir/Dockerfile' third_file.path = "somesubdir/Dockerfile"
t = Mock() t = Mock()
if commit_sha == 'aaaaaaa': if commit_sha == "aaaaaaa":
t.tree = [ t.tree = [first_file, second_file, third_file]
first_file, second_file, third_file,
]
else: else:
t.tree = [] t.tree = []
return t return t
repo_mock = Mock() repo_mock = Mock()
repo_mock.default_branch = 'master' repo_mock.default_branch = "master"
repo_mock.ssh_url = 'ssh_url' repo_mock.ssh_url = "ssh_url"
repo_mock.get_branch = Mock(side_effect=get_branch_mock) repo_mock.get_branch = Mock(side_effect=get_branch_mock)
repo_mock.get_tags = Mock(side_effect=get_tags_mock) repo_mock.get_tags = Mock(side_effect=get_tags_mock)

View file

@ -11,31 +11,28 @@ from buildtrigger.gitlabhandler import GitLabBuildTrigger
from util.morecollections import AttrDict from util.morecollections import AttrDict
@urlmatch(netloc=r'fakegitlab') @urlmatch(netloc=r"fakegitlab")
def catchall_handler(url, request): def catchall_handler(url, request):
return {'status_code': 404} return {"status_code": 404}
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/users$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/users$")
def users_handler(url, request): def users_handler(url, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
if url.query.find('knownuser') < 0: if url.query.find("knownuser") < 0:
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps([]),
},
'content': json.dumps([]),
} }
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([
{ {
"id": 1, "id": 1,
"username": "knownuser", "username": "knownuser",
@ -43,42 +40,42 @@ def users_handler(url, request):
"state": "active", "state": "active",
"avatar_url": "avatarurl", "avatar_url": "avatarurl",
"web_url": "https://bitbucket.org/knownuser", "web_url": "https://bitbucket.org/knownuser",
}, }
]), ]
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/user$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/user$")
def user_handler(_, request): def user_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 1, "id": 1,
"username": "john_smith", "username": "john_smith",
"email": "john@example.com", "email": "john@example.com",
"name": "John Smith", "name": "John Smith",
"state": "active", "state": "active",
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/foo%2Fbar$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/foo%2Fbar$")
def project_handler(_, request): def project_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 4, "id": 4,
"description": None, "description": None,
"default_branch": "master", "default_branch": "master",
@ -86,97 +83,89 @@ def project_handler(_, request):
"path_with_namespace": "someorg/somerepo", "path_with_namespace": "someorg/somerepo",
"ssh_url_to_repo": "git@example.com:someorg/somerepo.git", "ssh_url_to_repo": "git@example.com:someorg/somerepo.git",
"web_url": "http://example.com/someorg/somerepo", "web_url": "http://example.com/someorg/somerepo",
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tree$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tree$")
def project_tree_handler(_, request): def project_tree_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([
{ {
"id": "a1e8f8d745cc87e3a9248358d9352bb7f9a0aeba", "id": "a1e8f8d745cc87e3a9248358d9352bb7f9a0aeba",
"name": "Dockerfile", "name": "Dockerfile",
"type": "tree", "type": "tree",
"path": "files/Dockerfile", "path": "files/Dockerfile",
"mode": "040000", "mode": "040000",
}, }
]), ]
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tags$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tags$")
def project_tags_handler(_, request): def project_tags_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([
{ {
'name': 'sometag', "name": "sometag",
'commit': { "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
},
}, },
{ {
'name': 'someothertag', "name": "someothertag",
'commit': { "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
}, },
}, ]
]), ),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/branches$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/branches$")
def project_branches_handler(_, request): def project_branches_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([
{ {
'name': 'master', "name": "master",
'commit': { "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
},
}, },
{ {
'name': 'otherbranch', "name": "otherbranch",
'commit': { "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
}, },
}, ]
]), ),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/branches/master$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/branches/master$")
def project_branch_handler(_, request): def project_branch_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"name": "master", "name": "master",
"merged": True, "merged": True,
"protected": True, "protected": True,
@ -193,69 +182,68 @@ def project_branch_handler(_, request):
"short_id": "7b5c3cc", "short_id": "7b5c3cc",
"title": "add projects API", "title": "add projects API",
"message": "add projects API", "message": "add projects API",
"parent_ids": [ "parent_ids": ["4ad91d3c1144c406e50c7b33bae684bd6837faf8"],
"4ad91d3c1144c406e50c7b33bae684bd6837faf8",
],
}, },
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces/someorg$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces/someorg$")
def namespace_handler(_, request): def namespace_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 2, "id": 2,
"name": "someorg", "name": "someorg",
"path": "someorg", "path": "someorg",
"kind": "group", "kind": "group",
"full_path": "someorg", "full_path": "someorg",
"parent_id": None, "parent_id": None,
"members_count_with_descendants": 2 "members_count_with_descendants": 2,
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces/knownuser$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces/knownuser$")
def user_namespace_handler(_, request): def user_namespace_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 1, "id": 1,
"name": "knownuser", "name": "knownuser",
"path": "knownuser", "path": "knownuser",
"kind": "user", "kind": "user",
"full_path": "knownuser", "full_path": "knownuser",
"parent_id": None, "parent_id": None,
"members_count_with_descendants": 2 "members_count_with_descendants": 2,
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces(/)?$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces(/)?$")
def namespaces_handler(_, request): def namespaces_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([{ {
"id": 2, "id": 2,
"name": "someorg", "name": "someorg",
"path": "someorg", "path": "someorg",
@ -263,34 +251,30 @@ def namespaces_handler(_, request):
"full_path": "someorg", "full_path": "someorg",
"parent_id": None, "parent_id": None,
"web_url": "http://gitlab.com/groups/someorg", "web_url": "http://gitlab.com/groups/someorg",
"members_count_with_descendants": 2 "members_count_with_descendants": 2,
}]), }
]
),
} }
def get_projects_handler(add_permissions_block): def get_projects_handler(add_permissions_block):
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/groups/2/projects$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/groups/2/projects$")
def projects_handler(_, request): def projects_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
permissions_block = { permissions_block = {
"project_access": { "project_access": {"access_level": 10, "notification_level": 3},
"access_level": 10, "group_access": {"access_level": 20, "notification_level": 3},
"notification_level": 3
},
"group_access": {
"access_level": 20,
"notification_level": 3
},
} }
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([{ {
"id": 4, "id": 4,
"name": "Some project", "name": "Some project",
"description": None, "description": None,
@ -300,7 +284,9 @@ def get_projects_handler(add_permissions_block):
"path_with_namespace": "someorg/someproject", "path_with_namespace": "someorg/someproject",
"last_activity_at": "2013-09-30T13:46:02Z", "last_activity_at": "2013-09-30T13:46:02Z",
"web_url": "http://example.com/someorg/someproject", "web_url": "http://example.com/someorg/someproject",
"permissions": permissions_block if add_permissions_block else None, "permissions": permissions_block
if add_permissions_block
else None,
}, },
{ {
"id": 5, "id": 5,
@ -312,99 +298,105 @@ def get_projects_handler(add_permissions_block):
"path_with_namespace": "someorg/anotherproject", "path_with_namespace": "someorg/anotherproject",
"last_activity_at": "2013-09-30T13:46:02Z", "last_activity_at": "2013-09-30T13:46:02Z",
"web_url": "http://example.com/someorg/anotherproject", "web_url": "http://example.com/someorg/anotherproject",
}]), },
]
),
} }
return projects_handler return projects_handler
def get_group_handler(null_avatar): def get_group_handler(null_avatar):
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/groups/2$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/groups/2$")
def group_handler(_, request): def group_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 1, "id": 1,
"name": "SomeOrg Group", "name": "SomeOrg Group",
"path": "someorg", "path": "someorg",
"description": "An interesting group", "description": "An interesting group",
"visibility": "public", "visibility": "public",
"lfs_enabled": True, "lfs_enabled": True,
"avatar_url": 'avatar_url' if not null_avatar else None, "avatar_url": "avatar_url" if not null_avatar else None,
"web_url": "http://gitlab.com/groups/someorg", "web_url": "http://gitlab.com/groups/someorg",
"request_access_enabled": False, "request_access_enabled": False,
"full_name": "SomeOrg Group", "full_name": "SomeOrg Group",
"full_path": "someorg", "full_path": "someorg",
"parent_id": None, "parent_id": None,
}),
} }
),
}
return group_handler return group_handler
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/files/Dockerfile$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/files/Dockerfile$")
def dockerfile_handler(_, request): def dockerfile_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"file_name": "Dockerfile", "file_name": "Dockerfile",
"file_path": "Dockerfile", "file_path": "Dockerfile",
"size": 10, "size": 10,
"encoding": "base64", "encoding": "base64",
"content": base64.b64encode('hello world'), "content": base64.b64encode("hello world"),
"ref": "master", "ref": "master",
"blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83", "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
"commit_id": "d5a3ff139356ce33e37e73add446f16869741b50", "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
"last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d" "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d",
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/files/somesubdir%2FDockerfile$') @urlmatch(
netloc=r"fakegitlab",
path=r"/api/v4/projects/4/repository/files/somesubdir%2FDockerfile$",
)
def sub_dockerfile_handler(_, request): def sub_dockerfile_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"file_name": "Dockerfile", "file_name": "Dockerfile",
"file_path": "somesubdir/Dockerfile", "file_path": "somesubdir/Dockerfile",
"size": 10, "size": 10,
"encoding": "base64", "encoding": "base64",
"content": base64.b64encode('hi universe'), "content": base64.b64encode("hi universe"),
"ref": "master", "ref": "master",
"blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83", "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
"commit_id": "d5a3ff139356ce33e37e73add446f16869741b50", "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
"last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d" "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d",
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tags/sometag$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tags/sometag$")
def tag_handler(_, request): def tag_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"name": "sometag", "name": "sometag",
"message": "some cool message", "message": "some cool message",
"target": "60a8ff033665e1207714d6670fcd7b65304ec02f", "target": "60a8ff033665e1207714d6670fcd7b65304ec02f",
@ -413,33 +405,34 @@ def tag_handler(_, request):
"short_id": "60a8ff03", "short_id": "60a8ff03",
"title": "Initial commit", "title": "Initial commit",
"created_at": "2017-07-26T11:08:53.000+02:00", "created_at": "2017-07-26T11:08:53.000+02:00",
"parent_ids": [ "parent_ids": ["f61c062ff8bcbdb00e0a1b3317a91aed6ceee06b"],
"f61c062ff8bcbdb00e0a1b3317a91aed6ceee06b"
],
"message": "v5.0.0\n", "message": "v5.0.0\n",
"author_name": "Arthur Verschaeve", "author_name": "Arthur Verschaeve",
"author_email": "contact@arthurverschaeve.be", "author_email": "contact@arthurverschaeve.be",
"authored_date": "2015-02-01T21:56:31.000+01:00", "authored_date": "2015-02-01T21:56:31.000+01:00",
"committer_name": "Arthur Verschaeve", "committer_name": "Arthur Verschaeve",
"committer_email": "contact@arthurverschaeve.be", "committer_email": "contact@arthurverschaeve.be",
"committed_date": "2015-02-01T21:56:31.000+01:00" "committed_date": "2015-02-01T21:56:31.000+01:00",
}, },
"release": None, "release": None,
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/foo%2Fbar/repository/commits/60a8ff033665e1207714d6670fcd7b65304ec02f$') @urlmatch(
netloc=r"fakegitlab",
path=r"/api/v4/projects/foo%2Fbar/repository/commits/60a8ff033665e1207714d6670fcd7b65304ec02f$",
)
def commit_handler(_, request): def commit_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": "60a8ff033665e1207714d6670fcd7b65304ec02f", "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
"short_id": "60a8ff03366", "short_id": "60a8ff03366",
"title": "Sanitize for network graph", "title": "Sanitize for network graph",
@ -451,56 +444,50 @@ def commit_handler(_, request):
"message": "Sanitize for network graph", "message": "Sanitize for network graph",
"committed_date": "2012-09-20T09:06:12+03:00", "committed_date": "2012-09-20T09:06:12+03:00",
"authored_date": "2012-09-20T09:06:12+03:00", "authored_date": "2012-09-20T09:06:12+03:00",
"parent_ids": [ "parent_ids": ["ae1d9fb46aa2b07ee9836d49862ec4e2c46fbbba"],
"ae1d9fb46aa2b07ee9836d49862ec4e2c46fbbba"
],
"last_pipeline": { "last_pipeline": {
"id": 8, "id": 8,
"ref": "master", "ref": "master",
"sha": "2dc6aa325a317eda67812f05600bdf0fcdc70ab0", "sha": "2dc6aa325a317eda67812f05600bdf0fcdc70ab0",
"status": "created", "status": "created",
}, },
"stats": { "stats": {"additions": 15, "deletions": 10, "total": 25},
"additions": 15, "status": "running",
"deletions": 10, }
"total": 25 ),
},
"status": "running"
}),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/deploy_keys$', method='POST') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/deploy_keys$", method="POST")
def create_deploykey_handler(_, request): def create_deploykey_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 1, "id": 1,
"title": "Public key", "title": "Public key",
"key": "ssh-rsa some stuff", "key": "ssh-rsa some stuff",
"created_at": "2013-10-02T10:12:29Z", "created_at": "2013-10-02T10:12:29Z",
"can_push": False, "can_push": False,
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/hooks$', method='POST') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/hooks$", method="POST")
def create_hook_handler(_, request): def create_hook_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, {
'content': json.dumps({
"id": 1, "id": 1,
"url": "http://example.com/hook", "url": "http://example.com/hook",
"project_id": 4, "project_id": 4,
@ -515,49 +502,47 @@ def create_hook_handler(_, request):
"wiki_page_events": True, "wiki_page_events": True,
"enable_ssl_verification": True, "enable_ssl_verification": True,
"created_at": "2012-10-12T17:04:47Z", "created_at": "2012-10-12T17:04:47Z",
}), }
),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/hooks/1$', method='DELETE') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/hooks/1$", method="DELETE")
def delete_hook_handler(_, request): def delete_hook_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps({}),
},
'content': json.dumps({}),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/deploy_keys/1$', method='DELETE') @urlmatch(
netloc=r"fakegitlab", path=r"/api/v4/projects/4/deploy_keys/1$", method="DELETE"
)
def delete_deploykey_handker(_, request): def delete_deploykey_handker(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps({}),
},
'content': json.dumps({}),
} }
@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/users/1/projects$') @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/users/1/projects$")
def user_projects_list_handler(_, request): def user_projects_list_handler(_, request):
if not request.headers.get('Authorization') == 'Bearer foobar': if not request.headers.get("Authorization") == "Bearer foobar":
return {'status_code': 401} return {"status_code": 401}
return { return {
'status_code': 200, "status_code": 200,
'headers': { "headers": {"Content-Type": "application/json"},
'Content-Type': 'application/json', "content": json.dumps(
}, [
'content': json.dumps([
{ {
"id": 2, "id": 2,
"name": "Another project", "name": "Another project",
@ -569,29 +554,54 @@ def user_projects_list_handler(_, request):
"last_activity_at": "2013-09-30T13:46:02Z", "last_activity_at": "2013-09-30T13:46:02Z",
"web_url": "http://example.com/knownuser/anotherproject", "web_url": "http://example.com/knownuser/anotherproject",
} }
]), ]
),
} }
@contextmanager @contextmanager
def get_gitlab_trigger(dockerfile_path='', add_permissions=True, missing_avatar_url=False): def get_gitlab_trigger(
handlers = [user_handler, users_handler, project_branches_handler, project_tree_handler, dockerfile_path="", add_permissions=True, missing_avatar_url=False
project_handler, get_projects_handler(add_permissions), tag_handler, ):
project_branch_handler, get_group_handler(missing_avatar_url), dockerfile_handler, handlers = [
sub_dockerfile_handler, namespace_handler, user_namespace_handler, namespaces_handler, user_handler,
commit_handler, create_deploykey_handler, delete_deploykey_handker, users_handler,
create_hook_handler, delete_hook_handler, project_tags_handler, project_branches_handler,
user_projects_list_handler, catchall_handler] project_tree_handler,
project_handler,
get_projects_handler(add_permissions),
tag_handler,
project_branch_handler,
get_group_handler(missing_avatar_url),
dockerfile_handler,
sub_dockerfile_handler,
namespace_handler,
user_namespace_handler,
namespaces_handler,
commit_handler,
create_deploykey_handler,
delete_deploykey_handker,
create_hook_handler,
delete_hook_handler,
project_tags_handler,
user_projects_list_handler,
catchall_handler,
]
with HTTMock(*handlers): with HTTMock(*handlers):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger')) trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger = GitLabBuildTrigger(trigger_obj, { trigger = GitLabBuildTrigger(
'build_source': 'foo/bar', trigger_obj,
'dockerfile_path': dockerfile_path, {
'username': 'knownuser' "build_source": "foo/bar",
}) "dockerfile_path": dockerfile_path,
"username": "knownuser",
},
)
client = gitlab.Gitlab('http://fakegitlab', oauth_token='foobar', timeout=20, api_version=4) client = gitlab.Gitlab(
"http://fakegitlab", oauth_token="foobar", timeout=20, api_version=4
)
client.auth() client.auth()
trigger._get_authorized_client = lambda: client trigger._get_authorized_client = lambda: client

View file

@ -3,19 +3,24 @@ import pytest
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
@pytest.mark.parametrize('input,output', [ @pytest.mark.parametrize(
"input,output",
[
("Dockerfile", True), ("Dockerfile", True),
("server.Dockerfile", True), ("server.Dockerfile", True),
(u"Dockerfile", True), (u"Dockerfile", True),
(u"server.Dockerfile", True), (u"server.Dockerfile", True),
("bad file name", False), ("bad file name", False),
(u"bad file name", False), (u"bad file name", False),
]) ],
)
def test_path_is_dockerfile(input, output): def test_path_is_dockerfile(input, output):
assert BuildTriggerHandler.filename_is_dockerfile(input) == output assert BuildTriggerHandler.filename_is_dockerfile(input) == output
@pytest.mark.parametrize('input,output', [ @pytest.mark.parametrize(
"input,output",
[
("", {}), ("", {}),
("/a", {"/a": ["/"]}), ("/a", {"/a": ["/"]}),
("a", {"/a": ["/"]}), ("a", {"/a": ["/"]}),
@ -24,7 +29,8 @@ def test_path_is_dockerfile(input, output):
("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}), ("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}), ("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
("/a", {"/a": ["/"]}), ("/a", {"/a": ["/"]}),
]) ],
)
def test_subdir_path_map_no_previous(input, output): def test_subdir_path_map_no_previous(input, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input) actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
for key in actual_mapping: for key in actual_mapping:
@ -37,14 +43,29 @@ def test_subdir_path_map_no_previous(input, output):
assert actual_mapping == output assert actual_mapping == output
@pytest.mark.parametrize('new_path,original_dictionary,output', [ @pytest.mark.parametrize(
"new_path,original_dictionary,output",
[
("/a", {}, {"/a": ["/"]}), ("/a", {}, {"/a": ["/"]}),
("b", {"/a": ["some_path", "another_path"]}, {"/a": ["some_path", "another_path"], "/b": ["/"]}), (
("/a/b/c/d", {"/e": ["some_path", "another_path"]}, "b",
{"/e": ["some_path", "another_path"], "/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"]}), {"/a": ["some_path", "another_path"]},
]) {"/a": ["some_path", "another_path"], "/b": ["/"]},
),
(
"/a/b/c/d",
{"/e": ["some_path", "another_path"]},
{
"/e": ["some_path", "another_path"],
"/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"],
},
),
],
)
def test_subdir_path_map(new_path, original_dictionary, output): def test_subdir_path_map(new_path, original_dictionary, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(new_path, original_dictionary) actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(
new_path, original_dictionary
)
for key in actual_mapping: for key in actual_mapping:
value = actual_mapping[key] value = actual_mapping[key]
actual_mapping[key] = value.sort() actual_mapping[key] = value.sort()

View file

@ -2,11 +2,15 @@ import json
import pytest import pytest
from buildtrigger.test.bitbucketmock import get_bitbucket_trigger from buildtrigger.test.bitbucketmock import get_bitbucket_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture @pytest.fixture
def bitbucket_trigger(): def bitbucket_trigger():
return get_bitbucket_trigger() return get_bitbucket_trigger()
@ -16,21 +20,26 @@ def test_list_build_subdirs(bitbucket_trigger):
assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"] assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_bitbucket_trigger(dockerfile_path) trigger = get_bitbucket_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
('{}', InvalidPayloadException, "'push' is a required property"), "payload, expected_error, expected_message",
[
("{}", InvalidPayloadException, "'push' is a required property"),
# Valid payload: # Valid payload:
('''{ (
"""{
"push": { "push": {
"changes": [{ "changes": [{
"new": { "new": {
@ -51,10 +60,13 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": { "repository": {
"full_name": "foo/bar" "full_name": "foo/bar"
} }
}''', None, None), }""",
None,
None,
),
# Skip message: # Skip message:
('''{ (
"""{
"push": { "push": {
"changes": [{ "changes": [{
"new": { "new": {
@ -75,9 +87,15 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": { "repository": {
"full_name": "foo/bar" "full_name": "foo/bar"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(bitbucket_trigger, payload, expected_error, expected_message): "",
),
],
)
def test_handle_trigger_request(
bitbucket_trigger, payload, expected_error, expected_message
):
def get_payload(): def get_payload():
return json.loads(payload) return json.loads(payload)
@ -88,4 +106,6 @@ def test_handle_trigger_request(bitbucket_trigger, payload, expected_error, expe
bitbucket_trigger.handle_trigger_request(request) bitbucket_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message assert str(ipe.value) == expected_message
else: else:
assert isinstance(bitbucket_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(
bitbucket_trigger.handle_trigger_request(request), PreparedBuild
)

View file

@ -1,20 +1,32 @@
import pytest import pytest
from buildtrigger.customhandler import CustomBuildTrigger from buildtrigger.customhandler import CustomBuildTrigger
from buildtrigger.triggerutil import (InvalidPayloadException, SkipRequestException, from buildtrigger.triggerutil import (
TriggerStartException) InvalidPayloadException,
SkipRequestException,
TriggerStartException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('', InvalidPayloadException, 'Missing expected payload'),
('{}', InvalidPayloadException, "'commit' is a required property"),
('{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}', @pytest.mark.parametrize(
InvalidPayloadException, "u'foo' does not match '^([A-Fa-f0-9]{7,})$'"), "payload, expected_error, expected_message",
[
('{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}', None, None), ("", InvalidPayloadException, "Missing expected payload"),
('''{ ("{}", InvalidPayloadException, "'commit' is a required property"),
(
'{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
InvalidPayloadException,
"u'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
(
'{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}',
None,
None,
),
(
"""{
"commit": "11d6fbc", "commit": "11d6fbc",
"ref": "refs/heads/something", "ref": "refs/heads/something",
"default_branch": "baz", "default_branch": "baz",
@ -23,10 +35,14 @@ from util.morecollections import AttrDict
"url": "http://foo.bar", "url": "http://foo.bar",
"date": "NOW" "date": "NOW"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(payload, expected_error, expected_message): def test_handle_trigger_request(payload, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'}) trigger = CustomBuildTrigger(None, {"build_source": "foo"})
request = AttrDict(dict(data=payload)) request = AttrDict(dict(data=payload))
if expected_error is not None: if expected_error is not None:
@ -36,13 +52,21 @@ def test_handle_trigger_request(payload, expected_error, expected_message):
else: else:
assert isinstance(trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
({}, TriggerStartException, 'missing required parameter'), @pytest.mark.parametrize(
({'commit_sha': 'foo'}, TriggerStartException, "'foo' does not match '^([A-Fa-f0-9]{7,})$'"), "run_parameters, expected_error, expected_message",
({'commit_sha': '11d6fbc'}, None, None), [
]) ({}, TriggerStartException, "missing required parameter"),
(
{"commit_sha": "foo"},
TriggerStartException,
"'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
({"commit_sha": "11d6fbc"}, None, None),
],
)
def test_manual_start(run_parameters, expected_error, expected_message): def test_manual_start(run_parameters, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'}) trigger = CustomBuildTrigger(None, {"build_source": "foo"})
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
trigger.manual_start(run_parameters) trigger.manual_start(run_parameters)

View file

@ -11,25 +11,33 @@ from endpoints.building import PreparedBuild
def githost_trigger(request): def githost_trigger(request):
return request.param return request.param
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
@pytest.mark.parametrize(
"run_parameters, expected_error, expected_message",
[
# No branch or tag specified: use the commit of the default branch. # No branch or tag specified: use the commit of the default branch.
({}, None, None), ({}, None, None),
# Invalid branch. # Invalid branch.
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException, (
'Could not find branch in repository'), {"refs": {"kind": "branch", "name": "invalid"}},
TriggerStartException,
"Could not find branch in repository",
),
# Invalid tag. # Invalid tag.
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException, (
'Could not find tag in repository'), {"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
"Could not find tag in repository",
),
# Valid branch. # Valid branch.
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None), ({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag. # Valid tag.
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None), ({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
]) ],
def test_manual_start(run_parameters, expected_error, expected_message, githost_trigger): )
def test_manual_start(
run_parameters, expected_error, expected_message, githost_trigger
):
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
githost_trigger.manual_start(run_parameters) githost_trigger.manual_start(run_parameters)
@ -38,17 +46,23 @@ def test_manual_start(run_parameters, expected_error, expected_message, githost_
assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild) assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
@pytest.mark.parametrize('name, expected', [ @pytest.mark.parametrize(
('refs', [ "name, expected",
{'kind': 'branch', 'name': 'master'}, [
{'kind': 'branch', 'name': 'otherbranch'}, (
{'kind': 'tag', 'name': 'sometag'}, "refs",
{'kind': 'tag', 'name': 'someothertag'}, [
]), {"kind": "branch", "name": "master"},
('tag_name', set(['sometag', 'someothertag'])), {"kind": "branch", "name": "otherbranch"},
('branch_name', set(['master', 'otherbranch'])), {"kind": "tag", "name": "sometag"},
('invalid', None) {"kind": "tag", "name": "someothertag"},
]) ],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, githost_trigger): def test_list_field_values(name, expected, githost_trigger):
if expected is None: if expected is None:
assert githost_trigger.list_field_values(name) is None assert githost_trigger.list_field_values(name) is None
@ -61,21 +75,21 @@ def test_list_field_values(name, expected, githost_trigger):
def test_list_build_source_namespaces(): def test_list_build_source_namespaces():
namespaces_expected = [ namespaces_expected = [
{ {
'personal': True, "personal": True,
'score': 1, "score": 1,
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'knownuser', "id": "knownuser",
'title': 'knownuser', "title": "knownuser",
'url': 'https://bitbucket.org/knownuser', "url": "https://bitbucket.org/knownuser",
}, },
{ {
'score': 2, "score": 2,
'title': 'someorg', "title": "someorg",
'personal': False, "personal": False,
'url': 'https://bitbucket.org/someorg', "url": "https://bitbucket.org/someorg",
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'someorg' "id": "someorg",
} },
] ]
found = get_bitbucket_trigger().list_build_source_namespaces() found = get_bitbucket_trigger().list_build_source_namespaces()
@ -85,37 +99,55 @@ def test_list_build_source_namespaces():
assert found == namespaces_expected assert found == namespaces_expected
@pytest.mark.parametrize('namespace, expected', [ @pytest.mark.parametrize(
('', []), "namespace, expected",
('unknown', []), [
("", []),
('knownuser', [ ("unknown", []),
(
"knownuser",
[
{ {
'last_updated': 0, 'name': 'somerepo', "last_updated": 0,
'url': 'https://bitbucket.org/knownuser/somerepo', 'private': True, "name": "somerepo",
'full_name': 'knownuser/somerepo', 'has_admin_permissions': True, "url": "https://bitbucket.org/knownuser/somerepo",
'description': 'some somerepo repo' "private": True,
}]), "full_name": "knownuser/somerepo",
"has_admin_permissions": True,
('someorg', [ "description": "some somerepo repo",
}
],
),
(
"someorg",
[
{ {
'last_updated': 0, 'name': 'somerepo', "last_updated": 0,
'url': 'https://bitbucket.org/someorg/somerepo', 'private': True, "name": "somerepo",
'full_name': 'someorg/somerepo', 'has_admin_permissions': False, "url": "https://bitbucket.org/someorg/somerepo",
'description': 'some somerepo repo' "private": True,
"full_name": "someorg/somerepo",
"has_admin_permissions": False,
"description": "some somerepo repo",
}, },
{ {
'last_updated': 0, 'name': 'anotherrepo', "last_updated": 0,
'url': 'https://bitbucket.org/someorg/anotherrepo', 'private': False, "name": "anotherrepo",
'full_name': 'someorg/anotherrepo', 'has_admin_permissions': False, "url": "https://bitbucket.org/someorg/anotherrepo",
'description': 'some anotherrepo repo' "private": False,
}]), "full_name": "someorg/anotherrepo",
]) "has_admin_permissions": False,
"description": "some anotherrepo repo",
},
],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, githost_trigger): def test_list_build_sources_for_namespace(namespace, expected, githost_trigger):
assert githost_trigger.list_build_sources_for_namespace(namespace) == expected assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
def test_activate_and_deactivate(githost_trigger): def test_activate_and_deactivate(githost_trigger):
_, private_key = githost_trigger.activate('http://some/url') _, private_key = githost_trigger.activate("http://some/url")
assert 'private_key' in private_key assert "private_key" in private_key
githost_trigger.deactivate() githost_trigger.deactivate()

View file

@ -2,24 +2,33 @@ import json
import pytest import pytest
from buildtrigger.test.githubmock import get_github_trigger from buildtrigger.test.githubmock import get_github_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture @pytest.fixture
def github_trigger(): def github_trigger():
return get_github_trigger() return get_github_trigger()
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
"payload, expected_error, expected_message",
[
('{"zen": true}', SkipRequestException, ""), ('{"zen": true}', SkipRequestException, ""),
("{}", InvalidPayloadException, "Missing 'repository' on request"),
('{}', InvalidPayloadException, "Missing 'repository' on request"), (
('{"repository": "foo"}', InvalidPayloadException, "Missing 'owner' on repository"), '{"repository": "foo"}',
InvalidPayloadException,
"Missing 'owner' on repository",
),
# Valid payload: # Valid payload:
('''{ (
"""{
"repository": { "repository": {
"owner": { "owner": {
"name": "someguy" "name": "someguy"
@ -34,10 +43,13 @@ def github_trigger():
"message": "some message", "message": "some message",
"timestamp": "NOW" "timestamp": "NOW"
} }
}''', None, None), }""",
None,
None,
),
# Skip message: # Skip message:
('''{ (
"""{
"repository": { "repository": {
"owner": { "owner": {
"name": "someguy" "name": "someguy"
@ -52,9 +64,15 @@ def github_trigger():
"message": "[skip build]", "message": "[skip build]",
"timestamp": "NOW" "timestamp": "NOW"
} }
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(github_trigger, payload, expected_error, expected_message): "",
),
],
)
def test_handle_trigger_request(
github_trigger, payload, expected_error, expected_message
):
def get_payload(): def get_payload():
return json.loads(payload) return json.loads(payload)
@ -68,46 +86,58 @@ def test_handle_trigger_request(github_trigger, payload, expected_error, expecte
assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_github_trigger(dockerfile_path) trigger = get_github_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('username, expected_response', [ @pytest.mark.parametrize(
('unknownuser', None), "username, expected_response",
('knownuser', {'html_url': 'https://bitbucket.org/knownuser', 'avatar_url': 'avatarurl'}), [
]) ("unknownuser", None),
(
"knownuser",
{"html_url": "https://bitbucket.org/knownuser", "avatar_url": "avatarurl"},
),
],
)
def test_lookup_user(username, expected_response, github_trigger): def test_lookup_user(username, expected_response, github_trigger):
assert github_trigger.lookup_user(username) == expected_response assert github_trigger.lookup_user(username) == expected_response
def test_list_build_subdirs(github_trigger): def test_list_build_subdirs(github_trigger):
assert github_trigger.list_build_subdirs() == ['Dockerfile', 'somesubdir/Dockerfile'] assert github_trigger.list_build_subdirs() == [
"Dockerfile",
"somesubdir/Dockerfile",
]
def test_list_build_source_namespaces(github_trigger): def test_list_build_source_namespaces(github_trigger):
namespaces_expected = [ namespaces_expected = [
{ {
'personal': True, "personal": True,
'score': 1, "score": 1,
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'knownuser', "id": "knownuser",
'title': 'knownuser', "title": "knownuser",
'url': 'https://bitbucket.org/knownuser', "url": "https://bitbucket.org/knownuser",
}, },
{ {
'score': 0, "score": 0,
'title': 'someorg', "title": "someorg",
'personal': False, "personal": False,
'url': '', "url": "",
'avatar_url': 'avatarurl', "avatar_url": "avatarurl",
'id': 'someorg' "id": "someorg",
} },
] ]
found = github_trigger.list_build_source_namespaces() found = github_trigger.list_build_source_namespaces()

View file

@ -4,11 +4,16 @@ import pytest
from mock import Mock from mock import Mock
from buildtrigger.test.gitlabmock import get_gitlab_trigger from buildtrigger.test.gitlabmock import get_gitlab_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException, from buildtrigger.triggerutil import (
InvalidPayloadException, TriggerStartException) SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
TriggerStartException,
)
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from util.morecollections import AttrDict from util.morecollections import AttrDict
@pytest.fixture() @pytest.fixture()
def gitlab_trigger(): def gitlab_trigger():
with get_gitlab_trigger() as t: with get_gitlab_trigger() as t:
@ -16,79 +21,94 @@ def gitlab_trigger():
def test_list_build_subdirs(gitlab_trigger): def test_list_build_subdirs(gitlab_trigger):
assert gitlab_trigger.list_build_subdirs() == ['Dockerfile'] assert gitlab_trigger.list_build_subdirs() == ["Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [ @pytest.mark.parametrize(
('/Dockerfile', 'hello world'), "dockerfile_path, contents",
('somesubdir/Dockerfile', 'hi universe'), [
('unknownpath', None), ("/Dockerfile", "hello world"),
]) ("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents): def test_load_dockerfile_contents(dockerfile_path, contents):
with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger: with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
assert trigger.load_dockerfile_contents() == contents assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('email, expected_response', [ @pytest.mark.parametrize(
('unknown@email.com', None), "email, expected_response",
('knownuser', {'username': 'knownuser', 'html_url': 'https://bitbucket.org/knownuser', [
'avatar_url': 'avatarurl'}), ("unknown@email.com", None),
]) (
"knownuser",
{
"username": "knownuser",
"html_url": "https://bitbucket.org/knownuser",
"avatar_url": "avatarurl",
},
),
],
)
def test_lookup_user(email, expected_response, gitlab_trigger): def test_lookup_user(email, expected_response, gitlab_trigger):
assert gitlab_trigger.lookup_user(email) == expected_response assert gitlab_trigger.lookup_user(email) == expected_response
def test_null_permissions(): def test_null_permissions():
with get_gitlab_trigger(add_permissions=False) as trigger: with get_gitlab_trigger(add_permissions=False) as trigger:
sources = trigger.list_build_sources_for_namespace('someorg') sources = trigger.list_build_sources_for_namespace("someorg")
source = sources[0] source = sources[0]
assert source['has_admin_permissions'] assert source["has_admin_permissions"]
def test_list_build_sources(): def test_list_build_sources():
with get_gitlab_trigger() as trigger: with get_gitlab_trigger() as trigger:
sources = trigger.list_build_sources_for_namespace('someorg') sources = trigger.list_build_sources_for_namespace("someorg")
assert sources == [ assert sources == [
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'someproject', "name": u"someproject",
'url': u'http://example.com/someorg/someproject', "url": u"http://example.com/someorg/someproject",
'private': True, "private": True,
'full_name': u'someorg/someproject', "full_name": u"someorg/someproject",
'has_admin_permissions': False, "has_admin_permissions": False,
'description': '' "description": "",
}, },
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'anotherproject', "name": u"anotherproject",
'url': u'http://example.com/someorg/anotherproject', "url": u"http://example.com/someorg/anotherproject",
'private': False, "private": False,
'full_name': u'someorg/anotherproject', "full_name": u"someorg/anotherproject",
'has_admin_permissions': True, "has_admin_permissions": True,
'description': '', "description": "",
}] },
]
def test_null_avatar(): def test_null_avatar():
with get_gitlab_trigger(missing_avatar_url=True) as trigger: with get_gitlab_trigger(missing_avatar_url=True) as trigger:
namespace_data = trigger.list_build_source_namespaces() namespace_data = trigger.list_build_source_namespaces()
expected = { expected = {
'avatar_url': None, "avatar_url": None,
'personal': False, "personal": False,
'title': u'someorg', "title": u"someorg",
'url': u'http://gitlab.com/groups/someorg', "url": u"http://gitlab.com/groups/someorg",
'score': 1, "score": 1,
'id': '2', "id": "2",
} }
assert namespace_data == [expected] assert namespace_data == [expected]
@pytest.mark.parametrize('payload, expected_error, expected_message', [ @pytest.mark.parametrize(
('{}', InvalidPayloadException, ''), "payload, expected_error, expected_message",
[
("{}", InvalidPayloadException, ""),
# Valid payload: # Valid payload:
('''{ (
"""{
"object_kind": "push", "object_kind": "push",
"ref": "refs/heads/master", "ref": "refs/heads/master",
"checkout_sha": "aaaaaaa", "checkout_sha": "aaaaaaa",
@ -103,10 +123,13 @@ def test_null_avatar():
"timestamp": "now" "timestamp": "now"
} }
] ]
}''', None, None), }""",
None,
None,
),
# Skip message: # Skip message:
('''{ (
"""{
"object_kind": "push", "object_kind": "push",
"ref": "refs/heads/master", "ref": "refs/heads/master",
"checkout_sha": "aaaaaaa", "checkout_sha": "aaaaaaa",
@ -121,9 +144,15 @@ def test_null_avatar():
"timestamp": "now" "timestamp": "now"
} }
] ]
}''', SkipRequestException, ''), }""",
]) SkipRequestException,
def test_handle_trigger_request(gitlab_trigger, payload, expected_error, expected_message): "",
),
],
)
def test_handle_trigger_request(
gitlab_trigger, payload, expected_error, expected_message
):
def get_payload(): def get_payload():
return json.loads(payload) return json.loads(payload)
@ -137,24 +166,29 @@ def test_handle_trigger_request(gitlab_trigger, payload, expected_error, expecte
assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild) assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [ @pytest.mark.parametrize(
"run_parameters, expected_error, expected_message",
[
# No branch or tag specified: use the commit of the default branch. # No branch or tag specified: use the commit of the default branch.
({}, None, None), ({}, None, None),
# Invalid branch. # Invalid branch.
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException, (
'Could not find branch in repository'), {"refs": {"kind": "branch", "name": "invalid"}},
TriggerStartException,
"Could not find branch in repository",
),
# Invalid tag. # Invalid tag.
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException, (
'Could not find tag in repository'), {"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
"Could not find tag in repository",
),
# Valid branch. # Valid branch.
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None), ({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag. # Valid tag.
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None), ({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
]) ],
)
def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger): def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger):
if expected_error is not None: if expected_error is not None:
with pytest.raises(expected_error) as ipe: with pytest.raises(expected_error) as ipe:
@ -165,23 +199,29 @@ def test_manual_start(run_parameters, expected_error, expected_message, gitlab_t
def test_activate_and_deactivate(gitlab_trigger): def test_activate_and_deactivate(gitlab_trigger):
_, private_key = gitlab_trigger.activate('http://some/url') _, private_key = gitlab_trigger.activate("http://some/url")
assert 'private_key' in private_key assert "private_key" in private_key
gitlab_trigger.deactivate() gitlab_trigger.deactivate()
@pytest.mark.parametrize('name, expected', [ @pytest.mark.parametrize(
('refs', [ "name, expected",
{'kind': 'branch', 'name': 'master'}, [
{'kind': 'branch', 'name': 'otherbranch'}, (
{'kind': 'tag', 'name': 'sometag'}, "refs",
{'kind': 'tag', 'name': 'someothertag'}, [
]), {"kind": "branch", "name": "master"},
('tag_name', set(['sometag', 'someothertag'])), {"kind": "branch", "name": "otherbranch"},
('branch_name', set(['master', 'otherbranch'])), {"kind": "tag", "name": "sometag"},
('invalid', None) {"kind": "tag", "name": "someothertag"},
]) ],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, gitlab_trigger): def test_list_field_values(name, expected, gitlab_trigger):
if expected is None: if expected is None:
assert gitlab_trigger.list_field_values(name) is None assert gitlab_trigger.list_field_values(name) is None
@ -191,41 +231,49 @@ def test_list_field_values(name, expected, gitlab_trigger):
assert gitlab_trigger.list_field_values(name) == expected assert gitlab_trigger.list_field_values(name) == expected
@pytest.mark.parametrize('namespace, expected', [ @pytest.mark.parametrize(
('', []), "namespace, expected",
('unknown', []), [
("", []),
('knownuser', [ ("unknown", []),
(
"knownuser",
[
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'anotherproject', "name": u"anotherproject",
'url': u'http://example.com/knownuser/anotherproject', "url": u"http://example.com/knownuser/anotherproject",
'private': False, "private": False,
'full_name': u'knownuser/anotherproject', "full_name": u"knownuser/anotherproject",
'has_admin_permissions': True, "has_admin_permissions": True,
'description': '' "description": "",
}, }
]), ],
),
('someorg', [ (
"someorg",
[
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'someproject', "name": u"someproject",
'url': u'http://example.com/someorg/someproject', "url": u"http://example.com/someorg/someproject",
'private': True, "private": True,
'full_name': u'someorg/someproject', "full_name": u"someorg/someproject",
'has_admin_permissions': False, "has_admin_permissions": False,
'description': '' "description": "",
}, },
{ {
'last_updated': 1380548762, "last_updated": 1380548762,
'name': u'anotherproject', "name": u"anotherproject",
'url': u'http://example.com/someorg/anotherproject', "url": u"http://example.com/someorg/anotherproject",
'private': False, "private": False,
'full_name': u'someorg/anotherproject', "full_name": u"someorg/anotherproject",
'has_admin_permissions': True, "has_admin_permissions": True,
'description': '', "description": "",
}]), },
]) ],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger): def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger):
assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected

View file

@ -12,8 +12,9 @@ from buildtrigger.githubhandler import get_transformed_webhook_payload as gh_web
from buildtrigger.gitlabhandler import get_transformed_webhook_payload as gl_webhook from buildtrigger.gitlabhandler import get_transformed_webhook_payload as gl_webhook
from buildtrigger.triggerutil import SkipRequestException from buildtrigger.triggerutil import SkipRequestException
def assertSkipped(filename, processor, *args, **kwargs): def assertSkipped(filename, processor, *args, **kwargs):
with open('buildtrigger/test/triggerjson/%s.json' % filename) as f: with open("buildtrigger/test/triggerjson/%s.json" % filename) as f:
payload = json.loads(f.read()) payload = json.loads(f.read())
nargs = [payload] nargs = [payload]
@ -24,7 +25,7 @@ def assertSkipped(filename, processor, *args, **kwargs):
def assertSchema(filename, expected, processor, *args, **kwargs): def assertSchema(filename, expected, processor, *args, **kwargs):
with open('buildtrigger/test/triggerjson/%s.json' % filename) as f: with open("buildtrigger/test/triggerjson/%s.json" % filename) as f:
payload = json.loads(f.read()) payload = json.loads(f.read())
nargs = [payload] nargs = [payload]
@ -37,66 +38,71 @@ def assertSchema(filename, expected, processor, *args, **kwargs):
def test_custom_custom(): def test_custom_custom():
expected = { expected = {
u'commit':u'1c002dd', u"commit": u"1c002dd",
u'commit_info': { u"commit_info": {
u'url': u'gitsoftware.com/repository/commits/1234567', u"url": u"gitsoftware.com/repository/commits/1234567",
u'date': u'timestamp', u"date": u"timestamp",
u'message': u'initial commit', u"message": u"initial commit",
u'committer': { u"committer": {
u'username': u'user', u"username": u"user",
u'url': u'gitsoftware.com/users/user', u"url": u"gitsoftware.com/users/user",
u'avatar_url': u'gravatar.com/user.png' u"avatar_url": u"gravatar.com/user.png",
}, },
u'author': { u"author": {
u'username': u'user', u"username": u"user",
u'url': u'gitsoftware.com/users/user', u"url": u"gitsoftware.com/users/user",
u'avatar_url': u'gravatar.com/user.png' u"avatar_url": u"gravatar.com/user.png",
}
}, },
u'ref': u'refs/heads/master', },
u'default_branch': u'master', u"ref": u"refs/heads/master",
u'git_url': u'foobar', u"default_branch": u"master",
u"git_url": u"foobar",
} }
assertSchema('custom_webhook', expected, custom_trigger_payload, git_url='foobar') assertSchema("custom_webhook", expected, custom_trigger_payload, git_url="foobar")
def test_custom_gitlab(): def test_custom_gitlab():
expected = { expected = {
'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e', "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'git_url': u'git@gitlab.com:jsmith/somerepo.git', "git_url": u"git@gitlab.com:jsmith/somerepo.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e', "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'date': u'2015-08-13T19:33:18+00:00', "date": u"2015-08-13T19:33:18+00:00",
'message': u'Fix link\n', "message": u"Fix link\n",
}, },
} }
assertSchema('gitlab_webhook', expected, custom_trigger_payload, git_url='git@gitlab.com:jsmith/somerepo.git') assertSchema(
"gitlab_webhook",
expected,
custom_trigger_payload,
git_url="git@gitlab.com:jsmith/somerepo.git",
)
def test_custom_github(): def test_custom_github():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
'committer': { "committer": {"username": u"jsmith"},
'username': u'jsmith', "author": {"username": u"jsmith"},
},
'author': {
'username': u'jsmith',
},
}, },
} }
assertSchema('github_webhook', expected, custom_trigger_payload, assertSchema(
git_url='git@github.com:jsmith/anothertest.git') "github_webhook",
expected,
custom_trigger_payload,
git_url="git@github.com:jsmith/anothertest.git",
)
def test_custom_bitbucket(): def test_custom_bitbucket():
@ -119,7 +125,12 @@ def test_custom_bitbucket():
}, },
} }
assertSchema('bitbucket_webhook', expected, custom_trigger_payload, git_url='git@bitbucket.org:jsmith/another-repo.git') assertSchema(
"bitbucket_webhook",
expected,
custom_trigger_payload,
git_url="git@bitbucket.org:jsmith/another-repo.git",
)
def test_bitbucket_customer_payload_noauthor(): def test_bitbucket_customer_payload_noauthor():
@ -138,7 +149,7 @@ def test_bitbucket_customer_payload_noauthor():
}, },
} }
assertSchema('bitbucket_customer_example_noauthor', expected, bb_webhook) assertSchema("bitbucket_customer_example_noauthor", expected, bb_webhook)
def test_bitbucket_customer_payload_tag(): def test_bitbucket_customer_payload_tag():
@ -157,20 +168,17 @@ def test_bitbucket_customer_payload_tag():
}, },
} }
assertSchema('bitbucket_customer_example_tag', expected, bb_webhook) assertSchema("bitbucket_customer_example_tag", expected, bb_webhook)
def test_bitbucket_commit(): def test_bitbucket_commit():
ref = 'refs/heads/somebranch' ref = "refs/heads/somebranch"
default_branch = 'somebranch' default_branch = "somebranch"
repository_name = 'foo/bar' repository_name = "foo/bar"
def lookup_author(_): def lookup_author(_):
return { return {
'user': { "user": {"display_name": "cooluser", "avatar": "http://some/avatar/url"}
'display_name': 'cooluser',
'avatar': 'http://some/avatar/url'
}
} }
expected = { expected = {
@ -185,12 +193,20 @@ def test_bitbucket_commit():
"author": { "author": {
"avatar_url": u"http://some/avatar/url", "avatar_url": u"http://some/avatar/url",
"username": u"cooluser", "username": u"cooluser",
} },
} },
} }
assertSchema('bitbucket_commit', expected, bb_commit, ref, default_branch, assertSchema(
repository_name, lookup_author) "bitbucket_commit",
expected,
bb_commit,
ref,
default_branch,
repository_name,
lookup_author,
)
def test_bitbucket_webhook_payload(): def test_bitbucket_webhook_payload():
expected = { expected = {
@ -212,123 +228,117 @@ def test_bitbucket_webhook_payload():
}, },
} }
assertSchema('bitbucket_webhook', expected, bb_webhook) assertSchema("bitbucket_webhook", expected, bb_webhook)
def test_github_webhook_payload_slash_branch(): def test_github_webhook_payload_slash_branch():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/slash/branch', "ref": u"refs/heads/slash/branch",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
'committer': { "committer": {"username": u"jsmith"},
'username': u'jsmith', "author": {"username": u"jsmith"},
},
'author': {
'username': u'jsmith',
},
}, },
} }
assertSchema('github_webhook_slash_branch', expected, gh_webhook) assertSchema("github_webhook_slash_branch", expected, gh_webhook)
def test_github_webhook_payload(): def test_github_webhook_payload():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
'committer': { "committer": {"username": u"jsmith"},
'username': u'jsmith', "author": {"username": u"jsmith"},
},
'author': {
'username': u'jsmith',
},
}, },
} }
assertSchema('github_webhook', expected, gh_webhook) assertSchema("github_webhook", expected, gh_webhook)
def test_github_webhook_payload_with_lookup(): def test_github_webhook_payload_with_lookup():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
'committer': { "committer": {
'username': u'jsmith', "username": u"jsmith",
'url': u'http://github.com/jsmith', "url": u"http://github.com/jsmith",
'avatar_url': u'http://some/avatar/url', "avatar_url": u"http://some/avatar/url",
}, },
'author': { "author": {
'username': u'jsmith', "username": u"jsmith",
'url': u'http://github.com/jsmith', "url": u"http://github.com/jsmith",
'avatar_url': u'http://some/avatar/url', "avatar_url": u"http://some/avatar/url",
}, },
}, },
} }
def lookup_user(_): def lookup_user(_):
return { return {
'html_url': 'http://github.com/jsmith', "html_url": "http://github.com/jsmith",
'avatar_url': 'http://some/avatar/url' "avatar_url": "http://some/avatar/url",
} }
assertSchema('github_webhook', expected, gh_webhook, lookup_user=lookup_user) assertSchema("github_webhook", expected, gh_webhook, lookup_user=lookup_user)
def test_github_webhook_payload_missing_fields_with_lookup(): def test_github_webhook_payload_missing_fields_with_lookup():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile' "message": u"Update Dockerfile",
}, },
} }
def lookup_user(username): def lookup_user(username):
if not username: if not username:
raise Exception('Fail!') raise Exception("Fail!")
return { return {
'html_url': 'http://github.com/jsmith', "html_url": "http://github.com/jsmith",
'avatar_url': 'http://some/avatar/url' "avatar_url": "http://some/avatar/url",
} }
assertSchema('github_webhook_missing', expected, gh_webhook, lookup_user=lookup_user) assertSchema(
"github_webhook_missing", expected, gh_webhook, lookup_user=lookup_user
)
def test_gitlab_webhook_payload(): def test_gitlab_webhook_payload():
expected = { expected = {
'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e', "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'git_url': u'git@gitlab.com:jsmith/somerepo.git', "git_url": u"git@gitlab.com:jsmith/somerepo.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e', "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'date': u'2015-08-13T19:33:18+00:00', "date": u"2015-08-13T19:33:18+00:00",
'message': u'Fix link\n', "message": u"Fix link\n",
}, },
} }
assertSchema('gitlab_webhook', expected, gl_webhook) assertSchema("gitlab_webhook", expected, gl_webhook)
def test_github_webhook_payload_known_issue(): def test_github_webhook_payload_known_issue():
@ -344,82 +354,84 @@ def test_github_webhook_payload_known_issue():
}, },
} }
assertSchema('github_webhook_noname', expected, gh_webhook) assertSchema("github_webhook_noname", expected, gh_webhook)
def test_github_webhook_payload_missing_fields(): def test_github_webhook_payload_missing_fields():
expected = { expected = {
'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c', "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'default_branch': u'master', "default_branch": u"master",
'git_url': u'git@github.com:jsmith/anothertest.git', "git_url": u"git@github.com:jsmith/anothertest.git",
'commit_info': { "commit_info": {
'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c', "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
'date': u'2015-09-11T14:26:16-04:00', "date": u"2015-09-11T14:26:16-04:00",
'message': u'Update Dockerfile' "message": u"Update Dockerfile",
}, },
} }
assertSchema('github_webhook_missing', expected, gh_webhook) assertSchema("github_webhook_missing", expected, gh_webhook)
def test_gitlab_webhook_nocommit_payload(): def test_gitlab_webhook_nocommit_payload():
assertSkipped('gitlab_webhook_nocommit', gl_webhook) assertSkipped("gitlab_webhook_nocommit", gl_webhook)
def test_gitlab_webhook_multiple_commits(): def test_gitlab_webhook_multiple_commits():
expected = { expected = {
'commit': u'9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53', "commit": u"9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'git_url': u'git@gitlab.com:jsmith/some-test-project.git', "git_url": u"git@gitlab.com:jsmith/some-test-project.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/jsmith/some-test-project/commit/9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53', "url": u"https://gitlab.com/jsmith/some-test-project/commit/9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53",
'date': u'2016-09-29T15:02:41+00:00', "date": u"2016-09-29T15:02:41+00:00",
'message': u"Merge branch 'foobar' into 'master'\r\n\r\nAdd changelog\r\n\r\nSome merge thing\r\n\r\nSee merge request !1", "message": u"Merge branch 'foobar' into 'master'\r\n\r\nAdd changelog\r\n\r\nSome merge thing\r\n\r\nSee merge request !1",
'author': { "author": {
'username': 'jsmith', "username": "jsmith",
'url': 'http://gitlab.com/jsmith', "url": "http://gitlab.com/jsmith",
'avatar_url': 'http://some/avatar/url' "avatar_url": "http://some/avatar/url",
}, },
}, },
} }
def lookup_user(_): def lookup_user(_):
return { return {
'username': 'jsmith', "username": "jsmith",
'html_url': 'http://gitlab.com/jsmith', "html_url": "http://gitlab.com/jsmith",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
} }
assertSchema('gitlab_webhook_multicommit', expected, gl_webhook, lookup_user=lookup_user) assertSchema(
"gitlab_webhook_multicommit", expected, gl_webhook, lookup_user=lookup_user
)
def test_gitlab_webhook_for_tag(): def test_gitlab_webhook_for_tag():
expected = { expected = {
'commit': u'82b3d5ae55f7080f1e6022629cdb57bfae7cccc7', "commit": u"82b3d5ae55f7080f1e6022629cdb57bfae7cccc7",
'commit_info': { "commit_info": {
'author': { "author": {
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
'url': 'http://gitlab.com/jsmith', "url": "http://gitlab.com/jsmith",
'username': 'jsmith' "username": "jsmith",
}, },
'date': '2015-08-13T19:33:18+00:00', "date": "2015-08-13T19:33:18+00:00",
'message': 'Fix link\n', "message": "Fix link\n",
'url': 'https://some/url', "url": "https://some/url",
}, },
'git_url': u'git@example.com:jsmith/example.git', "git_url": u"git@example.com:jsmith/example.git",
'ref': u'refs/tags/v1.0.0', "ref": u"refs/tags/v1.0.0",
} }
def lookup_user(_): def lookup_user(_):
return { return {
'username': 'jsmith', "username": "jsmith",
'html_url': 'http://gitlab.com/jsmith', "html_url": "http://gitlab.com/jsmith",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
} }
def lookup_commit(repo_id, commit_sha): def lookup_commit(repo_id, commit_sha):
if commit_sha == '82b3d5ae55f7080f1e6022629cdb57bfae7cccc7': if commit_sha == "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7":
return { return {
"id": "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7", "id": "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7",
"message": "Fix link\n", "message": "Fix link\n",
@ -431,142 +443,146 @@ def test_gitlab_webhook_for_tag():
return None return None
assertSchema('gitlab_webhook_tag', expected, gl_webhook, lookup_user=lookup_user, assertSchema(
lookup_commit=lookup_commit) "gitlab_webhook_tag",
expected,
gl_webhook,
lookup_user=lookup_user,
lookup_commit=lookup_commit,
)
def test_gitlab_webhook_for_tag_nocommit(): def test_gitlab_webhook_for_tag_nocommit():
assertSkipped('gitlab_webhook_tag', gl_webhook) assertSkipped("gitlab_webhook_tag", gl_webhook)
def test_gitlab_webhook_for_tag_commit_sha_null(): def test_gitlab_webhook_for_tag_commit_sha_null():
assertSkipped('gitlab_webhook_tag_commit_sha_null', gl_webhook) assertSkipped("gitlab_webhook_tag_commit_sha_null", gl_webhook)
def test_gitlab_webhook_for_tag_known_issue(): def test_gitlab_webhook_for_tag_known_issue():
expected = { expected = {
'commit': u'770830e7ca132856991e6db4f7fc0f4dbe20bd5f', "commit": u"770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
'ref': u'refs/tags/thirdtag', "ref": u"refs/tags/thirdtag",
'git_url': u'git@gitlab.com:someuser/some-test-project.git', "git_url": u"git@gitlab.com:someuser/some-test-project.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f', "url": u"https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
'date': u'2019-10-17T18:07:48Z', "date": u"2019-10-17T18:07:48Z",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
'author': { "author": {
'username': 'someuser', "username": "someuser",
'url': 'http://gitlab.com/someuser', "url": "http://gitlab.com/someuser",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
}, },
}, },
} }
def lookup_user(_): def lookup_user(_):
return { return {
'username': 'someuser', "username": "someuser",
'html_url': 'http://gitlab.com/someuser', "html_url": "http://gitlab.com/someuser",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
} }
assertSchema('gitlab_webhook_tag_commit_issue', expected, gl_webhook, lookup_user=lookup_user) assertSchema(
"gitlab_webhook_tag_commit_issue", expected, gl_webhook, lookup_user=lookup_user
)
def test_gitlab_webhook_payload_known_issue(): def test_gitlab_webhook_payload_known_issue():
expected = { expected = {
'commit': u'770830e7ca132856991e6db4f7fc0f4dbe20bd5f', "commit": u"770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
'ref': u'refs/tags/fourthtag', "ref": u"refs/tags/fourthtag",
'git_url': u'git@gitlab.com:someuser/some-test-project.git', "git_url": u"git@gitlab.com:someuser/some-test-project.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f', "url": u"https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
'date': u'2019-10-17T18:07:48Z', "date": u"2019-10-17T18:07:48Z",
'message': u'Update Dockerfile', "message": u"Update Dockerfile",
}, },
} }
def lookup_commit(repo_id, commit_sha): def lookup_commit(repo_id, commit_sha):
if commit_sha == '770830e7ca132856991e6db4f7fc0f4dbe20bd5f': if commit_sha == "770830e7ca132856991e6db4f7fc0f4dbe20bd5f":
return { return {
"added": [], "added": [],
"author": { "author": {"name": "Some User", "email": "someuser@somedomain.com"},
"name": "Some User",
"email": "someuser@somedomain.com"
},
"url": "https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f", "url": "https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
"message": "Update Dockerfile", "message": "Update Dockerfile",
"removed": [], "removed": [],
"modified": [ "modified": ["Dockerfile"],
"Dockerfile" "id": "770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
],
"id": "770830e7ca132856991e6db4f7fc0f4dbe20bd5f"
} }
return None return None
assertSchema('gitlab_webhook_known_issue', expected, gl_webhook, lookup_commit=lookup_commit) assertSchema(
"gitlab_webhook_known_issue", expected, gl_webhook, lookup_commit=lookup_commit
)
def test_gitlab_webhook_for_other(): def test_gitlab_webhook_for_other():
assertSkipped('gitlab_webhook_other', gl_webhook) assertSkipped("gitlab_webhook_other", gl_webhook)
def test_gitlab_webhook_payload_with_lookup(): def test_gitlab_webhook_payload_with_lookup():
expected = { expected = {
'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e', "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'ref': u'refs/heads/master', "ref": u"refs/heads/master",
'git_url': u'git@gitlab.com:jsmith/somerepo.git', "git_url": u"git@gitlab.com:jsmith/somerepo.git",
'commit_info': { "commit_info": {
'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e', "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
'date': u'2015-08-13T19:33:18+00:00', "date": u"2015-08-13T19:33:18+00:00",
'message': u'Fix link\n', "message": u"Fix link\n",
'author': { "author": {
'username': 'jsmith', "username": "jsmith",
'url': 'http://gitlab.com/jsmith', "url": "http://gitlab.com/jsmith",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
}, },
}, },
} }
def lookup_user(_): def lookup_user(_):
return { return {
'username': 'jsmith', "username": "jsmith",
'html_url': 'http://gitlab.com/jsmith', "html_url": "http://gitlab.com/jsmith",
'avatar_url': 'http://some/avatar/url', "avatar_url": "http://some/avatar/url",
} }
assertSchema('gitlab_webhook', expected, gl_webhook, lookup_user=lookup_user) assertSchema("gitlab_webhook", expected, gl_webhook, lookup_user=lookup_user)
def test_github_webhook_payload_deleted_commit(): def test_github_webhook_payload_deleted_commit():
expected = { expected = {
'commit': u'456806b662cb903a0febbaed8344f3ed42f27bab', "commit": u"456806b662cb903a0febbaed8344f3ed42f27bab",
'commit_info': { "commit_info": {
'author': { "author": {"username": u"jsmith"},
'username': u'jsmith' "committer": {"username": u"jsmith"},
"date": u"2015-12-08T18:07:03-05:00",
"message": (
u"Merge pull request #1044 from jsmith/errerror\n\n"
+ "Assign the exception to a variable to log it"
),
"url": u"https://github.com/jsmith/somerepo/commit/456806b662cb903a0febbaed8344f3ed42f27bab",
}, },
'committer': { "git_url": u"git@github.com:jsmith/somerepo.git",
'username': u'jsmith' "ref": u"refs/heads/master",
}, "default_branch": u"master",
'date': u'2015-12-08T18:07:03-05:00',
'message': (u'Merge pull request #1044 from jsmith/errerror\n\n' +
'Assign the exception to a variable to log it'),
'url': u'https://github.com/jsmith/somerepo/commit/456806b662cb903a0febbaed8344f3ed42f27bab'
},
'git_url': u'git@github.com:jsmith/somerepo.git',
'ref': u'refs/heads/master',
'default_branch': u'master',
} }
def lookup_user(_): def lookup_user(_):
return None return None
assertSchema('github_webhook_deletedcommit', expected, gh_webhook, lookup_user=lookup_user) assertSchema(
"github_webhook_deletedcommit", expected, gh_webhook, lookup_user=lookup_user
)
def test_github_webhook_known_issue(): def test_github_webhook_known_issue():
def lookup_user(_): def lookup_user(_):
return None return None
assertSkipped('github_webhook_knownissue', gh_webhook, lookup_user=lookup_user) assertSkipped("github_webhook_knownissue", gh_webhook, lookup_user=lookup_user)
def test_bitbucket_webhook_known_issue(): def test_bitbucket_webhook_known_issue():
assertSkipped('bitbucket_knownissue', bb_webhook) assertSkipped("bitbucket_knownissue", bb_webhook)

View file

@ -4,22 +4,43 @@ import pytest
from buildtrigger.triggerutil import matches_ref from buildtrigger.triggerutil import matches_ref
@pytest.mark.parametrize('ref, filt, matches', [
('ref/heads/master', '.+', True),
('ref/heads/master', 'heads/.+', True),
('ref/heads/master', 'heads/master', True),
('ref/heads/slash/branch', 'heads/slash/branch', True),
('ref/heads/slash/branch', 'heads/.+', True),
('ref/heads/foobar', 'heads/master', False), @pytest.mark.parametrize(
('ref/heads/master', 'tags/master', False), "ref, filt, matches",
[
('ref/heads/master', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", ".+", True),
('ref/heads/alpha', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", "heads/.+", True),
('ref/heads/beta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/master", "heads/master", True),
('ref/heads/gamma', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True), ("ref/heads/slash/branch", "heads/slash/branch", True),
("ref/heads/slash/branch", "heads/.+", True),
('ref/heads/delta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', False), ("ref/heads/foobar", "heads/master", False),
]) ("ref/heads/master", "tags/master", False),
(
"ref/heads/master",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/alpha",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/beta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/gamma",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/delta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
False,
),
],
)
def test_matches_ref(ref, filt, matches): def test_matches_ref(ref, filt, matches):
assert matches_ref(ref, re.compile(filt)) == matches assert matches_ref(ref, re.compile(filt)) == matches

View file

@ -3,74 +3,92 @@ import io
import logging import logging
import re import re
class TriggerException(Exception): class TriggerException(Exception):
pass pass
class TriggerAuthException(TriggerException): class TriggerAuthException(TriggerException):
pass pass
class InvalidPayloadException(TriggerException): class InvalidPayloadException(TriggerException):
pass pass
class BuildArchiveException(TriggerException): class BuildArchiveException(TriggerException):
pass pass
class InvalidServiceException(TriggerException): class InvalidServiceException(TriggerException):
pass pass
class TriggerActivationException(TriggerException): class TriggerActivationException(TriggerException):
pass pass
class TriggerDeactivationException(TriggerException): class TriggerDeactivationException(TriggerException):
pass pass
class TriggerStartException(TriggerException): class TriggerStartException(TriggerException):
pass pass
class ValidationRequestException(TriggerException): class ValidationRequestException(TriggerException):
pass pass
class SkipRequestException(TriggerException): class SkipRequestException(TriggerException):
pass pass
class EmptyRepositoryException(TriggerException): class EmptyRepositoryException(TriggerException):
pass pass
class RepositoryReadException(TriggerException): class RepositoryReadException(TriggerException):
pass pass
class TriggerProviderException(TriggerException): class TriggerProviderException(TriggerException):
pass pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch): def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch):
run_parameters = run_parameters or {} run_parameters = run_parameters or {}
kind = '' kind = ""
value = '' value = ""
if 'refs' in run_parameters and run_parameters['refs']: if "refs" in run_parameters and run_parameters["refs"]:
kind = run_parameters['refs']['kind'] kind = run_parameters["refs"]["kind"]
value = run_parameters['refs']['name'] value = run_parameters["refs"]["name"]
elif 'branch_name' in run_parameters: elif "branch_name" in run_parameters:
kind = 'branch' kind = "branch"
value = run_parameters['branch_name'] value = run_parameters["branch_name"]
kind = kind or 'branch' kind = kind or "branch"
value = value or default_branch or 'master' value = value or default_branch or "master"
ref = 'refs/tags/' + value if kind == 'tag' else 'refs/heads/' + value ref = "refs/tags/" + value if kind == "tag" else "refs/heads/" + value
commit_sha = get_tag_sha(value) if kind == 'tag' else get_branch_sha(value) commit_sha = get_tag_sha(value) if kind == "tag" else get_branch_sha(value)
return (commit_sha, ref) return (commit_sha, ref)
def find_matching_branches(config, branches): def find_matching_branches(config, branches):
if 'branchtag_regex' in config: if "branchtag_regex" in config:
try: try:
regex = re.compile(config['branchtag_regex']) regex = re.compile(config["branchtag_regex"])
return [branch for branch in branches return [
if matches_ref('refs/heads/' + branch, regex)] branch
for branch in branches
if matches_ref("refs/heads/" + branch, regex)
]
except: except:
pass pass
@ -78,9 +96,9 @@ def find_matching_branches(config, branches):
def should_skip_commit(metadata): def should_skip_commit(metadata):
if 'commit_info' in metadata: if "commit_info" in metadata:
message = metadata['commit_info']['message'] message = metadata["commit_info"]["message"]
return '[skip build]' in message or '[build skip]' in message return "[skip build]" in message or "[build skip]" in message
return False return False
@ -88,27 +106,27 @@ def raise_if_skipped_build(prepared_build, config):
""" Raises a SkipRequestException if the given build should be skipped. """ """ Raises a SkipRequestException if the given build should be skipped. """
# Check to ensure we have metadata. # Check to ensure we have metadata.
if not prepared_build.metadata: if not prepared_build.metadata:
logger.debug('Skipping request due to missing metadata for prepared build') logger.debug("Skipping request due to missing metadata for prepared build")
raise SkipRequestException() raise SkipRequestException()
# Check the branchtag regex. # Check the branchtag regex.
if 'branchtag_regex' in config: if "branchtag_regex" in config:
try: try:
regex = re.compile(config['branchtag_regex']) regex = re.compile(config["branchtag_regex"])
except: except:
regex = re.compile('.*') regex = re.compile(".*")
if not matches_ref(prepared_build.metadata.get('ref'), regex): if not matches_ref(prepared_build.metadata.get("ref"), regex):
raise SkipRequestException() raise SkipRequestException()
# Check the commit message. # Check the commit message.
if should_skip_commit(prepared_build.metadata): if should_skip_commit(prepared_build.metadata):
logger.debug('Skipping request due to commit message request') logger.debug("Skipping request due to commit message request")
raise SkipRequestException() raise SkipRequestException()
def matches_ref(ref, regex): def matches_ref(ref, regex):
match_string = ref.split('/', 1)[1] match_string = ref.split("/", 1)[1]
if not regex: if not regex:
return False return False

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,18 +11,24 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000' bind = "0.0.0.0:5000"
workers = get_worker_count('local', 2, minimum=2, maximum=8) workers = get_worker_count("local", 2, minimum=2, maximum=8)
worker_class = 'gevent' worker_class = "gevent"
daemon = False daemon = False
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,10 +11,10 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_registry.sock' bind = "unix:/tmp/gunicorn_registry.sock"
workers = get_worker_count('registry', 4, minimum=8, maximum=64) workers = get_worker_count("registry", 4, minimum=8, maximum=64)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
@ -22,7 +23,11 @@ def post_fork(server, worker):
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting registry gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting registry gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,10 +11,10 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_secscan.sock' bind = "unix:/tmp/gunicorn_secscan.sock"
workers = get_worker_count('secscan', 2, minimum=2, maximum=4) workers = get_worker_count("secscan", 2, minimum=2, maximum=4)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
@ -22,7 +23,11 @@ def post_fork(server, worker):
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting secscan gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting secscan gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,9 +11,9 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_verbs.sock' bind = "unix:/tmp/gunicorn_verbs.sock"
workers = get_worker_count('verbs', 2, minimum=2, maximum=32) workers = get_worker_count("verbs", 2, minimum=2, maximum=32)
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
timeout = 2000 # Because sync workers timeout = 2000 # Because sync workers
@ -22,6 +23,9 @@ def post_fork(server, worker):
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting verbs gunicorn with %s workers and sync worker class', workers) logger.debug(
"Starting verbs gunicorn with %s workers and sync worker class", workers
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -11,18 +12,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False) logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_web.sock' bind = "unix:/tmp/gunicorn_web.sock"
workers = get_worker_count('web', 2, minimum=2, maximum=32) workers = get_worker_count("web", 2, minimum=2, maximum=32)
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting web gunicorn with %s workers and %s worker class', workers, logger.debug(
worker_class) "Starting web gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -7,59 +7,60 @@ import jinja2
QUAYPATH = os.getenv("QUAYPATH", ".") QUAYPATH = os.getenv("QUAYPATH", ".")
QUAYDIR = os.getenv("QUAYDIR", "/") QUAYDIR = os.getenv("QUAYDIR", "/")
QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf")) QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf"))
STATIC_DIR = os.path.join(QUAYDIR, 'static') STATIC_DIR = os.path.join(QUAYDIR, "static")
SSL_PROTOCOL_DEFAULTS = ['TLSv1', 'TLSv1.1', 'TLSv1.2'] SSL_PROTOCOL_DEFAULTS = ["TLSv1", "TLSv1.1", "TLSv1.2"]
SSL_CIPHER_DEFAULTS = [ SSL_CIPHER_DEFAULTS = [
'ECDHE-RSA-AES128-GCM-SHA256', "ECDHE-RSA-AES128-GCM-SHA256",
'ECDHE-ECDSA-AES128-GCM-SHA256', "ECDHE-ECDSA-AES128-GCM-SHA256",
'ECDHE-RSA-AES256-GCM-SHA384', "ECDHE-RSA-AES256-GCM-SHA384",
'ECDHE-ECDSA-AES256-GCM-SHA384', "ECDHE-ECDSA-AES256-GCM-SHA384",
'DHE-RSA-AES128-GCM-SHA256', "DHE-RSA-AES128-GCM-SHA256",
'DHE-DSS-AES128-GCM-SHA256', "DHE-DSS-AES128-GCM-SHA256",
'kEDH+AESGCM', "kEDH+AESGCM",
'ECDHE-RSA-AES128-SHA256', "ECDHE-RSA-AES128-SHA256",
'ECDHE-ECDSA-AES128-SHA256', "ECDHE-ECDSA-AES128-SHA256",
'ECDHE-RSA-AES128-SHA', "ECDHE-RSA-AES128-SHA",
'ECDHE-ECDSA-AES128-SHA', "ECDHE-ECDSA-AES128-SHA",
'ECDHE-RSA-AES256-SHA384', "ECDHE-RSA-AES256-SHA384",
'ECDHE-ECDSA-AES256-SHA384', "ECDHE-ECDSA-AES256-SHA384",
'ECDHE-RSA-AES256-SHA', "ECDHE-RSA-AES256-SHA",
'ECDHE-ECDSA-AES256-SHA', "ECDHE-ECDSA-AES256-SHA",
'DHE-RSA-AES128-SHA256', "DHE-RSA-AES128-SHA256",
'DHE-RSA-AES128-SHA', "DHE-RSA-AES128-SHA",
'DHE-DSS-AES128-SHA256', "DHE-DSS-AES128-SHA256",
'DHE-RSA-AES256-SHA256', "DHE-RSA-AES256-SHA256",
'DHE-DSS-AES256-SHA', "DHE-DSS-AES256-SHA",
'DHE-RSA-AES256-SHA', "DHE-RSA-AES256-SHA",
'AES128-GCM-SHA256', "AES128-GCM-SHA256",
'AES256-GCM-SHA384', "AES256-GCM-SHA384",
'AES128-SHA256', "AES128-SHA256",
'AES256-SHA256', "AES256-SHA256",
'AES128-SHA', "AES128-SHA",
'AES256-SHA', "AES256-SHA",
'AES', "AES",
'CAMELLIA', "CAMELLIA",
'!3DES', "!3DES",
'!aNULL', "!aNULL",
'!eNULL', "!eNULL",
'!EXPORT', "!EXPORT",
'!DES', "!DES",
'!RC4', "!RC4",
'!MD5', "!MD5",
'!PSK', "!PSK",
'!aECDH', "!aECDH",
'!EDH-DSS-DES-CBC3-SHA', "!EDH-DSS-DES-CBC3-SHA",
'!EDH-RSA-DES-CBC3-SHA', "!EDH-RSA-DES-CBC3-SHA",
'!KRB5-DES-CBC3-SHA', "!KRB5-DES-CBC3-SHA",
] ]
def write_config(filename, **kwargs): def write_config(filename, **kwargs):
with open(filename + ".jnj") as f: with open(filename + ".jnj") as f:
template = jinja2.Template(f.read()) template = jinja2.Template(f.read())
rendered = template.render(kwargs) rendered = template.render(kwargs)
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(rendered) f.write(rendered)
@ -68,19 +69,22 @@ def generate_nginx_config(config):
Generates nginx config from the app config Generates nginx config from the app config
""" """
config = config or {} config = config or {}
use_https = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.key')) use_https = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.key"))
use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.old.key')) use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.old.key"))
v1_only_domain = config.get('V1_ONLY_DOMAIN', None) v1_only_domain = config.get("V1_ONLY_DOMAIN", None)
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
ssl_protocols = config.get('SSL_PROTOCOLS', SSL_PROTOCOL_DEFAULTS) ssl_protocols = config.get("SSL_PROTOCOLS", SSL_PROTOCOL_DEFAULTS)
ssl_ciphers = config.get('SSL_CIPHERS', SSL_CIPHER_DEFAULTS) ssl_ciphers = config.get("SSL_CIPHERS", SSL_CIPHER_DEFAULTS)
write_config(os.path.join(QUAYCONF_DIR, 'nginx/nginx.conf'), use_https=use_https, write_config(
os.path.join(QUAYCONF_DIR, "nginx/nginx.conf"),
use_https=use_https,
use_old_certs=use_old_certs, use_old_certs=use_old_certs,
enable_rate_limits=enable_rate_limits, enable_rate_limits=enable_rate_limits,
v1_only_domain=v1_only_domain, v1_only_domain=v1_only_domain,
ssl_protocols=ssl_protocols, ssl_protocols=ssl_protocols,
ssl_ciphers=':'.join(ssl_ciphers)) ssl_ciphers=":".join(ssl_ciphers),
)
def generate_server_config(config): def generate_server_config(config):
@ -88,17 +92,21 @@ def generate_server_config(config):
Generates server config from the app config Generates server config from the app config
""" """
config = config or {} config = config or {}
tuf_server = config.get('TUF_SERVER', None) tuf_server = config.get("TUF_SERVER", None)
tuf_host = config.get('TUF_HOST', None) tuf_host = config.get("TUF_HOST", None)
signing_enabled = config.get('FEATURE_SIGNING', False) signing_enabled = config.get("FEATURE_SIGNING", False)
maximum_layer_size = config.get('MAXIMUM_LAYER_SIZE', '20G') maximum_layer_size = config.get("MAXIMUM_LAYER_SIZE", "20G")
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config( write_config(
os.path.join(QUAYCONF_DIR, 'nginx/server-base.conf'), tuf_server=tuf_server, tuf_host=tuf_host, os.path.join(QUAYCONF_DIR, "nginx/server-base.conf"),
signing_enabled=signing_enabled, maximum_layer_size=maximum_layer_size, tuf_server=tuf_server,
tuf_host=tuf_host,
signing_enabled=signing_enabled,
maximum_layer_size=maximum_layer_size,
enable_rate_limits=enable_rate_limits, enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR) static_dir=STATIC_DIR,
)
def generate_rate_limiting_config(config): def generate_rate_limiting_config(config):
@ -106,17 +114,19 @@ def generate_rate_limiting_config(config):
Generates rate limiting config from the app config Generates rate limiting config from the app config
""" """
config = config or {} config = config or {}
non_rate_limited_namespaces = config.get('NON_RATE_LIMITED_NAMESPACES') or set() non_rate_limited_namespaces = config.get("NON_RATE_LIMITED_NAMESPACES") or set()
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False) enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config( write_config(
os.path.join(QUAYCONF_DIR, 'nginx/rate-limiting.conf'), os.path.join(QUAYCONF_DIR, "nginx/rate-limiting.conf"),
non_rate_limited_namespaces=non_rate_limited_namespaces, non_rate_limited_namespaces=non_rate_limited_namespaces,
enable_rate_limits=enable_rate_limits, enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR) static_dir=STATIC_DIR,
)
if __name__ == "__main__": if __name__ == "__main__":
if os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/config.yaml')): if os.path.exists(os.path.join(QUAYCONF_DIR, "stack/config.yaml")):
with open(os.path.join(QUAYCONF_DIR, 'stack/config.yaml'), 'r') as f: with open(os.path.join(QUAYCONF_DIR, "stack/config.yaml"), "r") as f:
config = yaml.load(f) config = yaml.load(f)
else: else:
config = None config = None

View file

@ -13,99 +13,37 @@ QUAY_OVERRIDE_SERVICES = os.getenv("QUAY_OVERRIDE_SERVICES", [])
def default_services(): def default_services():
return { return {
"blobuploadcleanupworker": { "blobuploadcleanupworker": {"autostart": "true"},
"autostart": "true" "buildlogsarchiver": {"autostart": "true"},
}, "builder": {"autostart": "true"},
"buildlogsarchiver": { "chunkcleanupworker": {"autostart": "true"},
"autostart": "true" "expiredappspecifictokenworker": {"autostart": "true"},
}, "exportactionlogsworker": {"autostart": "true"},
"builder": { "gcworker": {"autostart": "true"},
"autostart": "true" "globalpromstats": {"autostart": "true"},
}, "labelbackfillworker": {"autostart": "true"},
"chunkcleanupworker": { "logrotateworker": {"autostart": "true"},
"autostart": "true" "namespacegcworker": {"autostart": "true"},
}, "notificationworker": {"autostart": "true"},
"expiredappspecifictokenworker": { "queuecleanupworker": {"autostart": "true"},
"autostart": "true" "repositoryactioncounter": {"autostart": "true"},
}, "security_notification_worker": {"autostart": "true"},
"exportactionlogsworker": { "securityworker": {"autostart": "true"},
"autostart": "true" "storagereplication": {"autostart": "true"},
}, "tagbackfillworker": {"autostart": "true"},
"gcworker": { "teamsyncworker": {"autostart": "true"},
"autostart": "true" "dnsmasq": {"autostart": "true"},
}, "gunicorn-registry": {"autostart": "true"},
"globalpromstats": { "gunicorn-secscan": {"autostart": "true"},
"autostart": "true" "gunicorn-verbs": {"autostart": "true"},
}, "gunicorn-web": {"autostart": "true"},
"labelbackfillworker": { "ip-resolver-update-worker": {"autostart": "true"},
"autostart": "true" "jwtproxy": {"autostart": "true"},
}, "memcache": {"autostart": "true"},
"logrotateworker": { "nginx": {"autostart": "true"},
"autostart": "true" "prometheus-aggregator": {"autostart": "true"},
}, "servicekey": {"autostart": "true"},
"namespacegcworker": { "repomirrorworker": {"autostart": "false"},
"autostart": "true"
},
"notificationworker": {
"autostart": "true"
},
"queuecleanupworker": {
"autostart": "true"
},
"repositoryactioncounter": {
"autostart": "true"
},
"security_notification_worker": {
"autostart": "true"
},
"securityworker": {
"autostart": "true"
},
"storagereplication": {
"autostart": "true"
},
"tagbackfillworker": {
"autostart": "true"
},
"teamsyncworker": {
"autostart": "true"
},
"dnsmasq": {
"autostart": "true"
},
"gunicorn-registry": {
"autostart": "true"
},
"gunicorn-secscan": {
"autostart": "true"
},
"gunicorn-verbs": {
"autostart": "true"
},
"gunicorn-web": {
"autostart": "true"
},
"ip-resolver-update-worker": {
"autostart": "true"
},
"jwtproxy": {
"autostart": "true"
},
"memcache": {
"autostart": "true"
},
"nginx": {
"autostart": "true"
},
"prometheus-aggregator": {
"autostart": "true"
},
"servicekey": {
"autostart": "true"
},
"repomirrorworker": {
"autostart": "false"
}
} }
@ -114,7 +52,7 @@ def generate_supervisord_config(filename, config):
template = jinja2.Template(f.read()) template = jinja2.Template(f.read())
rendered = template.render(config=config) rendered = template.render(config=config)
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(rendered) f.write(rendered)
@ -144,4 +82,4 @@ if __name__ == "__main__":
config = default_services() config = default_services()
limit_services(config, QUAY_SERVICES) limit_services(config, QUAY_SERVICES)
override_services(config, QUAY_OVERRIDE_SERVICES) override_services(config, QUAY_OVERRIDE_SERVICES)
generate_supervisord_config(os.path.join(QUAYCONF_DIR, 'supervisord.conf'), config) generate_supervisord_config(os.path.join(QUAYCONF_DIR, "supervisord.conf"), config)

View file

@ -6,11 +6,17 @@ import jinja2
from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services
def render_supervisord_conf(config): def render_supervisord_conf(config):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj")) as f: with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj"
)
) as f:
template = jinja2.Template(f.read()) template = jinja2.Template(f.read())
return template.render(config=config) return template.render(config=config)
def test_supervisord_conf_create_defaults(): def test_supervisord_conf_create_defaults():
config = default_services() config = default_services()
limit_services(config, []) limit_services(config, [])
@ -394,6 +400,7 @@ stderr_events_enabled = true
# EOF NO NEWLINE""" # EOF NO NEWLINE"""
assert rendered == expected assert rendered == expected
def test_supervisord_conf_create_all_overrides(): def test_supervisord_conf_create_all_overrides():
config = default_services() config = default_services()
limit_services(config, "servicekey,prometheus-aggregator") limit_services(config, "servicekey,prometheus-aggregator")

289
config.py
View file

@ -8,36 +8,52 @@ from _init import ROOT_DIR, CONF_DIR
def build_requests_session(): def build_requests_session():
sess = requests.Session() sess = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100, adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
pool_maxsize=100) sess.mount("http://", adapter)
sess.mount('http://', adapter) sess.mount("https://", adapter)
sess.mount('https://', adapter)
return sess return sess
# The set of configuration key names that will be accessible in the client. Since these # The set of configuration key names that will be accessible in the client. Since these
# values are sent to the frontend, DO NOT PLACE ANY SECRETS OR KEYS in this list. # values are sent to the frontend, DO NOT PLACE ANY SECRETS OR KEYS in this list.
CLIENT_WHITELIST = ['SERVER_HOSTNAME', 'PREFERRED_URL_SCHEME', 'MIXPANEL_KEY', CLIENT_WHITELIST = [
'STRIPE_PUBLISHABLE_KEY', 'ENTERPRISE_LOGO_URL', 'SENTRY_PUBLIC_DSN', "SERVER_HOSTNAME",
'AUTHENTICATION_TYPE', 'REGISTRY_TITLE', 'REGISTRY_TITLE_SHORT', "PREFERRED_URL_SCHEME",
'CONTACT_INFO', 'AVATAR_KIND', 'LOCAL_OAUTH_HANDLER', "MIXPANEL_KEY",
'SETUP_COMPLETE', 'DEBUG', 'MARKETO_MUNCHKIN_ID', "STRIPE_PUBLISHABLE_KEY",
'STATIC_SITE_BUCKET', 'RECAPTCHA_SITE_KEY', 'CHANNEL_COLORS', "ENTERPRISE_LOGO_URL",
'TAG_EXPIRATION_OPTIONS', 'INTERNAL_OIDC_SERVICE_ID', "SENTRY_PUBLIC_DSN",
'SEARCH_RESULTS_PER_PAGE', 'SEARCH_MAX_RESULT_PAGE_COUNT', 'BRANDING'] "AUTHENTICATION_TYPE",
"REGISTRY_TITLE",
"REGISTRY_TITLE_SHORT",
"CONTACT_INFO",
"AVATAR_KIND",
"LOCAL_OAUTH_HANDLER",
"SETUP_COMPLETE",
"DEBUG",
"MARKETO_MUNCHKIN_ID",
"STATIC_SITE_BUCKET",
"RECAPTCHA_SITE_KEY",
"CHANNEL_COLORS",
"TAG_EXPIRATION_OPTIONS",
"INTERNAL_OIDC_SERVICE_ID",
"SEARCH_RESULTS_PER_PAGE",
"SEARCH_MAX_RESULT_PAGE_COUNT",
"BRANDING",
]
def frontend_visible_config(config_dict): def frontend_visible_config(config_dict):
visible_dict = {} visible_dict = {}
for name in CLIENT_WHITELIST: for name in CLIENT_WHITELIST:
if name.lower().find('secret') >= 0: if name.lower().find("secret") >= 0:
raise Exception('Cannot whitelist secrets: %s' % name) raise Exception("Cannot whitelist secrets: %s" % name)
if name in config_dict: if name in config_dict:
visible_dict[name] = config_dict.get(name, None) visible_dict[name] = config_dict.get(name, None)
if 'ENTERPRISE_LOGO_URL' in config_dict: if "ENTERPRISE_LOGO_URL" in config_dict:
visible_dict['BRANDING'] = visible_dict.get('BRANDING', {}) visible_dict["BRANDING"] = visible_dict.get("BRANDING", {})
visible_dict['BRANDING']['logo'] = config_dict['ENTERPRISE_LOGO_URL'] visible_dict["BRANDING"]["logo"] = config_dict["ENTERPRISE_LOGO_URL"]
return visible_dict return visible_dict
@ -50,32 +66,99 @@ class ImmutableConfig(object):
# Status tag config # Status tag config
STATUS_TAGS = {} STATUS_TAGS = {}
for tag_name in ['building', 'failed', 'none', 'ready', 'cancelled']: for tag_name in ["building", "failed", "none", "ready", "cancelled"]:
tag_path = os.path.join(ROOT_DIR, 'buildstatus', tag_name + '.svg') tag_path = os.path.join(ROOT_DIR, "buildstatus", tag_name + ".svg")
with open(tag_path) as tag_svg: with open(tag_path) as tag_svg:
STATUS_TAGS[tag_name] = tag_svg.read() STATUS_TAGS[tag_name] = tag_svg.read()
# Reverse DNS prefixes that are reserved for internal use on labels and should not be allowable # Reverse DNS prefixes that are reserved for internal use on labels and should not be allowable
# to be set via the API. # to be set via the API.
DEFAULT_LABEL_KEY_RESERVED_PREFIXES = ['com.docker.', 'io.docker.', 'org.dockerproject.', DEFAULT_LABEL_KEY_RESERVED_PREFIXES = [
'org.opencontainers.', 'io.cncf.', "com.docker.",
'io.kubernetes.', 'io.k8s.', "io.docker.",
'io.quay', 'com.coreos', 'com.tectonic', "org.dockerproject.",
'internal', 'quay'] "org.opencontainers.",
"io.cncf.",
"io.kubernetes.",
"io.k8s.",
"io.quay",
"com.coreos",
"com.tectonic",
"internal",
"quay",
]
# Colors for local avatars. # Colors for local avatars.
AVATAR_COLORS = ['#969696', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', AVATAR_COLORS = [
'#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', "#969696",
'#7f7f7f', '#c7c7c7', '#bcbd22', '#1f77b4', '#17becf', '#9edae5', '#393b79', "#aec7e8",
'#5254a3', '#6b6ecf', '#9c9ede', '#9ecae1', '#31a354', '#b5cf6b', '#a1d99b', "#ff7f0e",
'#8c6d31', '#ad494a', '#e7ba52', '#a55194'] "#ffbb78",
"#2ca02c",
"#98df8a",
"#d62728",
"#ff9896",
"#9467bd",
"#c5b0d5",
"#8c564b",
"#c49c94",
"#e377c2",
"#f7b6d2",
"#7f7f7f",
"#c7c7c7",
"#bcbd22",
"#1f77b4",
"#17becf",
"#9edae5",
"#393b79",
"#5254a3",
"#6b6ecf",
"#9c9ede",
"#9ecae1",
"#31a354",
"#b5cf6b",
"#a1d99b",
"#8c6d31",
"#ad494a",
"#e7ba52",
"#a55194",
]
# Colors for channels. # Colors for channels.
CHANNEL_COLORS = ['#969696', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', CHANNEL_COLORS = [
'#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', "#969696",
'#7f7f7f', '#c7c7c7', '#bcbd22', '#1f77b4', '#17becf', '#9edae5', '#393b79', "#aec7e8",
'#5254a3', '#6b6ecf', '#9c9ede', '#9ecae1', '#31a354', '#b5cf6b', '#a1d99b', "#ff7f0e",
'#8c6d31', '#ad494a', '#e7ba52', '#a55194'] "#ffbb78",
"#2ca02c",
"#98df8a",
"#d62728",
"#ff9896",
"#9467bd",
"#c5b0d5",
"#8c564b",
"#c49c94",
"#e377c2",
"#f7b6d2",
"#7f7f7f",
"#c7c7c7",
"#bcbd22",
"#1f77b4",
"#17becf",
"#9edae5",
"#393b79",
"#5254a3",
"#6b6ecf",
"#9c9ede",
"#9ecae1",
"#31a354",
"#b5cf6b",
"#a1d99b",
"#8c6d31",
"#ad494a",
"#e7ba52",
"#a55194",
]
PROPAGATE_EXCEPTIONS = True PROPAGATE_EXCEPTIONS = True
@ -86,34 +169,31 @@ class DefaultConfig(ImmutableConfig):
SESSION_COOKIE_SECURE = False SESSION_COOKIE_SECURE = False
SESSION_COOKIE_HTTPONLY = True SESSION_COOKIE_HTTPONLY = True
SESSION_COOKIE_SAMESITE = 'Lax' SESSION_COOKIE_SAMESITE = "Lax"
LOGGING_LEVEL = 'DEBUG' LOGGING_LEVEL = "DEBUG"
SEND_FILE_MAX_AGE_DEFAULT = 0 SEND_FILE_MAX_AGE_DEFAULT = 0
PREFERRED_URL_SCHEME = 'http' PREFERRED_URL_SCHEME = "http"
SERVER_HOSTNAME = 'localhost:5000' SERVER_HOSTNAME = "localhost:5000"
REGISTRY_TITLE = 'Project Quay' REGISTRY_TITLE = "Project Quay"
REGISTRY_TITLE_SHORT = 'Project Quay' REGISTRY_TITLE_SHORT = "Project Quay"
CONTACT_INFO = [] CONTACT_INFO = []
# Mail config # Mail config
MAIL_SERVER = '' MAIL_SERVER = ""
MAIL_USE_TLS = True MAIL_USE_TLS = True
MAIL_PORT = 587 MAIL_PORT = 587
MAIL_USERNAME = None MAIL_USERNAME = None
MAIL_PASSWORD = None MAIL_PASSWORD = None
MAIL_DEFAULT_SENDER = 'example@projectquay.io' MAIL_DEFAULT_SENDER = "example@projectquay.io"
MAIL_FAIL_SILENTLY = False MAIL_FAIL_SILENTLY = False
TESTING = True TESTING = True
# DB config # DB config
DB_URI = 'sqlite:///test/data/test.db' DB_URI = "sqlite:///test/data/test.db"
DB_CONNECTION_ARGS = { DB_CONNECTION_ARGS = {"threadlocals": True, "autorollback": True}
'threadlocals': True,
'autorollback': True,
}
@staticmethod @staticmethod
def create_transaction(db): def create_transaction(db):
@ -123,7 +203,7 @@ class DefaultConfig(ImmutableConfig):
# If set to 'readonly', the entire registry is placed into read only mode and no write operations # If set to 'readonly', the entire registry is placed into read only mode and no write operations
# may be performed against it. # may be performed against it.
REGISTRY_STATE = 'normal' REGISTRY_STATE = "normal"
# If set to true, TLS is used, but is terminated by an external service (such as a load balancer). # If set to true, TLS is used, but is terminated by an external service (such as a load balancer).
# Note that PREFERRED_URL_SCHEME must be `https` when this flag is set or it can lead to undefined # Note that PREFERRED_URL_SCHEME must be `https` when this flag is set or it can lead to undefined
@ -135,27 +215,27 @@ class DefaultConfig(ImmutableConfig):
USE_CDN = False USE_CDN = False
# Authentication # Authentication
AUTHENTICATION_TYPE = 'Database' AUTHENTICATION_TYPE = "Database"
# Build logs # Build logs
BUILDLOGS_REDIS = {'host': 'localhost'} BUILDLOGS_REDIS = {"host": "localhost"}
BUILDLOGS_OPTIONS = [] BUILDLOGS_OPTIONS = []
# Real-time user events # Real-time user events
USER_EVENTS_REDIS = {'host': 'localhost'} USER_EVENTS_REDIS = {"host": "localhost"}
# Stripe config # Stripe config
BILLING_TYPE = 'FakeStripe' BILLING_TYPE = "FakeStripe"
# Analytics # Analytics
ANALYTICS_TYPE = 'FakeAnalytics' ANALYTICS_TYPE = "FakeAnalytics"
# Build Queue Metrics # Build Queue Metrics
QUEUE_METRICS_TYPE = 'Null' QUEUE_METRICS_TYPE = "Null"
QUEUE_WORKER_METRICS_REFRESH_SECONDS = 300 QUEUE_WORKER_METRICS_REFRESH_SECONDS = 300
# Exception logging # Exception logging
EXCEPTION_LOG_TYPE = 'FakeSentry' EXCEPTION_LOG_TYPE = "FakeSentry"
SENTRY_DSN = None SENTRY_DSN = None
SENTRY_PUBLIC_DSN = None SENTRY_PUBLIC_DSN = None
@ -172,13 +252,13 @@ class DefaultConfig(ImmutableConfig):
# Gitlab Config. # Gitlab Config.
GITLAB_TRIGGER_CONFIG = None GITLAB_TRIGGER_CONFIG = None
NOTIFICATION_QUEUE_NAME = 'notification' NOTIFICATION_QUEUE_NAME = "notification"
DOCKERFILE_BUILD_QUEUE_NAME = 'dockerfilebuild' DOCKERFILE_BUILD_QUEUE_NAME = "dockerfilebuild"
REPLICATION_QUEUE_NAME = 'imagestoragereplication' REPLICATION_QUEUE_NAME = "imagestoragereplication"
SECSCAN_NOTIFICATION_QUEUE_NAME = 'security_notification' SECSCAN_NOTIFICATION_QUEUE_NAME = "security_notification"
CHUNK_CLEANUP_QUEUE_NAME = 'chunk_cleanup' CHUNK_CLEANUP_QUEUE_NAME = "chunk_cleanup"
NAMESPACE_GC_QUEUE_NAME = 'namespacegc' NAMESPACE_GC_QUEUE_NAME = "namespacegc"
EXPORT_ACTION_LOGS_QUEUE_NAME = 'exportactionlogs' EXPORT_ACTION_LOGS_QUEUE_NAME = "exportactionlogs"
# Super user config. Note: This MUST BE an empty list for the default config. # Super user config. Note: This MUST BE an empty list for the default config.
SUPER_USERS = [] SUPER_USERS = []
@ -244,7 +324,7 @@ class DefaultConfig(ImmutableConfig):
# Semver spec for which Docker versions we will blacklist # Semver spec for which Docker versions we will blacklist
# Documentation: http://pythonhosted.org/semantic_version/reference.html#semantic_version.Spec # Documentation: http://pythonhosted.org/semantic_version/reference.html#semantic_version.Spec
BLACKLIST_V2_SPEC = '<1.6.0' BLACKLIST_V2_SPEC = "<1.6.0"
# Feature Flag: Whether to restrict V1 pushes to the whitelist. # Feature Flag: Whether to restrict V1 pushes to the whitelist.
FEATURE_RESTRICTED_V1_PUSH = False FEATURE_RESTRICTED_V1_PUSH = False
@ -302,33 +382,33 @@ class DefaultConfig(ImmutableConfig):
# The namespace to use for library repositories. # The namespace to use for library repositories.
# Note: This must remain 'library' until Docker removes their hard-coded namespace for libraries. # Note: This must remain 'library' until Docker removes their hard-coded namespace for libraries.
# See: https://github.com/docker/docker/blob/master/registry/session.go#L320 # See: https://github.com/docker/docker/blob/master/registry/session.go#L320
LIBRARY_NAMESPACE = 'library' LIBRARY_NAMESPACE = "library"
BUILD_MANAGER = ('enterprise', {}) BUILD_MANAGER = ("enterprise", {})
DISTRIBUTED_STORAGE_CONFIG = { DISTRIBUTED_STORAGE_CONFIG = {
'local_eu': ['LocalStorage', {'storage_path': 'test/data/registry/eu'}], "local_eu": ["LocalStorage", {"storage_path": "test/data/registry/eu"}],
'local_us': ['LocalStorage', {'storage_path': 'test/data/registry/us'}], "local_us": ["LocalStorage", {"storage_path": "test/data/registry/us"}],
} }
DISTRIBUTED_STORAGE_PREFERENCE = ['local_us'] DISTRIBUTED_STORAGE_PREFERENCE = ["local_us"]
DISTRIBUTED_STORAGE_DEFAULT_LOCATIONS = ['local_us'] DISTRIBUTED_STORAGE_DEFAULT_LOCATIONS = ["local_us"]
# Health checker. # Health checker.
HEALTH_CHECKER = ('LocalHealthCheck', {}) HEALTH_CHECKER = ("LocalHealthCheck", {})
# Userfiles # Userfiles
USERFILES_LOCATION = 'local_us' USERFILES_LOCATION = "local_us"
USERFILES_PATH = 'userfiles/' USERFILES_PATH = "userfiles/"
# Build logs archive # Build logs archive
LOG_ARCHIVE_LOCATION = 'local_us' LOG_ARCHIVE_LOCATION = "local_us"
LOG_ARCHIVE_PATH = 'logarchive/' LOG_ARCHIVE_PATH = "logarchive/"
# Action logs archive # Action logs archive
ACTION_LOG_ARCHIVE_LOCATION = 'local_us' ACTION_LOG_ARCHIVE_LOCATION = "local_us"
ACTION_LOG_ARCHIVE_PATH = 'actionlogarchive/' ACTION_LOG_ARCHIVE_PATH = "actionlogarchive/"
ACTION_LOG_ROTATION_THRESHOLD = '30d' ACTION_LOG_ROTATION_THRESHOLD = "30d"
# Allow registry pulls when unable to write to the audit log # Allow registry pulls when unable to write to the audit log
ALLOW_PULLS_WITHOUT_STRICT_LOGGING = False ALLOW_PULLS_WITHOUT_STRICT_LOGGING = False
@ -340,19 +420,21 @@ class DefaultConfig(ImmutableConfig):
SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
# Registry v2 JWT Auth config # Registry v2 JWT Auth config
REGISTRY_JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed one hour, accounting for clock skew REGISTRY_JWT_AUTH_MAX_FRESH_S = (
60 * 60 + 60
) # At most signed one hour, accounting for clock skew
# The URL endpoint to which we redirect OAuth when generating a token locally. # The URL endpoint to which we redirect OAuth when generating a token locally.
LOCAL_OAUTH_HANDLER = '/oauth/localapp' LOCAL_OAUTH_HANDLER = "/oauth/localapp"
# The various avatar background colors. # The various avatar background colors.
AVATAR_KIND = 'local' AVATAR_KIND = "local"
# Custom branding # Custom branding
BRANDING = { BRANDING = {
'logo': '/static/img/quay-horizontal-color.svg', "logo": "/static/img/quay-horizontal-color.svg",
'footer_img': None, "footer_img": None,
'footer_url': None, "footer_url": None,
} }
# How often the Garbage Collection worker runs. # How often the Garbage Collection worker runs.
@ -366,7 +448,7 @@ class DefaultConfig(ImmutableConfig):
FEATURE_SECURITY_NOTIFICATIONS = False FEATURE_SECURITY_NOTIFICATIONS = False
# The endpoint for the security scanner. # The endpoint for the security scanner.
SECURITY_SCANNER_ENDPOINT = 'http://192.168.99.101:6060' SECURITY_SCANNER_ENDPOINT = "http://192.168.99.101:6060"
# The number of seconds between indexing intervals in the security scanner # The number of seconds between indexing intervals in the security scanner
SECURITY_SCANNER_INDEXING_INTERVAL = 30 SECURITY_SCANNER_INDEXING_INTERVAL = 30
@ -384,7 +466,7 @@ class DefaultConfig(ImmutableConfig):
SECURITY_SCANNER_ENGINE_VERSION_TARGET = 3 SECURITY_SCANNER_ENGINE_VERSION_TARGET = 3
# The version of the API to use for the security scanner. # The version of the API to use for the security scanner.
SECURITY_SCANNER_API_VERSION = 'v1' SECURITY_SCANNER_API_VERSION = "v1"
# API call timeout for the security scanner. # API call timeout for the security scanner.
SECURITY_SCANNER_API_TIMEOUT_SECONDS = 10 SECURITY_SCANNER_API_TIMEOUT_SECONDS = 10
@ -393,7 +475,7 @@ class DefaultConfig(ImmutableConfig):
SECURITY_SCANNER_API_TIMEOUT_POST_SECONDS = 480 SECURITY_SCANNER_API_TIMEOUT_POST_SECONDS = 480
# The issuer name for the security scanner. # The issuer name for the security scanner.
SECURITY_SCANNER_ISSUER_NAME = 'security_scanner' SECURITY_SCANNER_ISSUER_NAME = "security_scanner"
# Repository mirror # Repository mirror
FEATURE_REPO_MIRROR = False FEATURE_REPO_MIRROR = False
@ -410,7 +492,7 @@ class DefaultConfig(ImmutableConfig):
# JWTProxy Settings # JWTProxy Settings
# The address (sans schema) to proxy outgoing requests through the jwtproxy # The address (sans schema) to proxy outgoing requests through the jwtproxy
# to be signed # to be signed
JWTPROXY_SIGNER = 'localhost:8081' JWTPROXY_SIGNER = "localhost:8081"
# The audience that jwtproxy should verify on incoming requests # The audience that jwtproxy should verify on incoming requests
# If None, will be calculated off of the SERVER_HOSTNAME (default) # If None, will be calculated off of the SERVER_HOSTNAME (default)
@ -419,7 +501,7 @@ class DefaultConfig(ImmutableConfig):
# Torrent management flags # Torrent management flags
FEATURE_BITTORRENT = False FEATURE_BITTORRENT = False
BITTORRENT_PIECE_SIZE = 512 * 1024 BITTORRENT_PIECE_SIZE = 512 * 1024
BITTORRENT_ANNOUNCE_URL = 'https://localhost:6881/announce' BITTORRENT_ANNOUNCE_URL = "https://localhost:6881/announce"
BITTORRENT_FILENAME_PEPPER = str(uuid4()) BITTORRENT_FILENAME_PEPPER = str(uuid4())
BITTORRENT_WEBSEED_LIFETIME = 3600 BITTORRENT_WEBSEED_LIFETIME = 3600
@ -427,7 +509,7 @@ class DefaultConfig(ImmutableConfig):
# hide the ID range for production (in which this value is overridden). Should *not* # hide the ID range for production (in which this value is overridden). Should *not*
# be relied upon for secure encryption otherwise. # be relied upon for secure encryption otherwise.
# This value is a Fernet key and should be 32bytes URL-safe base64 encoded. # This value is a Fernet key and should be 32bytes URL-safe base64 encoded.
PAGE_TOKEN_KEY = '0OYrc16oBuksR8T3JGB-xxYSlZ2-7I_zzqrLzggBJ58=' PAGE_TOKEN_KEY = "0OYrc16oBuksR8T3JGB-xxYSlZ2-7I_zzqrLzggBJ58="
# The timeout for service key approval. # The timeout for service key approval.
UNAPPROVED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 # One day UNAPPROVED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 # One day
@ -441,14 +523,14 @@ class DefaultConfig(ImmutableConfig):
# The service key ID for the instance service. # The service key ID for the instance service.
# NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated. # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
INSTANCE_SERVICE_KEY_SERVICE = 'quay' INSTANCE_SERVICE_KEY_SERVICE = "quay"
# The location of the key ID file generated for this instance. # The location of the key ID file generated for this instance.
INSTANCE_SERVICE_KEY_KID_LOCATION = os.path.join(CONF_DIR, 'quay.kid') INSTANCE_SERVICE_KEY_KID_LOCATION = os.path.join(CONF_DIR, "quay.kid")
# The location of the private key generated for this instance. # The location of the private key generated for this instance.
# NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated. # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
INSTANCE_SERVICE_KEY_LOCATION = os.path.join(CONF_DIR, 'quay.pem') INSTANCE_SERVICE_KEY_LOCATION = os.path.join(CONF_DIR, "quay.pem")
# This instance's service key expiration in minutes. # This instance's service key expiration in minutes.
INSTANCE_SERVICE_KEY_EXPIRATION = 120 INSTANCE_SERVICE_KEY_EXPIRATION = 120
@ -461,10 +543,10 @@ class DefaultConfig(ImmutableConfig):
DIRECT_OAUTH_CLIENTID_WHITELIST = [] DIRECT_OAUTH_CLIENTID_WHITELIST = []
# URL that specifies the location of the prometheus stats aggregator. # URL that specifies the location of the prometheus stats aggregator.
PROMETHEUS_AGGREGATOR_URL = 'http://localhost:9092' PROMETHEUS_AGGREGATOR_URL = "http://localhost:9092"
# Namespace prefix for all prometheus metrics. # Namespace prefix for all prometheus metrics.
PROMETHEUS_NAMESPACE = 'quay' PROMETHEUS_NAMESPACE = "quay"
# Overridable list of reverse DNS prefixes that are reserved for internal use on labels. # Overridable list of reverse DNS prefixes that are reserved for internal use on labels.
LABEL_KEY_RESERVED_PREFIXES = [] LABEL_KEY_RESERVED_PREFIXES = []
@ -487,22 +569,22 @@ class DefaultConfig(ImmutableConfig):
TUF_GUN_PREFIX = None TUF_GUN_PREFIX = None
# Maximum size allowed for layers in the registry. # Maximum size allowed for layers in the registry.
MAXIMUM_LAYER_SIZE = '20G' MAXIMUM_LAYER_SIZE = "20G"
# Feature Flag: Whether team syncing from the backing auth is enabled. # Feature Flag: Whether team syncing from the backing auth is enabled.
FEATURE_TEAM_SYNCING = False FEATURE_TEAM_SYNCING = False
TEAM_RESYNC_STALE_TIME = '30m' TEAM_RESYNC_STALE_TIME = "30m"
TEAM_SYNC_WORKER_FREQUENCY = 60 # seconds TEAM_SYNC_WORKER_FREQUENCY = 60 # seconds
# Feature Flag: If enabled, non-superusers can setup team syncing. # Feature Flag: If enabled, non-superusers can setup team syncing.
FEATURE_NONSUPERUSER_TEAM_SYNCING_SETUP = False FEATURE_NONSUPERUSER_TEAM_SYNCING_SETUP = False
# The default configurable tag expiration time for time machine. # The default configurable tag expiration time for time machine.
DEFAULT_TAG_EXPIRATION = '2w' DEFAULT_TAG_EXPIRATION = "2w"
# The options to present in namespace settings for the tag expiration. If empty, no option # The options to present in namespace settings for the tag expiration. If empty, no option
# will be given and the default will be displayed read-only. # will be given and the default will be displayed read-only.
TAG_EXPIRATION_OPTIONS = ['0s', '1d', '1w', '2w', '4w'] TAG_EXPIRATION_OPTIONS = ["0s", "1d", "1w", "2w", "4w"]
# Feature Flag: Whether users can view and change their tag expiration. # Feature Flag: Whether users can view and change their tag expiration.
FEATURE_CHANGE_TAG_EXPIRATION = True FEATURE_CHANGE_TAG_EXPIRATION = True
@ -511,7 +593,7 @@ class DefaultConfig(ImmutableConfig):
ENABLE_HEALTH_DEBUG_SECRET = None ENABLE_HEALTH_DEBUG_SECRET = None
# The lifetime for a user recovery token before it becomes invalid. # The lifetime for a user recovery token before it becomes invalid.
USER_RECOVERY_TOKEN_LIFETIME = '30m' USER_RECOVERY_TOKEN_LIFETIME = "30m"
# If specified, when app specific passwords expire by default. # If specified, when app specific passwords expire by default.
APP_SPECIFIC_TOKEN_EXPIRATION = None APP_SPECIFIC_TOKEN_EXPIRATION = None
@ -521,7 +603,7 @@ class DefaultConfig(ImmutableConfig):
# How long expired app specific tokens should remain visible to users before being automatically # How long expired app specific tokens should remain visible to users before being automatically
# deleted. Set to None to turn off garbage collection. # deleted. Set to None to turn off garbage collection.
EXPIRED_APP_SPECIFIC_TOKEN_GC = '1d' EXPIRED_APP_SPECIFIC_TOKEN_GC = "1d"
# The size of pages returned by the Docker V2 API. # The size of pages returned by the Docker V2 API.
V2_PAGINATION_SIZE = 50 V2_PAGINATION_SIZE = 50
@ -545,10 +627,7 @@ class DefaultConfig(ImmutableConfig):
BILLED_NAMESPACE_MAXIMUM_BUILD_COUNT = None BILLED_NAMESPACE_MAXIMUM_BUILD_COUNT = None
# Configuration for the data model cache. # Configuration for the data model cache.
DATA_MODEL_CACHE_CONFIG = { DATA_MODEL_CACHE_CONFIG = {"engine": "memcached", "endpoint": ("127.0.0.1", 18080)}
'engine': 'memcached',
'endpoint': ('127.0.0.1', 18080),
}
# Defines the number of successive failures of a build trigger's build before the trigger is # Defines the number of successive failures of a build trigger's build before the trigger is
# automatically disabled. # automatically disabled.
@ -584,7 +663,7 @@ class DefaultConfig(ImmutableConfig):
ACTION_LOG_MAX_PAGE = None ACTION_LOG_MAX_PAGE = None
# Log model # Log model
LOGS_MODEL = 'database' LOGS_MODEL = "database"
LOGS_MODEL_CONFIG = {} LOGS_MODEL_CONFIG = {}
# Namespace in which all audit logging is disabled. # Namespace in which all audit logging is disabled.

View file

@ -7,19 +7,19 @@ import subprocess
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/")) CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/') STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/') STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/') STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/') TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
def _get_version_number_changelog(): def _get_version_number_changelog():
try: try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f: with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0) return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError: except IOError:
return '' return ""
def _get_git_sha(): def _get_git_sha():

View file

@ -15,26 +15,27 @@ app = Flask(__name__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, 'config_app/conf/stack') OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, "config_app/conf/stack")
INIT_SCRIPTS_LOCATION = '/conf/init/' INIT_SCRIPTS_LOCATION = "/conf/init/"
is_testing = 'TEST' in os.environ is_testing = "TEST" in os.environ
is_kubernetes = IS_KUBERNETES is_kubernetes = IS_KUBERNETES
logger.debug('Configuration is on a kubernetes deployment: %s' % IS_KUBERNETES) logger.debug("Configuration is on a kubernetes deployment: %s" % IS_KUBERNETES)
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py', config_provider = get_config_provider(
testing=is_testing) OVERRIDE_CONFIG_DIRECTORY, "config.yaml", "config.py", testing=is_testing
)
if is_testing: if is_testing:
from test.testconfig import TestConfig from test.testconfig import TestConfig
logger.debug('Loading test config.') logger.debug("Loading test config.")
app.config.from_object(TestConfig()) app.config.from_object(TestConfig())
else: else:
from config import DefaultConfig from config import DefaultConfig
logger.debug('Loading default config.') logger.debug("Loading default config.")
app.config.from_object(DefaultConfig()) app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter) app.teardown_request(database.close_db_filter)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -9,18 +10,24 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000' bind = "0.0.0.0:5000"
workers = 1 workers = 1
worker_class = 'gevent' worker_class = "gevent"
daemon = False daemon = False
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../")) sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging import logging
@ -10,17 +11,23 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True) logconfig = logfile_path(debug=True)
bind = 'unix:/tmp/gunicorn_web.sock' bind = "unix:/tmp/gunicorn_web.sock"
workers = 1 workers = 1
worker_class = 'gevent' worker_class = "gevent"
pythonpath = '.' pythonpath = "."
preload_app = True preload_app = True
def post_fork(server, worker): def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after # Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks. # gunicorn forks.
Random.atfork() Random.atfork()
def when_ready(server): def when_ready(server):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class) logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -3,6 +3,6 @@ from config_app.c_app import app as application
# Bind all of the blueprints # Bind all of the blueprints
import config_web import config_web
if __name__ == '__main__': if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False) logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0') application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -13,15 +13,15 @@ from config_app.c_app import app, IS_KUBERNETES
from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
api_bp = Blueprint('api', __name__) api_bp = Blueprint("api", __name__)
CROSS_DOMAIN_HEADERS = ['Authorization', 'Content-Type', 'X-Requested-With'] CROSS_DOMAIN_HEADERS = ["Authorization", "Content-Type", "X-Requested-With"]
class ApiExceptionHandlingApi(Api): class ApiExceptionHandlingApi(Api):
pass pass
@crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS) @crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS)
def handle_error(self, error): def handle_error(self, error):
return super(ApiExceptionHandlingApi, self).handle_error(error) return super(ApiExceptionHandlingApi, self).handle_error(error)
@ -30,6 +30,7 @@ api = ApiExceptionHandlingApi()
api.init_app(api_bp) api.init_app(api_bp)
def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None): def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
if not metadata: if not metadata:
metadata = {} metadata = {}
@ -37,7 +38,10 @@ def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
if repo: if repo:
repo_name = repo.name repo_name = repo.name
model.log.log_action(kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata) model.log.log_action(
kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata
)
def format_date(date): def format_date(date):
""" Output an RFC822 date format. """ """ Output an RFC822 date format. """
@ -46,7 +50,6 @@ def format_date(date):
return formatdate(timegm(date.utctimetuple())) return formatdate(timegm(date.utctimetuple()))
def resource(*urls, **kwargs): def resource(*urls, **kwargs):
def wrapper(api_resource): def wrapper(api_resource):
if not api_resource: if not api_resource:
@ -72,7 +75,7 @@ def add_method_metadata(name, value):
if func is None: if func is None:
return None return None
if '__api_metadata' not in dir(func): if "__api_metadata" not in dir(func):
func.__api_metadata = {} func.__api_metadata = {}
func.__api_metadata[name] = value func.__api_metadata[name] = value
return func return func
@ -84,7 +87,7 @@ def method_metadata(func, name):
if func is None: if func is None:
return None return None
if '__api_metadata' in dir(func): if "__api_metadata" in dir(func):
return func.__api_metadata.get(name, None) return func.__api_metadata.get(name, None)
return None return None
@ -94,33 +97,36 @@ def no_cache(f):
def add_no_cache(*args, **kwargs): def add_no_cache(*args, **kwargs):
response = f(*args, **kwargs) response = f(*args, **kwargs)
if response is not None: if response is not None:
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
return response return response
return add_no_cache return add_no_cache
def define_json_response(schema_name): def define_json_response(schema_name):
def wrapper(func): def wrapper(func):
@add_method_metadata('response_schema', schema_name) @add_method_metadata("response_schema", schema_name)
@wraps(func) @wraps(func)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name] schema = self.schemas[schema_name]
resp = func(self, *args, **kwargs) resp = func(self, *args, **kwargs)
if app.config['TESTING']: if app.config["TESTING"]:
try: try:
validate(resp, schema) validate(resp, schema)
except ValidationError as ex: except ValidationError as ex:
raise InvalidResponse(ex.message) raise InvalidResponse(ex.message)
return resp return resp
return wrapped return wrapped
return wrapper return wrapper
def validate_json_request(schema_name, optional=False): def validate_json_request(schema_name, optional=False):
def wrapper(func): def wrapper(func):
@add_method_metadata('request_schema', schema_name) @add_method_metadata("request_schema", schema_name)
@wraps(func) @wraps(func)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name] schema = self.schemas[schema_name]
@ -128,26 +134,32 @@ def validate_json_request(schema_name, optional=False):
json_data = request.get_json() json_data = request.get_json()
if json_data is None: if json_data is None:
if not optional: if not optional:
raise InvalidRequest('Missing JSON body') raise InvalidRequest("Missing JSON body")
else: else:
validate(json_data, schema) validate(json_data, schema)
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
except ValidationError as ex: except ValidationError as ex:
raise InvalidRequest(ex.message) raise InvalidRequest(ex.message)
return wrapped return wrapped
return wrapper return wrapper
def kubernetes_only(f): def kubernetes_only(f):
""" Aborts the request with a 400 if the app is not running on kubernetes """ """ Aborts the request with a 400 if the app is not running on kubernetes """
@wraps(f) @wraps(f)
def abort_if_not_kube(*args, **kwargs): def abort_if_not_kube(*args, **kwargs):
if not IS_KUBERNETES: if not IS_KUBERNETES:
abort(400) abort(400)
return f(*args, **kwargs) return f(*args, **kwargs)
return abort_if_not_kube return abort_if_not_kube
nickname = partial(add_method_metadata, 'nickname')
nickname = partial(add_method_metadata, "nickname")
import config_app.config_endpoints.api.discovery import config_app.config_endpoints.api.discovery

View file

@ -5,7 +5,11 @@ from collections import OrderedDict
from config_app.c_app import app from config_app.c_app import app
from config_app.config_endpoints.api import method_metadata from config_app.config_endpoints.api import method_metadata
from config_app.config_endpoints.common import fully_qualified_name, PARAM_REGEX, TYPE_CONVERTER from config_app.config_endpoints.common import (
fully_qualified_name,
PARAM_REGEX,
TYPE_CONVERTER,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,24 +18,25 @@ def generate_route_data():
include_internal = True include_internal = True
compact = True compact = True
def swagger_parameter(name, description, kind='path', param_type='string', required=True, def swagger_parameter(
enum=None, schema=None): name,
description,
kind="path",
param_type="string",
required=True,
enum=None,
schema=None,
):
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
parameter_info = { parameter_info = {"name": name, "in": kind, "required": required}
'name': name,
'in': kind,
'required': required
}
if schema: if schema:
parameter_info['schema'] = { parameter_info["schema"] = {"$ref": "#/definitions/%s" % schema}
'$ref': '#/definitions/%s' % schema
}
else: else:
parameter_info['type'] = param_type parameter_info["type"] = param_type
if enum is not None and len(list(enum)) > 0: if enum is not None and len(list(enum)) > 0:
parameter_info['enum'] = list(enum) parameter_info["enum"] = list(enum)
return parameter_info return parameter_info
@ -45,210 +50,208 @@ def generate_route_data():
endpoint_method = app.view_functions[rule.endpoint] endpoint_method = app.view_functions[rule.endpoint]
# Verify that we have a view class for this API method. # Verify that we have a view class for this API method.
if not 'view_class' in dir(endpoint_method): if not "view_class" in dir(endpoint_method):
continue continue
view_class = endpoint_method.view_class view_class = endpoint_method.view_class
# Hide the class if it is internal. # Hide the class if it is internal.
internal = method_metadata(view_class, 'internal') internal = method_metadata(view_class, "internal")
if not include_internal and internal: if not include_internal and internal:
continue continue
# Build the tag. # Build the tag.
parts = fully_qualified_name(view_class).split('.') parts = fully_qualified_name(view_class).split(".")
tag_name = parts[-2] tag_name = parts[-2]
if not tag_name in tags_added: if not tag_name in tags_added:
tags_added.add(tag_name) tags_added.add(tag_name)
tags.append({ tags.append(
'name': tag_name, {
'description': (sys.modules[view_class.__module__].__doc__ or '').strip() "name": tag_name,
}) "description": (
sys.modules[view_class.__module__].__doc__ or ""
).strip(),
}
)
# Build the Swagger data for the path. # Build the Swagger data for the path.
swagger_path = PARAM_REGEX.sub(r'{\2}', rule.rule) swagger_path = PARAM_REGEX.sub(r"{\2}", rule.rule)
full_name = fully_qualified_name(view_class) full_name = fully_qualified_name(view_class)
path_swagger = { path_swagger = {"x-name": full_name, "x-path": swagger_path, "x-tag": tag_name}
'x-name': full_name,
'x-path': swagger_path,
'x-tag': tag_name
}
related_user_res = method_metadata(view_class, 'related_user_resource') related_user_res = method_metadata(view_class, "related_user_resource")
if related_user_res is not None: if related_user_res is not None:
path_swagger['x-user-related'] = fully_qualified_name(related_user_res) path_swagger["x-user-related"] = fully_qualified_name(related_user_res)
paths[swagger_path] = path_swagger paths[swagger_path] = path_swagger
# Add any global path parameters. # Add any global path parameters.
param_data_map = view_class.__api_path_params if '__api_path_params' in dir( param_data_map = (
view_class) else {} view_class.__api_path_params
if "__api_path_params" in dir(view_class)
else {}
)
if param_data_map: if param_data_map:
path_parameters_swagger = [] path_parameters_swagger = []
for path_parameter in param_data_map: for path_parameter in param_data_map:
description = param_data_map[path_parameter].get('description') description = param_data_map[path_parameter].get("description")
path_parameters_swagger.append(swagger_parameter(path_parameter, description)) path_parameters_swagger.append(
swagger_parameter(path_parameter, description)
)
path_swagger['parameters'] = path_parameters_swagger path_swagger["parameters"] = path_parameters_swagger
# Add the individual HTTP operations. # Add the individual HTTP operations.
method_names = list(rule.methods.difference(['HEAD', 'OPTIONS'])) method_names = list(rule.methods.difference(["HEAD", "OPTIONS"]))
for method_name in method_names: for method_name in method_names:
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
method = getattr(view_class, method_name.lower(), None) method = getattr(view_class, method_name.lower(), None)
if method is None: if method is None:
logger.debug('Unable to find method for %s in class %s', method_name, view_class) logger.debug(
"Unable to find method for %s in class %s", method_name, view_class
)
continue continue
operationId = method_metadata(method, 'nickname') operationId = method_metadata(method, "nickname")
operation_swagger = { operation_swagger = {"operationId": operationId, "parameters": []}
'operationId': operationId,
'parameters': [],
}
if operationId is None: if operationId is None:
continue continue
if operationId in operation_ids: if operationId in operation_ids:
raise Exception('Duplicate operation Id: %s' % operationId) raise Exception("Duplicate operation Id: %s" % operationId)
operation_ids.add(operationId) operation_ids.add(operationId)
# Mark the method as internal. # Mark the method as internal.
internal = method_metadata(method, 'internal') internal = method_metadata(method, "internal")
if internal is not None: if internal is not None:
operation_swagger['x-internal'] = True operation_swagger["x-internal"] = True
if include_internal: if include_internal:
requires_fresh_login = method_metadata(method, 'requires_fresh_login') requires_fresh_login = method_metadata(method, "requires_fresh_login")
if requires_fresh_login is not None: if requires_fresh_login is not None:
operation_swagger['x-requires-fresh-login'] = True operation_swagger["x-requires-fresh-login"] = True
# Add the path parameters. # Add the path parameters.
if rule.arguments: if rule.arguments:
for path_parameter in rule.arguments: for path_parameter in rule.arguments:
description = param_data_map.get(path_parameter, {}).get('description') description = param_data_map.get(path_parameter, {}).get(
operation_swagger['parameters'].append( "description"
swagger_parameter(path_parameter, description)) )
operation_swagger["parameters"].append(
swagger_parameter(path_parameter, description)
)
# Add the query parameters. # Add the query parameters.
if '__api_query_params' in dir(method): if "__api_query_params" in dir(method):
for query_parameter_info in method.__api_query_params: for query_parameter_info in method.__api_query_params:
name = query_parameter_info['name'] name = query_parameter_info["name"]
description = query_parameter_info['help'] description = query_parameter_info["help"]
param_type = TYPE_CONVERTER[query_parameter_info['type']] param_type = TYPE_CONVERTER[query_parameter_info["type"]]
required = query_parameter_info['required'] required = query_parameter_info["required"]
operation_swagger['parameters'].append( operation_swagger["parameters"].append(
swagger_parameter(name, description, kind='query', swagger_parameter(
name,
description,
kind="query",
param_type=param_type, param_type=param_type,
required=required, required=required,
enum=query_parameter_info['choices'])) enum=query_parameter_info["choices"],
)
)
# Add the OAuth security block. # Add the OAuth security block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
scope = method_metadata(method, 'oauth2_scope') scope = method_metadata(method, "oauth2_scope")
if scope and not compact: if scope and not compact:
operation_swagger['security'] = [{'oauth2_implicit': [scope.scope]}] operation_swagger["security"] = [{"oauth2_implicit": [scope.scope]}]
# Add the responses block. # Add the responses block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
response_schema_name = method_metadata(method, 'response_schema') response_schema_name = method_metadata(method, "response_schema")
if not compact: if not compact:
if response_schema_name: if response_schema_name:
models[response_schema_name] = view_class.schemas[response_schema_name] models[response_schema_name] = view_class.schemas[
response_schema_name
models['ApiError'] = {
'type': 'object',
'properties': {
'status': {
'type': 'integer',
'description': 'Status code of the response.'
},
'type': {
'type': 'string',
'description': 'Reference to the type of the error.'
},
'detail': {
'type': 'string',
'description': 'Details about the specific instance of the error.'
},
'title': {
'type': 'string',
'description': 'Unique error code to identify the type of error.'
},
'error_message': {
'type': 'string',
'description': 'Deprecated; alias for detail'
},
'error_type': {
'type': 'string',
'description': 'Deprecated; alias for detail'
}
},
'required': [
'status',
'type',
'title',
] ]
models["ApiError"] = {
"type": "object",
"properties": {
"status": {
"type": "integer",
"description": "Status code of the response.",
},
"type": {
"type": "string",
"description": "Reference to the type of the error.",
},
"detail": {
"type": "string",
"description": "Details about the specific instance of the error.",
},
"title": {
"type": "string",
"description": "Unique error code to identify the type of error.",
},
"error_message": {
"type": "string",
"description": "Deprecated; alias for detail",
},
"error_type": {
"type": "string",
"description": "Deprecated; alias for detail",
},
},
"required": ["status", "type", "title"],
} }
responses = { responses = {
'400': { "400": {"description": "Bad Request"},
'description': 'Bad Request', "401": {"description": "Session required"},
}, "403": {"description": "Unauthorized access"},
"404": {"description": "Not found"},
'401': {
'description': 'Session required',
},
'403': {
'description': 'Unauthorized access',
},
'404': {
'description': 'Not found',
},
} }
for _, body in responses.items(): for _, body in responses.items():
body['schema'] = {'$ref': '#/definitions/ApiError'} body["schema"] = {"$ref": "#/definitions/ApiError"}
if method_name == 'DELETE': if method_name == "DELETE":
responses['204'] = { responses["204"] = {"description": "Deleted"}
'description': 'Deleted' elif method_name == "POST":
} responses["201"] = {"description": "Successful creation"}
elif method_name == 'POST':
responses['201'] = {
'description': 'Successful creation'
}
else: else:
responses['200'] = { responses["200"] = {"description": "Successful invocation"}
'description': 'Successful invocation'
}
if response_schema_name: if response_schema_name:
responses['200']['schema'] = { responses["200"]["schema"] = {
'$ref': '#/definitions/%s' % response_schema_name "$ref": "#/definitions/%s" % response_schema_name
} }
operation_swagger['responses'] = responses operation_swagger["responses"] = responses
# Add the request block. # Add the request block.
request_schema_name = method_metadata(method, 'request_schema') request_schema_name = method_metadata(method, "request_schema")
if request_schema_name and not compact: if request_schema_name and not compact:
models[request_schema_name] = view_class.schemas[request_schema_name] models[request_schema_name] = view_class.schemas[request_schema_name]
operation_swagger['parameters'].append( operation_swagger["parameters"].append(
swagger_parameter('body', 'Request body contents.', kind='body', swagger_parameter(
schema=request_schema_name)) "body",
"Request body contents.",
kind="body",
schema=request_schema_name,
)
)
# Add the operation to the parent path. # Add the operation to the parent path.
if not internal or (internal and include_internal): if not internal or (internal and include_internal):
path_swagger[method_name.lower()] = operation_swagger path_swagger[method_name.lower()] = operation_swagger
tags.sort(key=lambda t: t['name']) tags.sort(key=lambda t: t["name"])
paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]['x-tag'])) paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]["x-tag"]))
if compact: if compact:
return {'paths': paths} return {"paths": paths}

View file

@ -6,83 +6,95 @@ from config_app.config_util.config import get_config_as_kube_secret
from data.database import configure from data.database import configure
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname, kubernetes_only, validate_json_request from config_app.config_endpoints.api import (
from config_app.config_util.k8saccessor import KubernetesAccessorSingleton, K8sApiException resource,
ApiResource,
nickname,
kubernetes_only,
validate_json_request,
)
from config_app.config_util.k8saccessor import (
KubernetesAccessorSingleton,
K8sApiException,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@resource('/v1/kubernetes/deployments/')
@resource("/v1/kubernetes/deployments/")
class SuperUserKubernetesDeployment(ApiResource): class SuperUserKubernetesDeployment(ApiResource):
""" Resource for the getting the status of Red Hat Quay deployments and cycling them """ """ Resource for the getting the status of Red Hat Quay deployments and cycling them """
schemas = { schemas = {
'ValidateDeploymentNames': { "ValidateDeploymentNames": {
'type': 'object', "type": "object",
'description': 'Validates deployment names for cycling', "description": "Validates deployment names for cycling",
'required': [ "required": ["deploymentNames"],
'deploymentNames' "properties": {
], "deploymentNames": {
'properties': { "type": "array",
'deploymentNames': { "description": "The names of the deployments to cycle",
'type': 'array', }
'description': 'The names of the deployments to cycle'
},
}, },
} }
} }
@kubernetes_only @kubernetes_only
@nickname('scGetNumDeployments') @nickname("scGetNumDeployments")
def get(self): def get(self):
return KubernetesAccessorSingleton.get_instance().get_qe_deployments() return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
@kubernetes_only @kubernetes_only
@validate_json_request('ValidateDeploymentNames') @validate_json_request("ValidateDeploymentNames")
@nickname('scCycleQEDeployments') @nickname("scCycleQEDeployments")
def put(self): def put(self):
deployment_names = request.get_json()['deploymentNames'] deployment_names = request.get_json()["deploymentNames"]
return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(deployment_names) return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(
deployment_names
)
@resource('/v1/kubernetes/deployment/<deployment>/status') @resource("/v1/kubernetes/deployment/<deployment>/status")
class QEDeploymentRolloutStatus(ApiResource): class QEDeploymentRolloutStatus(ApiResource):
@kubernetes_only @kubernetes_only
@nickname('scGetDeploymentRolloutStatus') @nickname("scGetDeploymentRolloutStatus")
def get(self, deployment): def get(self, deployment):
deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(deployment) deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(
deployment
)
return { return {
'status': deployment_rollout_status.status, "status": deployment_rollout_status.status,
'message': deployment_rollout_status.message, "message": deployment_rollout_status.message,
} }
@resource('/v1/kubernetes/deployments/rollback') @resource("/v1/kubernetes/deployments/rollback")
class QEDeploymentRollback(ApiResource): class QEDeploymentRollback(ApiResource):
""" Resource for rolling back deployments """ """ Resource for rolling back deployments """
schemas = { schemas = {
'ValidateDeploymentNames': { "ValidateDeploymentNames": {
'type': 'object', "type": "object",
'description': 'Validates deployment names for rolling back', "description": "Validates deployment names for rolling back",
'required': [ "required": ["deploymentNames"],
'deploymentNames' "properties": {
], "deploymentNames": {
'properties': { "type": "array",
'deploymentNames': { "description": "The names of the deployments to rollback",
'type': 'array', }
'description': 'The names of the deployments to rollback'
},
}, },
} }
} }
@kubernetes_only @kubernetes_only
@nickname('scRollbackDeployments') @nickname("scRollbackDeployments")
@validate_json_request('ValidateDeploymentNames') @validate_json_request("ValidateDeploymentNames")
def post(self): def post(self):
""" """
Returns the config to its original state and rolls back deployments Returns the config to its original state and rolls back deployments
:return: :return:
""" """
deployment_names = request.get_json()['deploymentNames'] deployment_names = request.get_json()["deploymentNames"]
# To roll back a deployment, we must do 2 things: # To roll back a deployment, we must do 2 things:
# 1. Roll back the config secret to its old value (discarding changes we made in this session) # 1. Roll back the config secret to its old value (discarding changes we made in this session)
@ -96,35 +108,37 @@ class QEDeploymentRollback(ApiResource):
for name in deployment_names: for name in deployment_names:
kube_accessor.rollback_deployment(name) kube_accessor.rollback_deployment(name)
except K8sApiException as e: except K8sApiException as e:
logger.exception('Failed to rollback deployment.') logger.exception("Failed to rollback deployment.")
return make_response(e.message, 503) return make_response(e.message, 503)
return make_response('Ok', 204) return make_response("Ok", 204)
@resource('/v1/kubernetes/config') @resource("/v1/kubernetes/config")
class SuperUserKubernetesConfiguration(ApiResource): class SuperUserKubernetesConfiguration(ApiResource):
""" Resource for saving the config files to kubernetes secrets. """ """ Resource for saving the config files to kubernetes secrets. """
@kubernetes_only @kubernetes_only
@nickname('scDeployConfiguration') @nickname("scDeployConfiguration")
def post(self): def post(self):
try: try:
new_secret = get_config_as_kube_secret(config_provider.get_config_dir_path()) new_secret = get_config_as_kube_secret(
config_provider.get_config_dir_path()
)
KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret) KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
except K8sApiException as e: except K8sApiException as e:
logger.exception('Failed to deploy qe config secret to kubernetes.') logger.exception("Failed to deploy qe config secret to kubernetes.")
return make_response(e.message, 503) return make_response(e.message, 503)
return make_response('Ok', 201) return make_response("Ok", 201)
@resource('/v1/kubernetes/config/populate') @resource("/v1/kubernetes/config/populate")
class KubernetesConfigurationPopulator(ApiResource): class KubernetesConfigurationPopulator(ApiResource):
""" Resource for populating the local configuration from the cluster's kubernetes secrets. """ """ Resource for populating the local configuration from the cluster's kubernetes secrets. """
@kubernetes_only @kubernetes_only
@nickname('scKubePopulateConfig') @nickname("scKubePopulateConfig")
def post(self): def post(self):
# Get a clean transient directory to write the config into # Get a clean transient directory to write the config into
config_provider.new_config_dir() config_provider.new_config_dir()

View file

@ -2,16 +2,32 @@ import logging
from flask import abort, request from flask import abort, request
from config_app.config_endpoints.api.suconfig_models_pre_oci import pre_oci_model as model from config_app.config_endpoints.api.suconfig_models_pre_oci import (
from config_app.config_endpoints.api import resource, ApiResource, nickname, validate_json_request pre_oci_model as model,
from config_app.c_app import (app, config_provider, superusers, ip_resolver, )
instance_keys, INIT_SCRIPTS_LOCATION) from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
validate_json_request,
)
from config_app.c_app import (
app,
config_provider,
superusers,
ip_resolver,
instance_keys,
INIT_SCRIPTS_LOCATION,
)
from data.database import configure from data.database import configure
from data.runmigration import run_alembic_migration from data.runmigration import run_alembic_migration
from util.config.configutil import add_enterprise_config_defaults from util.config.configutil import add_enterprise_config_defaults
from util.config.validator import validate_service_for_config, ValidatorContext, \ from util.config.validator import (
is_valid_config_upload_filename validate_service_for_config,
ValidatorContext,
is_valid_config_upload_filename,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,45 +42,38 @@ def database_has_users():
return model.has_users() return model.has_users()
@resource('/v1/superuser/config') @resource("/v1/superuser/config")
class SuperUserConfig(ApiResource): class SuperUserConfig(ApiResource):
""" Resource for fetching and updating the current configuration, if any. """ """ Resource for fetching and updating the current configuration, if any. """
schemas = { schemas = {
'UpdateConfig': { "UpdateConfig": {
'type': 'object', "type": "object",
'description': 'Updates the YAML config file', "description": "Updates the YAML config file",
'required': [ "required": ["config"],
'config', "properties": {
], "config": {"type": "object"},
'properties': { "password": {"type": "string"},
'config': {
'type': 'object'
},
'password': {
'type': 'string'
},
},
}, },
} }
}
@nickname('scGetConfig') @nickname("scGetConfig")
def get(self): def get(self):
""" Returns the currently defined configuration, if any. """ """ Returns the currently defined configuration, if any. """
config_object = config_provider.get_config() config_object = config_provider.get_config()
return { return {"config": config_object}
'config': config_object
}
@nickname('scUpdateConfig') @nickname("scUpdateConfig")
@validate_json_request('UpdateConfig') @validate_json_request("UpdateConfig")
def put(self): def put(self):
""" Updates the config override file. """ """ Updates the config override file. """
# Note: This method is called to set the database configuration before super users exists, # Note: This method is called to set the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup. # so we also allow it to be called if there is no valid registry configuration setup.
config_object = request.get_json()['config'] config_object = request.get_json()["config"]
# Add any enterprise defaults missing from the config. # Add any enterprise defaults missing from the config.
add_enterprise_config_defaults(config_object, app.config['SECRET_KEY']) add_enterprise_config_defaults(config_object, app.config["SECRET_KEY"])
# Write the configuration changes to the config override file. # Write the configuration changes to the config override file.
config_provider.save_config(config_object) config_provider.save_config(config_object)
@ -72,44 +81,33 @@ class SuperUserConfig(ApiResource):
# now try to connect to the db provided in their config to validate it works # now try to connect to the db provided in their config to validate it works
combined = dict(**app.config) combined = dict(**app.config)
combined.update(config_provider.get_config()) combined.update(config_provider.get_config())
configure(combined, testing=app.config['TESTING']) configure(combined, testing=app.config["TESTING"])
return { return {"exists": True, "config": config_object}
'exists': True,
'config': config_object
}
@resource('/v1/superuser/registrystatus') @resource("/v1/superuser/registrystatus")
class SuperUserRegistryStatus(ApiResource): class SuperUserRegistryStatus(ApiResource):
""" Resource for determining the status of the registry, such as if config exists, """ Resource for determining the status of the registry, such as if config exists,
if a database is configured, and if it has any defined users. if a database is configured, and if it has any defined users.
""" """
@nickname('scRegistryStatus') @nickname("scRegistryStatus")
def get(self): def get(self):
""" Returns the status of the registry. """ """ Returns the status of the registry. """
# If there is no config file, we need to setup the database. # If there is no config file, we need to setup the database.
if not config_provider.config_exists(): if not config_provider.config_exists():
return { return {"status": "config-db"}
'status': 'config-db'
}
# If the database isn't yet valid, then we need to set it up. # If the database isn't yet valid, then we need to set it up.
if not database_is_valid(): if not database_is_valid():
return { return {"status": "setup-db"}
'status': 'setup-db'
}
config = config_provider.get_config() config = config_provider.get_config()
if config and config.get('SETUP_COMPLETE'): if config and config.get("SETUP_COMPLETE"):
return { return {"status": "config"}
'status': 'config'
}
return { return {"status": "create-superuser" if not database_has_users() else "config"}
'status': 'create-superuser' if not database_has_users() else 'config'
}
class _AlembicLogHandler(logging.Handler): class _AlembicLogHandler(logging.Handler):
@ -118,10 +116,7 @@ class _AlembicLogHandler(logging.Handler):
self.records = [] self.records = []
def emit(self, record): def emit(self, record):
self.records.append({ self.records.append({"level": record.levelname, "message": record.getMessage()})
'level': record.levelname,
'message': record.getMessage()
})
def _reload_config(): def _reload_config():
@ -131,11 +126,11 @@ def _reload_config():
return combined return combined
@resource('/v1/superuser/setupdb') @resource("/v1/superuser/setupdb")
class SuperUserSetupDatabase(ApiResource): class SuperUserSetupDatabase(ApiResource):
""" Resource for invoking alembic to setup the database. """ """ Resource for invoking alembic to setup the database. """
@nickname('scSetupDatabase') @nickname("scSetupDatabase")
def get(self): def get(self):
""" Invokes the alembic upgrade process. """ """ Invokes the alembic upgrade process. """
# Note: This method is called after the database configured is saved, but before the # Note: This method is called after the database configured is saved, but before the
@ -143,57 +138,50 @@ class SuperUserSetupDatabase(ApiResource):
if config_provider.config_exists() and not database_is_valid(): if config_provider.config_exists() and not database_is_valid():
combined = _reload_config() combined = _reload_config()
app.config['DB_URI'] = combined['DB_URI'] app.config["DB_URI"] = combined["DB_URI"]
db_uri = app.config['DB_URI'] db_uri = app.config["DB_URI"]
escaped_db_uri = db_uri.replace('%', '%%') escaped_db_uri = db_uri.replace("%", "%%")
log_handler = _AlembicLogHandler() log_handler = _AlembicLogHandler()
try: try:
run_alembic_migration(escaped_db_uri, log_handler, setup_app=False) run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
except Exception as ex: except Exception as ex:
return { return {"error": str(ex)}
'error': str(ex)
}
return { return {"logs": log_handler.records}
'logs': log_handler.records
}
abort(403) abort(403)
@resource('/v1/superuser/config/createsuperuser') @resource("/v1/superuser/config/createsuperuser")
class SuperUserCreateInitialSuperUser(ApiResource): class SuperUserCreateInitialSuperUser(ApiResource):
""" Resource for creating the initial super user. """ """ Resource for creating the initial super user. """
schemas = { schemas = {
'CreateSuperUser': { "CreateSuperUser": {
'type': 'object', "type": "object",
'description': 'Information for creating the initial super user', "description": "Information for creating the initial super user",
'required': [ "required": ["username", "password", "email"],
'username', "properties": {
'password', "username": {
'email' "type": "string",
], "description": "The username for the superuser",
'properties': {
'username': {
'type': 'string',
'description': 'The username for the superuser'
}, },
'password': { "password": {
'type': 'string', "type": "string",
'description': 'The password for the superuser' "description": "The password for the superuser",
},
'email': {
'type': 'string',
'description': 'The e-mail address for the superuser'
}, },
"email": {
"type": "string",
"description": "The e-mail address for the superuser",
}, },
}, },
} }
}
@nickname('scCreateInitialSuperuser') @nickname("scCreateInitialSuperuser")
@validate_json_request('CreateSuperUser') @validate_json_request("CreateSuperUser")
def post(self): def post(self):
""" Creates the initial super user, updates the underlying configuration and """ Creates the initial super user, updates the underlying configuration and
sets the current session to have that super user. """ sets the current session to have that super user. """
@ -208,83 +196,79 @@ class SuperUserCreateInitialSuperUser(ApiResource):
# is clean but does not (yet) have any super users for our permissions code to check against. # is clean but does not (yet) have any super users for our permissions code to check against.
if config_provider.config_exists() and not database_has_users(): if config_provider.config_exists() and not database_has_users():
data = request.get_json() data = request.get_json()
username = data['username'] username = data["username"]
password = data['password'] password = data["password"]
email = data['email'] email = data["email"]
# Create the user in the database. # Create the user in the database.
superuser_uuid = model.create_superuser(username, password, email) superuser_uuid = model.create_superuser(username, password, email)
# Add the user to the config. # Add the user to the config.
config_object = config_provider.get_config() config_object = config_provider.get_config()
config_object['SUPER_USERS'] = [username] config_object["SUPER_USERS"] = [username]
config_provider.save_config(config_object) config_provider.save_config(config_object)
# Update the in-memory config for the new superuser. # Update the in-memory config for the new superuser.
superusers.register_superuser(username) superusers.register_superuser(username)
return { return {"status": True}
'status': True
}
abort(403) abort(403)
@resource('/v1/superuser/config/validate/<service>') @resource("/v1/superuser/config/validate/<service>")
class SuperUserConfigValidate(ApiResource): class SuperUserConfigValidate(ApiResource):
""" Resource for validating a block of configuration against an external service. """ """ Resource for validating a block of configuration against an external service. """
schemas = { schemas = {
'ValidateConfig': { "ValidateConfig": {
'type': 'object', "type": "object",
'description': 'Validates configuration', "description": "Validates configuration",
'required': [ "required": ["config"],
'config' "properties": {
], "config": {"type": "object"},
'properties': { "password": {
'config': { "type": "string",
'type': 'object' "description": "The users password, used for auth validation",
},
}, },
'password': {
'type': 'string',
'description': 'The users password, used for auth validation'
} }
},
},
} }
@nickname('scValidateConfig') @nickname("scValidateConfig")
@validate_json_request('ValidateConfig') @validate_json_request("ValidateConfig")
def post(self, service): def post(self, service):
""" Validates the given config for the given service. """ """ Validates the given config for the given service. """
# Note: This method is called to validate the database configuration before super users exists, # Note: This method is called to validate the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup. Note that # so we also allow it to be called if there is no valid registry configuration setup. Note that
# this is also safe since this method does not access any information not given in the request. # this is also safe since this method does not access any information not given in the request.
config = request.get_json()['config'] config = request.get_json()["config"]
validator_context = ValidatorContext.from_app(app, config, validator_context = ValidatorContext.from_app(
request.get_json().get('password', ''), app,
config,
request.get_json().get("password", ""),
instance_keys=instance_keys, instance_keys=instance_keys,
ip_resolver=ip_resolver, ip_resolver=ip_resolver,
config_provider=config_provider, config_provider=config_provider,
init_scripts_location=INIT_SCRIPTS_LOCATION) init_scripts_location=INIT_SCRIPTS_LOCATION,
)
return validate_service_for_config(service, validator_context) return validate_service_for_config(service, validator_context)
@resource('/v1/superuser/config/file/<filename>') @resource("/v1/superuser/config/file/<filename>")
class SuperUserConfigFile(ApiResource): class SuperUserConfigFile(ApiResource):
""" Resource for fetching the status of config files and overriding them. """ """ Resource for fetching the status of config files and overriding them. """
@nickname('scConfigFileExists') @nickname("scConfigFileExists")
def get(self, filename): def get(self, filename):
""" Returns whether the configuration file with the given name exists. """ """ Returns whether the configuration file with the given name exists. """
if not is_valid_config_upload_filename(filename): if not is_valid_config_upload_filename(filename):
abort(404) abort(404)
return { return {"exists": config_provider.volume_file_exists(filename)}
'exists': config_provider.volume_file_exists(filename)
}
@nickname('scUpdateConfigFile') @nickname("scUpdateConfigFile")
def post(self, filename): def post(self, filename):
""" Updates the configuration file with the given name. """ """ Updates the configuration file with the given name. """
if not is_valid_config_upload_filename(filename): if not is_valid_config_upload_filename(filename):
@ -292,11 +276,9 @@ class SuperUserConfigFile(ApiResource):
# Note: This method can be called before the configuration exists # Note: This method can be called before the configuration exists
# to upload the database SSL cert. # to upload the database SSL cert.
uploaded_file = request.files['file'] uploaded_file = request.files["file"]
if not uploaded_file: if not uploaded_file:
abort(400) abort(400)
config_provider.save_volume_file(filename, uploaded_file) config_provider.save_volume_file(filename, uploaded_file)
return { return {"status": True}
'status': True
}

View file

@ -1,6 +1,8 @@
from data import model from data import model
from data.database import User from data.database import User
from config_app.config_endpoints.api.suconfig_models_interface import SuperuserConfigDataInterface from config_app.config_endpoints.api.suconfig_models_interface import (
SuperuserConfigDataInterface,
)
class PreOCIModel(SuperuserConfigDataInterface): class PreOCIModel(SuperuserConfigDataInterface):

View file

@ -12,7 +12,13 @@ from data.model import ServiceKeyDoesNotExist
from util.config.validator import EXTRA_CA_DIRECTORY from util.config.validator import EXTRA_CA_DIRECTORY
from config_app.config_endpoints.exception import InvalidRequest from config_app.config_endpoints.exception import InvalidRequest
from config_app.config_endpoints.api import resource, ApiResource, nickname, log_action, validate_json_request from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
log_action,
validate_json_request,
)
from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model
from config_app.config_util.ssl import load_certificate, CertInvalidException from config_app.config_util.ssl import load_certificate, CertInvalidException
from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
@ -21,228 +27,233 @@ from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@resource('/v1/superuser/customcerts/<certpath>') @resource("/v1/superuser/customcerts/<certpath>")
class SuperUserCustomCertificate(ApiResource): class SuperUserCustomCertificate(ApiResource):
""" Resource for managing a custom certificate. """ """ Resource for managing a custom certificate. """
@nickname('uploadCustomCertificate') @nickname("uploadCustomCertificate")
def post(self, certpath): def post(self, certpath):
uploaded_file = request.files['file'] uploaded_file = request.files["file"]
if not uploaded_file: if not uploaded_file:
raise InvalidRequest('Missing certificate file') raise InvalidRequest("Missing certificate file")
# Save the certificate. # Save the certificate.
certpath = pathvalidate.sanitize_filename(certpath) certpath = pathvalidate.sanitize_filename(certpath)
if not certpath.endswith('.crt'): if not certpath.endswith(".crt"):
raise InvalidRequest('Invalid certificate file: must have suffix `.crt`') raise InvalidRequest("Invalid certificate file: must have suffix `.crt`")
logger.debug('Saving custom certificate %s', certpath) logger.debug("Saving custom certificate %s", certpath)
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath) cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.save_volume_file(cert_full_path, uploaded_file) config_provider.save_volume_file(cert_full_path, uploaded_file)
logger.debug('Saved custom certificate %s', certpath) logger.debug("Saved custom certificate %s", certpath)
# Validate the certificate. # Validate the certificate.
try: try:
logger.debug('Loading custom certificate %s', certpath) logger.debug("Loading custom certificate %s", certpath)
with config_provider.get_volume_file(cert_full_path) as f: with config_provider.get_volume_file(cert_full_path) as f:
load_certificate(f.read()) load_certificate(f.read())
except CertInvalidException: except CertInvalidException:
logger.exception('Got certificate invalid error for cert %s', certpath) logger.exception("Got certificate invalid error for cert %s", certpath)
return '', 204 return "", 204
except IOError: except IOError:
logger.exception('Got IO error for cert %s', certpath) logger.exception("Got IO error for cert %s", certpath)
return '', 204 return "", 204
# Call the update script with config dir location to install the certificate immediately. # Call the update script with config dir location to install the certificate immediately.
if not app.config['TESTING']: if not app.config["TESTING"]:
cert_dir = os.path.join(config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY) cert_dir = os.path.join(
if subprocess.call([os.path.join(INIT_SCRIPTS_LOCATION, 'certs_install.sh')], env={ 'CERTDIR': cert_dir }) != 0: config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY
raise Exception('Could not install certificates') )
if (
subprocess.call(
[os.path.join(INIT_SCRIPTS_LOCATION, "certs_install.sh")],
env={"CERTDIR": cert_dir},
)
!= 0
):
raise Exception("Could not install certificates")
return '', 204 return "", 204
@nickname('deleteCustomCertificate') @nickname("deleteCustomCertificate")
def delete(self, certpath): def delete(self, certpath):
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath) cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.remove_volume_file(cert_full_path) config_provider.remove_volume_file(cert_full_path)
return '', 204 return "", 204
@resource('/v1/superuser/customcerts') @resource("/v1/superuser/customcerts")
class SuperUserCustomCertificates(ApiResource): class SuperUserCustomCertificates(ApiResource):
""" Resource for managing custom certificates. """ """ Resource for managing custom certificates. """
@nickname('getCustomCertificates') @nickname("getCustomCertificates")
def get(self): def get(self):
has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY) has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY) extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
if extra_certs_found is None: if extra_certs_found is None:
return { return {"status": "file" if has_extra_certs_path else "none"}
'status': 'file' if has_extra_certs_path else 'none',
}
cert_views = [] cert_views = []
for extra_cert_path in extra_certs_found: for extra_cert_path in extra_certs_found:
try: try:
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, extra_cert_path) cert_full_path = config_provider.get_volume_path(
EXTRA_CA_DIRECTORY, extra_cert_path
)
with config_provider.get_volume_file(cert_full_path) as f: with config_provider.get_volume_file(cert_full_path) as f:
certificate = load_certificate(f.read()) certificate = load_certificate(f.read())
cert_views.append({ cert_views.append(
'path': extra_cert_path, {
'names': list(certificate.names), "path": extra_cert_path,
'expired': certificate.expired, "names": list(certificate.names),
}) "expired": certificate.expired,
except CertInvalidException as cie:
cert_views.append({
'path': extra_cert_path,
'error': cie.message,
})
except IOError as ioe:
cert_views.append({
'path': extra_cert_path,
'error': ioe.message,
})
return {
'status': 'directory',
'certs': cert_views,
} }
)
except CertInvalidException as cie:
cert_views.append({"path": extra_cert_path, "error": cie.message})
except IOError as ioe:
cert_views.append({"path": extra_cert_path, "error": ioe.message})
return {"status": "directory", "certs": cert_views}
@resource('/v1/superuser/keys') @resource("/v1/superuser/keys")
class SuperUserServiceKeyManagement(ApiResource): class SuperUserServiceKeyManagement(ApiResource):
""" Resource for managing service keys.""" """ Resource for managing service keys."""
schemas = { schemas = {
'CreateServiceKey': { "CreateServiceKey": {
'id': 'CreateServiceKey', "id": "CreateServiceKey",
'type': 'object', "type": "object",
'description': 'Description of creation of a service key', "description": "Description of creation of a service key",
'required': ['service', 'expiration'], "required": ["service", "expiration"],
'properties': { "properties": {
'service': { "service": {
'type': 'string', "type": "string",
'description': 'The service authenticating with this key', "description": "The service authenticating with this key",
}, },
'name': { "name": {
'type': 'string', "type": "string",
'description': 'The friendly name of a service key', "description": "The friendly name of a service key",
}, },
'metadata': { "metadata": {
'type': 'object', "type": "object",
'description': 'The key/value pairs of this key\'s metadata', "description": "The key/value pairs of this key's metadata",
}, },
'notes': { "notes": {
'type': 'string', "type": "string",
'description': 'If specified, the extra notes for the key', "description": "If specified, the extra notes for the key",
},
'expiration': {
'description': 'The expiration date as a unix timestamp',
'anyOf': [{'type': 'number'}, {'type': 'null'}],
}, },
"expiration": {
"description": "The expiration date as a unix timestamp",
"anyOf": [{"type": "number"}, {"type": "null"}],
}, },
}, },
} }
}
@nickname('listServiceKeys') @nickname("listServiceKeys")
def get(self): def get(self):
keys = pre_oci_model.list_all_service_keys() keys = pre_oci_model.list_all_service_keys()
return jsonify({ return jsonify({"keys": [key.to_dict() for key in keys]})
'keys': [key.to_dict() for key in keys],
})
@nickname('createServiceKey') @nickname("createServiceKey")
@validate_json_request('CreateServiceKey') @validate_json_request("CreateServiceKey")
def post(self): def post(self):
body = request.get_json() body = request.get_json()
# Ensure we have a valid expiration date if specified. # Ensure we have a valid expiration date if specified.
expiration_date = body.get('expiration', None) expiration_date = body.get("expiration", None)
if expiration_date is not None: if expiration_date is not None:
try: try:
expiration_date = datetime.utcfromtimestamp(float(expiration_date)) expiration_date = datetime.utcfromtimestamp(float(expiration_date))
except ValueError as ve: except ValueError as ve:
raise InvalidRequest('Invalid expiration date: %s' % ve) raise InvalidRequest("Invalid expiration date: %s" % ve)
if expiration_date <= datetime.now(): if expiration_date <= datetime.now():
raise InvalidRequest('Expiration date cannot be in the past') raise InvalidRequest("Expiration date cannot be in the past")
# Create the metadata for the key. # Create the metadata for the key.
metadata = body.get('metadata', {}) metadata = body.get("metadata", {})
metadata.update({ metadata.update(
'created_by': 'Quay Superuser Panel', {"created_by": "Quay Superuser Panel", "ip": request.remote_addr}
'ip': request.remote_addr, )
})
# Generate a key with a private key that we *never save*. # Generate a key with a private key that we *never save*.
(private_key, key_id) = pre_oci_model.generate_service_key(body['service'], expiration_date, (private_key, key_id) = pre_oci_model.generate_service_key(
body["service"],
expiration_date,
metadata=metadata, metadata=metadata,
name=body.get('name', '')) name=body.get("name", ""),
)
# Auto-approve the service key. # Auto-approve the service key.
pre_oci_model.approve_service_key(key_id, ServiceKeyApprovalType.SUPERUSER, pre_oci_model.approve_service_key(
notes=body.get('notes', '')) key_id, ServiceKeyApprovalType.SUPERUSER, notes=body.get("notes", "")
)
# Log the creation and auto-approval of the service key. # Log the creation and auto-approval of the service key.
key_log_metadata = { key_log_metadata = {
'kid': key_id, "kid": key_id,
'preshared': True, "preshared": True,
'service': body['service'], "service": body["service"],
'name': body.get('name', ''), "name": body.get("name", ""),
'expiration_date': expiration_date, "expiration_date": expiration_date,
'auto_approved': True, "auto_approved": True,
} }
log_action('service_key_create', None, key_log_metadata) log_action("service_key_create", None, key_log_metadata)
log_action('service_key_approve', None, key_log_metadata) log_action("service_key_approve", None, key_log_metadata)
return jsonify({ return jsonify(
'kid': key_id, {
'name': body.get('name', ''), "kid": key_id,
'service': body['service'], "name": body.get("name", ""),
'public_key': private_key.publickey().exportKey('PEM'), "service": body["service"],
'private_key': private_key.exportKey('PEM'), "public_key": private_key.publickey().exportKey("PEM"),
}) "private_key": private_key.exportKey("PEM"),
}
)
@resource('/v1/superuser/approvedkeys/<kid>')
@resource("/v1/superuser/approvedkeys/<kid>")
class SuperUserServiceKeyApproval(ApiResource): class SuperUserServiceKeyApproval(ApiResource):
""" Resource for approving service keys. """ """ Resource for approving service keys. """
schemas = { schemas = {
'ApproveServiceKey': { "ApproveServiceKey": {
'id': 'ApproveServiceKey', "id": "ApproveServiceKey",
'type': 'object', "type": "object",
'description': 'Information for approving service keys', "description": "Information for approving service keys",
'properties': { "properties": {
'notes': { "notes": {"type": "string", "description": "Optional approval notes"}
'type': 'string',
'description': 'Optional approval notes',
},
},
}, },
} }
}
@nickname('approveServiceKey') @nickname("approveServiceKey")
@validate_json_request('ApproveServiceKey') @validate_json_request("ApproveServiceKey")
def post(self, kid): def post(self, kid):
notes = request.get_json().get('notes', '') notes = request.get_json().get("notes", "")
try: try:
key = pre_oci_model.approve_service_key(kid, ServiceKeyApprovalType.SUPERUSER, notes=notes) key = pre_oci_model.approve_service_key(
kid, ServiceKeyApprovalType.SUPERUSER, notes=notes
)
# Log the approval of the service key. # Log the approval of the service key.
key_log_metadata = { key_log_metadata = {
'kid': kid, "kid": kid,
'service': key.service, "service": key.service,
'name': key.name, "name": key.name,
'expiration_date': key.expiration_date, "expiration_date": key.expiration_date,
} }
# Note: this may not actually be the current person modifying the config, but if they're in the config tool, # Note: this may not actually be the current person modifying the config, but if they're in the config tool,
# they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine # they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
super_user = app.config.get('SUPER_USERS', [None])[0] super_user = app.config.get("SUPER_USERS", [None])[0]
log_action('service_key_approve', super_user, key_log_metadata) log_action("service_key_approve", super_user, key_log_metadata)
except ServiceKeyDoesNotExist: except ServiceKeyDoesNotExist:
raise NotFound() raise NotFound()
except ServiceKeyAlreadyApproved: except ServiceKeyAlreadyApproved:
pass pass
return make_response('', 201) return make_response("", 201)

View file

@ -6,20 +6,32 @@ from config_app.config_endpoints.api import format_date
def user_view(user): def user_view(user):
return { return {"name": user.username, "kind": "user", "is_robot": user.robot}
'name': user.username,
'kind': 'user',
'is_robot': user.robot,
}
class RepositoryBuild(namedtuple('RepositoryBuild', class RepositoryBuild(
['uuid', 'logs_archived', 'repository_namespace_user_username', namedtuple(
'repository_name', "RepositoryBuild",
'can_write', 'can_read', 'pull_robot', 'resource_key', 'trigger', [
'display_name', "uuid",
'started', 'job_config', 'phase', 'status', 'error', "logs_archived",
'archive_url'])): "repository_namespace_user_username",
"repository_name",
"can_write",
"can_read",
"pull_robot",
"resource_key",
"trigger",
"display_name",
"started",
"job_config",
"phase",
"status",
"error",
"archive_url",
],
)
):
""" """
RepositoryBuild represents a build associated with a repostiory RepositoryBuild represents a build associated with a repostiory
:type uuid: string :type uuid: string
@ -43,38 +55,42 @@ class RepositoryBuild(namedtuple('RepositoryBuild',
def to_dict(self): def to_dict(self):
resp = { resp = {
'id': self.uuid, "id": self.uuid,
'phase': self.phase, "phase": self.phase,
'started': format_date(self.started), "started": format_date(self.started),
'display_name': self.display_name, "display_name": self.display_name,
'status': self.status or {}, "status": self.status or {},
'subdirectory': self.job_config.get('build_subdir', ''), "subdirectory": self.job_config.get("build_subdir", ""),
'dockerfile_path': self.job_config.get('build_subdir', ''), "dockerfile_path": self.job_config.get("build_subdir", ""),
'context': self.job_config.get('context', ''), "context": self.job_config.get("context", ""),
'tags': self.job_config.get('docker_tags', []), "tags": self.job_config.get("docker_tags", []),
'manual_user': self.job_config.get('manual_user', None), "manual_user": self.job_config.get("manual_user", None),
'is_writer': self.can_write, "is_writer": self.can_write,
'trigger': self.trigger.to_dict(), "trigger": self.trigger.to_dict(),
'trigger_metadata': self.job_config.get('trigger_metadata', None) if self.can_read else None, "trigger_metadata": self.job_config.get("trigger_metadata", None)
'resource_key': self.resource_key, if self.can_read
'pull_robot': user_view(self.pull_robot) if self.pull_robot else None, else None,
'repository': { "resource_key": self.resource_key,
'namespace': self.repository_namespace_user_username, "pull_robot": user_view(self.pull_robot) if self.pull_robot else None,
'name': self.repository_name "repository": {
"namespace": self.repository_namespace_user_username,
"name": self.repository_name,
}, },
'error': self.error, "error": self.error,
} }
if self.can_write: if self.can_write:
if self.resource_key is not None: if self.resource_key is not None:
resp['archive_url'] = self.archive_url resp["archive_url"] = self.archive_url
elif self.job_config.get('archive_url', None): elif self.job_config.get("archive_url", None):
resp['archive_url'] = self.job_config['archive_url'] resp["archive_url"] = self.job_config["archive_url"]
return resp return resp
class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_date', 'notes'])): class Approval(
namedtuple("Approval", ["approver", "approval_type", "approved_date", "notes"])
):
""" """
Approval represents whether a key has been approved or not Approval represents whether a key has been approved or not
:type approver: User :type approver: User
@ -85,16 +101,29 @@ class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_da
def to_dict(self): def to_dict(self):
return { return {
'approver': self.approver.to_dict() if self.approver else None, "approver": self.approver.to_dict() if self.approver else None,
'approval_type': self.approval_type, "approval_type": self.approval_type,
'approved_date': self.approved_date, "approved_date": self.approved_date,
'notes': self.notes, "notes": self.notes,
} }
class ServiceKey( class ServiceKey(
namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'metadata', 'created_date', namedtuple(
'expiration_date', 'rotation_duration', 'approval'])): "ServiceKey",
[
"name",
"kid",
"service",
"jwk",
"metadata",
"created_date",
"expiration_date",
"rotation_duration",
"approval",
],
)
):
""" """
ServiceKey is an apostille signing key ServiceKey is an apostille signing key
:type name: string :type name: string
@ -111,19 +140,19 @@ class ServiceKey(
def to_dict(self): def to_dict(self):
return { return {
'name': self.name, "name": self.name,
'kid': self.kid, "kid": self.kid,
'service': self.service, "service": self.service,
'jwk': self.jwk, "jwk": self.jwk,
'metadata': self.metadata, "metadata": self.metadata,
'created_date': self.created_date, "created_date": self.created_date,
'expiration_date': self.expiration_date, "expiration_date": self.expiration_date,
'rotation_duration': self.rotation_duration, "rotation_duration": self.rotation_duration,
'approval': self.approval.to_dict() if self.approval is not None else None, "approval": self.approval.to_dict() if self.approval is not None else None,
} }
class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robot'])): class User(namedtuple("User", ["username", "email", "verified", "enabled", "robot"])):
""" """
User represents a single user. User represents a single user.
:type username: string :type username: string
@ -135,18 +164,18 @@ class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robo
def to_dict(self): def to_dict(self):
user_data = { user_data = {
'kind': 'user', "kind": "user",
'name': self.username, "name": self.username,
'username': self.username, "username": self.username,
'email': self.email, "email": self.email,
'verified': self.verified, "verified": self.verified,
'enabled': self.enabled, "enabled": self.enabled,
} }
return user_data return user_data
class Organization(namedtuple('Organization', ['username', 'email'])): class Organization(namedtuple("Organization", ["username", "email"])):
""" """
Organization represents a single org. Organization represents a single org.
:type username: string :type username: string
@ -154,10 +183,7 @@ class Organization(namedtuple('Organization', ['username', 'email'])):
""" """
def to_dict(self): def to_dict(self):
return { return {"name": self.username, "email": self.email}
'name': self.username,
'email': self.email,
}
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)

View file

@ -1,7 +1,11 @@
from data import model from data import model
from config_app.config_endpoints.api.superuser_models_interface import (SuperuserDataInterface, User, ServiceKey, from config_app.config_endpoints.api.superuser_models_interface import (
Approval) SuperuserDataInterface,
User,
ServiceKey,
Approval,
)
def _create_user(user): def _create_user(user):
@ -13,13 +17,24 @@ def _create_user(user):
def _create_key(key): def _create_key(key):
approval = None approval = None
if key.approval is not None: if key.approval is not None:
approval = Approval(_create_user(key.approval.approver), key.approval.approval_type, approval = Approval(
_create_user(key.approval.approver),
key.approval.approval_type,
key.approval.approved_date, key.approval.approved_date,
key.approval.notes) key.approval.notes,
)
return ServiceKey(key.name, key.kid, key.service, key.jwk, key.metadata, key.created_date, return ServiceKey(
key.name,
key.kid,
key.service,
key.jwk,
key.metadata,
key.created_date,
key.expiration_date, key.expiration_date,
key.rotation_duration, approval) key.rotation_duration,
approval,
)
class ServiceKeyDoesNotExist(Exception): class ServiceKeyDoesNotExist(Exception):
@ -40,19 +55,29 @@ class PreOCIModel(SuperuserDataInterface):
keys = model.service_keys.list_all_keys() keys = model.service_keys.list_all_keys()
return [_create_key(key) for key in keys] return [_create_key(key) for key in keys]
def approve_service_key(self, kid, approval_type, notes=''): def approve_service_key(self, kid, approval_type, notes=""):
try: try:
key = model.service_keys.approve_service_key(kid, approval_type, notes=notes) key = model.service_keys.approve_service_key(
kid, approval_type, notes=notes
)
return _create_key(key) return _create_key(key)
except model.ServiceKeyDoesNotExist: except model.ServiceKeyDoesNotExist:
raise ServiceKeyDoesNotExist raise ServiceKeyDoesNotExist
except model.ServiceKeyAlreadyApproved: except model.ServiceKeyAlreadyApproved:
raise ServiceKeyAlreadyApproved raise ServiceKeyAlreadyApproved
def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None, def generate_service_key(
rotation_duration=None): self,
(private_key, key) = model.service_keys.generate_service_key(service, expiration_date, service,
metadata=metadata, name=name) expiration_date,
kid=None,
name="",
metadata=None,
rotation_duration=None,
):
(private_key, key) = model.service_keys.generate_service_key(
service, expiration_date, metadata=metadata, name=name
)
return private_key, key.kid return private_key, key.kid

View file

@ -10,29 +10,32 @@ from data.database import configure
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_util.tar import tarinfo_filter_partial, strip_absolute_path_and_add_trailing_dir from config_app.config_util.tar import (
tarinfo_filter_partial,
strip_absolute_path_and_add_trailing_dir,
)
@resource('/v1/configapp/initialization') @resource("/v1/configapp/initialization")
class ConfigInitialization(ApiResource): class ConfigInitialization(ApiResource):
""" """
Resource for dealing with any initialization logic for the config app Resource for dealing with any initialization logic for the config app
""" """
@nickname('scStartNewConfig') @nickname("scStartNewConfig")
def post(self): def post(self):
config_provider.new_config_dir() config_provider.new_config_dir()
return make_response('OK') return make_response("OK")
@resource('/v1/configapp/tarconfig') @resource("/v1/configapp/tarconfig")
class TarConfigLoader(ApiResource): class TarConfigLoader(ApiResource):
""" """
Resource for dealing with configuration as a tarball, Resource for dealing with configuration as a tarball,
including loading and generating functions including loading and generating functions
""" """
@nickname('scGetConfigTarball') @nickname("scGetConfigTarball")
def get(self): def get(self):
config_path = config_provider.get_config_dir_path() config_path = config_provider.get_config_dir_path()
tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path) tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
@ -40,10 +43,13 @@ class TarConfigLoader(ApiResource):
with closing(tarfile.open(temp.name, mode="w|gz")) as tar: with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
for name in os.listdir(config_path): for name in os.listdir(config_path):
tar.add(os.path.join(config_path, name), filter=tarinfo_filter_partial(tar_dir_prefix)) tar.add(
return send_file(temp.name, mimetype='application/gzip') os.path.join(config_path, name),
filter=tarinfo_filter_partial(tar_dir_prefix),
)
return send_file(temp.name, mimetype="application/gzip")
@nickname('scUploadTarballConfig') @nickname("scUploadTarballConfig")
def put(self): def put(self):
""" Loads tarball config into the config provider """ """ Loads tarball config into the config provider """
# Generate a new empty dir to load the config into # Generate a new empty dir to load the config into
@ -59,4 +65,4 @@ class TarConfigLoader(ApiResource):
combined.update(config_provider.get_config()) combined.update(config_provider.get_config())
configure(combined) configure(combined)
return make_response('OK') return make_response("OK")

View file

@ -3,11 +3,11 @@ from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_endpoints.api.superuser_models_interface import user_view from config_app.config_endpoints.api.superuser_models_interface import user_view
@resource('/v1/user/') @resource("/v1/user/")
class User(ApiResource): class User(ApiResource):
""" Operations related to users. """ """ Operations related to users. """
@nickname('getLoggedInUser') @nickname("getLoggedInUser")
def get(self): def get(self):
""" Get user information for the authenticated user. """ """ Get user information for the authenticated user. """
user = get_authenticated_user() user = get_authenticated_user()

View file

@ -14,18 +14,18 @@ from config_app.config_util.k8sconfig import get_k8s_namespace
def truthy_bool(param): def truthy_bool(param):
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'} return param not in {False, "false", "False", "0", "FALSE", "", "null"}
DEFAULT_JS_BUNDLE_NAME = 'configapp' DEFAULT_JS_BUNDLE_NAME = "configapp"
PARAM_REGEX = re.compile(r'<([^:>]+:)*([\w]+)>') PARAM_REGEX = re.compile(r"<([^:>]+:)*([\w]+)>")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TYPE_CONVERTER = { TYPE_CONVERTER = {
truthy_bool: 'boolean', truthy_bool: "boolean",
str: 'string', str: "string",
basestring: 'string', basestring: "string",
reqparse.text_type: 'string', reqparse.text_type: "string",
int: 'integer', int: "integer",
} }
@ -33,41 +33,53 @@ def _list_files(path, extension, contains=""):
""" Returns a list of all the files with the given extension found under the given path. """ """ Returns a list of all the files with the given extension found under the given path. """
def matches(f): def matches(f):
return os.path.splitext(f)[1] == '.' + extension and contains in os.path.splitext(f)[0] return (
os.path.splitext(f)[1] == "." + extension
and contains in os.path.splitext(f)[0]
)
def join_path(dp, f): def join_path(dp, f):
# Remove the static/ prefix. It is added in the template. # Remove the static/ prefix. It is added in the template.
return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len('config_app/static/'):] return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len("config_app/static/") :]
filepath = os.path.join(os.path.join(ROOT_DIR, 'config_app/static/'), path) filepath = os.path.join(os.path.join(ROOT_DIR, "config_app/static/"), path)
return [join_path(dp, f) for dp, _, files in os.walk(filepath) for f in files if matches(f)] return [
join_path(dp, f)
for dp, _, files in os.walk(filepath)
for f in files
if matches(f)
]
FONT_AWESOME_4 = 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css' FONT_AWESOME_4 = "netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css"
def render_page_template(name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs): def render_page_template(
name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs
):
""" Renders the page template with the given name as the response and returns its contents. """ """ Renders the page template with the given name as the response and returns its contents. """
main_scripts = _list_files('build', 'js', js_bundle_name) main_scripts = _list_files("build", "js", js_bundle_name)
use_cdn = os.getenv('TESTING') == 'true' use_cdn = os.getenv("TESTING") == "true"
external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4) external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
external_scripts = get_external_javascript(local=not use_cdn) external_scripts = get_external_javascript(local=not use_cdn)
contents = render_template(name, contents = render_template(
name,
route_data=route_data, route_data=route_data,
main_scripts=main_scripts, main_scripts=main_scripts,
external_styles=external_styles, external_styles=external_styles,
external_scripts=external_scripts, external_scripts=external_scripts,
config_set=frontend_visible_config(app.config), config_set=frontend_visible_config(app.config),
kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(), kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
**kwargs) **kwargs
)
resp = make_response(contents) resp = make_response(contents)
resp.headers['X-FRAME-OPTIONS'] = 'DENY' resp.headers["X-FRAME-OPTIONS"] = "DENY"
return resp return resp
def fully_qualified_name(method_view_class): def fully_qualified_name(method_view_class):
return '%s.%s' % (method_view_class.__module__, method_view_class.__name__) return "%s.%s" % (method_view_class.__module__, method_view_class.__name__)

View file

@ -5,7 +5,7 @@ from werkzeug.exceptions import HTTPException
class ApiErrorType(Enum): class ApiErrorType(Enum):
invalid_request = 'invalid_request' invalid_request = "invalid_request"
class ApiException(HTTPException): class ApiException(HTTPException):
@ -45,22 +45,28 @@ class ApiException(HTTPException):
rv = dict(self.payload or ()) rv = dict(self.payload or ())
if self.error_description is not None: if self.error_description is not None:
rv['detail'] = self.error_description rv["detail"] = self.error_description
rv['error_message'] = self.error_description # TODO: deprecate rv["error_message"] = self.error_description # TODO: deprecate
rv['error_type'] = self.error_type.value # TODO: deprecate rv["error_type"] = self.error_type.value # TODO: deprecate
rv['title'] = self.error_type.value rv["title"] = self.error_type.value
rv['type'] = url_for('api.error', error_type=self.error_type.value, _external=True) rv["type"] = url_for(
rv['status'] = self.code "api.error", error_type=self.error_type.value, _external=True
)
rv["status"] = self.code
return rv return rv
class InvalidRequest(ApiException): class InvalidRequest(ApiException):
def __init__(self, error_description, payload=None): def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_request, 400, error_description, payload) ApiException.__init__(
self, ApiErrorType.invalid_request, 400, error_description, payload
)
class InvalidResponse(ApiException): class InvalidResponse(ApiException):
def __init__(self, error_description, payload=None): def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_response, 400, error_description, payload) ApiException.__init__(
self, ApiErrorType.invalid_response, 400, error_description, payload
)

View file

@ -5,7 +5,7 @@ from config_app.config_endpoints.common import render_page_template
from config_app.config_endpoints.api.discovery import generate_route_data from config_app.config_endpoints.api.discovery import generate_route_data
from config_app.config_endpoints.api import no_cache from config_app.config_endpoints.api import no_cache
setup_web = Blueprint('setup_web', __name__, template_folder='templates') setup_web = Blueprint("setup_web", __name__, template_folder="templates")
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@ -18,6 +18,8 @@ def render_page_template_with_routedata(name, *args, **kwargs):
@no_cache @no_cache
@setup_web.route('/', methods=['GET'], defaults={'path': ''}) @setup_web.route("/", methods=["GET"], defaults={"path": ""})
def index(path, **kwargs): def index(path, **kwargs):
return render_page_template_with_routedata('index.html', js_bundle_name='configapp', **kwargs) return render_page_template_with_routedata(
"index.html", js_bundle_name="configapp", **kwargs
)

View file

@ -5,14 +5,26 @@ from data import database, model
from util.security.test.test_ssl_util import generate_test_cert from util.security.test.test_ssl_util import generate_test_cert
from config_app.c_app import app from config_app.c_app import app
from config_app.config_test import ApiTestCase, all_queues, ADMIN_ACCESS_USER, ADMIN_ACCESS_EMAIL from config_app.config_test import (
ApiTestCase,
all_queues,
ADMIN_ACCESS_USER,
ADMIN_ACCESS_EMAIL,
)
from config_app.config_endpoints.api import api_bp from config_app.config_endpoints.api import api_bp
from config_app.config_endpoints.api.superuser import SuperUserCustomCertificate, SuperUserCustomCertificates from config_app.config_endpoints.api.superuser import (
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserCreateInitialSuperUser, \ SuperUserCustomCertificate,
SuperUserConfigFile, SuperUserRegistryStatus SuperUserCustomCertificates,
)
from config_app.config_endpoints.api.suconfig import (
SuperUserConfig,
SuperUserCreateInitialSuperUser,
SuperUserConfigFile,
SuperUserRegistryStatus,
)
try: try:
app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(api_bp, url_prefix="/api")
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
pass pass
@ -21,18 +33,17 @@ except ValueError:
class TestSuperUserCreateInitialSuperUser(ApiTestCase): class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_create_superuser(self): def test_create_superuser(self):
data = { data = {
'username': 'newsuper', "username": "newsuper",
'password': 'password', "password": "password",
'email': 'jschorr+fake@devtable.com', "email": "jschorr+fake@devtable.com",
} }
# Add some fake config. # Add some fake config.
fake_config = { fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, hostname='fakehost')) self.putJsonResponse(
SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
)
# Try to write with config. Should 403 since there are users in the DB. # Try to write with config. Should 403 since there are users in the DB.
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
@ -45,34 +56,32 @@ class TestSuperUserCreateInitialSuperUser(ApiTestCase):
self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data) self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
# Ensure the user exists in the DB. # Ensure the user exists in the DB.
self.assertIsNotNone(model.user.get_user('newsuper')) self.assertIsNotNone(model.user.get_user("newsuper"))
# Ensure that the current user is a superuser in the config. # Ensure that the current user is a superuser in the config.
json = self.getJsonResponse(SuperUserConfig) json = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['newsuper'], json['config']['SUPER_USERS']) self.assertEquals(["newsuper"], json["config"]["SUPER_USERS"])
# Ensure that the current user is a superuser in memory by trying to call an API # Ensure that the current user is a superuser in memory by trying to call an API
# that will fail otherwise. # that will fail otherwise.
self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) self.getResponse(SuperUserConfigFile, params=dict(filename="ssl.cert"))
class TestSuperUserConfig(ApiTestCase): class TestSuperUserConfig(ApiTestCase):
def test_get_status_update_config(self): def test_get_status_update_config(self):
# With no config the status should be 'config-db'. # With no config the status should be 'config-db'.
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status']) self.assertEquals("config-db", json["status"])
# Add some fake config. # Add some fake config.
fake_config = { fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
json = self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, json = self.putJsonResponse(
hostname='fakehost')) SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
self.assertEquals('fakekey', json['config']['SECRET_KEY']) )
self.assertEquals('fakehost', json['config']['SERVER_HOSTNAME']) self.assertEquals("fakekey", json["config"]["SECRET_KEY"])
self.assertEquals('Database', json['config']['AUTHENTICATION_TYPE']) self.assertEquals("fakehost", json["config"]["SERVER_HOSTNAME"])
self.assertEquals("Database", json["config"]["AUTHENTICATION_TYPE"])
# With config the status should be 'setup-db'. # With config the status should be 'setup-db'.
# TODO: fix this test # TODO: fix this test
@ -81,56 +90,65 @@ class TestSuperUserConfig(ApiTestCase):
def test_config_file(self): def test_config_file(self):
# Try for an invalid file. Should 404. # Try for an invalid file. Should 404.
self.getResponse(SuperUserConfigFile, params=dict(filename='foobar'), expected_code=404) self.getResponse(
SuperUserConfigFile, params=dict(filename="foobar"), expected_code=404
)
# Try for a valid filename. Should not exist. # Try for a valid filename. Should not exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) json = self.getJsonResponse(
self.assertFalse(json['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(json["exists"])
# Add the file. # Add the file.
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), self.postResponse(
file=(StringIO('my file contents'), 'ssl.cert')) SuperUserConfigFile,
params=dict(filename="ssl.cert"),
file=(StringIO("my file contents"), "ssl.cert"),
)
# Should now exist. # Should now exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) json = self.getJsonResponse(
self.assertTrue(json['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertTrue(json["exists"])
def test_update_with_external_auth(self): def test_update_with_external_auth(self):
# Run a mock LDAP. # Run a mock LDAP.
mockldap = MockLdap({ mockldap = MockLdap(
'dc=quay,dc=io': { {
'dc': ['quay', 'io'] "dc=quay,dc=io": {"dc": ["quay", "io"]},
"ou=employees,dc=quay,dc=io": {"dc": ["quay", "io"], "ou": "employees"},
"uid="
+ ADMIN_ACCESS_USER
+ ",ou=employees,dc=quay,dc=io": {
"dc": ["quay", "io"],
"ou": "employees",
"uid": [ADMIN_ACCESS_USER],
"userPassword": ["password"],
"mail": [ADMIN_ACCESS_EMAIL],
}, },
'ou=employees,dc=quay,dc=io': { }
'dc': ['quay', 'io'], )
'ou': 'employees'
},
'uid=' + ADMIN_ACCESS_USER + ',ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': [ADMIN_ACCESS_USER],
'userPassword': ['password'],
'mail': [ADMIN_ACCESS_EMAIL],
},
})
config = { config = {
'AUTHENTICATION_TYPE': 'LDAP', "AUTHENTICATION_TYPE": "LDAP",
'LDAP_BASE_DN': ['dc=quay', 'dc=io'], "LDAP_BASE_DN": ["dc=quay", "dc=io"],
'LDAP_ADMIN_DN': 'uid=devtable,ou=employees,dc=quay,dc=io', "LDAP_ADMIN_DN": "uid=devtable,ou=employees,dc=quay,dc=io",
'LDAP_ADMIN_PASSWD': 'password', "LDAP_ADMIN_PASSWD": "password",
'LDAP_USER_RDN': ['ou=employees'], "LDAP_USER_RDN": ["ou=employees"],
'LDAP_UID_ATTR': 'uid', "LDAP_UID_ATTR": "uid",
'LDAP_EMAIL_ATTR': 'mail', "LDAP_EMAIL_ATTR": "mail",
} }
mockldap.start() mockldap.start()
try: try:
# Write the config with the valid password. # Write the config with the valid password.
self.putResponse(SuperUserConfig, self.putResponse(
data={'config': config, SuperUserConfig,
'password': 'password', data={"config": config, "password": "password", "hostname": "foo"},
'hostname': 'foo'}, expected_code=200) expected_code=200,
)
# Ensure that the user row has been linked. # Ensure that the user row has been linked.
# TODO: fix this test # TODO: fix this test
@ -139,70 +157,90 @@ class TestSuperUserConfig(ApiTestCase):
finally: finally:
mockldap.stop() mockldap.stop()
class TestSuperUserCustomCertificates(ApiTestCase): class TestSuperUserCustomCertificates(ApiTestCase):
def test_custom_certificates(self): def test_custom_certificates(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', san_list=['DNS:bar', 'DNS:baz']) cert_contents, _ = generate_test_cert(
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), hostname="somecoolhost", san_list=["DNS:bar", "DNS:baz"]
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204) )
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost', 'bar', 'baz']), set(cert_info['names'])) self.assertEquals(set(["somecoolhost", "bar", "baz"]), set(cert_info["names"]))
self.assertFalse(cert_info['expired']) self.assertFalse(cert_info["expired"])
# Remove the certificate. # Remove the certificate.
self.deleteResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt')) self.deleteResponse(
SuperUserCustomCertificate, params=dict(certpath="testcert.crt")
)
# Make sure it is gone. # Make sure it is gone.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(0, len(json['certs'])) self.assertEquals(0, len(json["certs"]))
def test_expired_custom_certificate(self): def test_expired_custom_certificate(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10) cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), self.postResponse(
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost']), set(cert_info['names'])) self.assertEquals(set(["somecoolhost"]), set(cert_info["names"]))
self.assertTrue(cert_info['expired']) self.assertTrue(cert_info["expired"])
def test_invalid_custom_certificate(self): def test_invalid_custom_certificate(self):
# Upload an invalid certificate. # Upload an invalid certificate.
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'), self.postResponse(
file=(StringIO('some contents'), 'testcert.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO("some contents"), "testcert.crt"),
expected_code=204,
)
# Make sure it is present but invalid. # Make sure it is present but invalid.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0] cert_info = json["certs"][0]
self.assertEquals('testcert.crt', cert_info['path']) self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals('no start line', cert_info['error']) self.assertEquals("no start line", cert_info["error"])
def test_path_sanitization(self): def test_path_sanitization(self):
# Upload a certificate. # Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10) cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert/../foobar.crt'), self.postResponse(
file=(StringIO(cert_contents), 'testcert/../foobar.crt'), expected_code=204) SuperUserCustomCertificate,
params=dict(certpath="testcert/../foobar.crt"),
file=(StringIO(cert_contents), "testcert/../foobar.crt"),
expected_code=204,
)
# Make sure it is present. # Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates) json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs'])) self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0]
self.assertEquals('foobar.crt', cert_info['path'])
cert_info = json["certs"][0]
self.assertEquals("foobar.crt", cert_info["path"])

View file

@ -4,14 +4,19 @@ import mock
from data.database import User from data.database import User
from data import model from data import model
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserConfigValidate, SuperUserConfigFile, \ from config_app.config_endpoints.api.suconfig import (
SuperUserRegistryStatus, SuperUserCreateInitialSuperUser SuperUserConfig,
SuperUserConfigValidate,
SuperUserConfigFile,
SuperUserRegistryStatus,
SuperUserCreateInitialSuperUser,
)
from config_app.config_endpoints.api import api_bp from config_app.config_endpoints.api import api_bp
from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER
from config_app.c_app import app, config_provider from config_app.c_app import app, config_provider
try: try:
app.register_blueprint(api_bp, url_prefix='/api') app.register_blueprint(api_bp, url_prefix="/api")
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
pass pass
@ -19,6 +24,7 @@ except ValueError:
# OVERRIDES FROM PORTING FROM OLD APP: # OVERRIDES FROM PORTING FROM OLD APP:
all_queues = [] # the config app doesn't have any queues all_queues = [] # the config app doesn't have any queues
class FreshConfigProvider(object): class FreshConfigProvider(object):
def __enter__(self): def __enter__(self):
config_provider.reset_for_test() config_provider.reset_for_test()
@ -32,122 +38,171 @@ class TestSuperUserRegistryStatus(ApiTestCase):
def test_registry_status_no_config(self): def test_registry_status_no_config(self):
with FreshConfigProvider(): with FreshConfigProvider():
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status']) self.assertEquals("config-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=False)) @mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=False),
)
def test_registry_status_no_database(self): def test_registry_status_no_database(self):
with FreshConfigProvider(): with FreshConfigProvider():
config_provider.save_config({'key': 'value'}) config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('setup-db', json['status']) self.assertEquals("setup-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) @mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
def test_registry_status_db_has_superuser(self): def test_registry_status_db_has_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
config_provider.save_config({'key': 'value'}) config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config', json['status']) self.assertEquals("config", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) @mock.patch(
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=False)) "config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_has_users",
mock.Mock(return_value=False),
)
def test_registry_status_db_no_superuser(self): def test_registry_status_db_no_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
config_provider.save_config({'key': 'value'}) config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('create-superuser', json['status']) self.assertEquals("create-superuser", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) @mock.patch(
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=True)) "config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_has_users",
mock.Mock(return_value=True),
)
def test_registry_status_setup_complete(self): def test_registry_status_setup_complete(self):
with FreshConfigProvider(): with FreshConfigProvider():
config_provider.save_config({'key': 'value', 'SETUP_COMPLETE': True}) config_provider.save_config({"key": "value", "SETUP_COMPLETE": True})
json = self.getJsonResponse(SuperUserRegistryStatus) json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config', json['status']) self.assertEquals("config", json["status"])
class TestSuperUserConfigFile(ApiTestCase): class TestSuperUserConfigFile(ApiTestCase):
def test_get_superuser_invalid_filename(self): def test_get_superuser_invalid_filename(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.getResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) self.getResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_get_superuser(self): def test_get_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
result = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert')) result = self.getJsonResponse(
self.assertFalse(result['exists']) SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(result["exists"])
def test_post_no_file(self): def test_post_no_file(self):
with FreshConfigProvider(): with FreshConfigProvider():
# No file # No file
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
def test_post_superuser_invalid_filename(self): def test_post_superuser_invalid_filename(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404) self.postResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_post_superuser(self): def test_post_superuser(self):
with FreshConfigProvider(): with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400) self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
class TestSuperUserCreateInitialSuperUser(ApiTestCase): class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_no_config_file(self): def test_no_config_file(self):
with FreshConfigProvider(): with FreshConfigProvider():
# If there is no config.yaml, then this method should security fail. # If there is no config.yaml, then this method should security fail.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_db_users(self): def test_config_file_with_db_users(self):
with FreshConfigProvider(): with FreshConfigProvider():
# Write some config. # Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# If there is a config.yaml, but existing DB users exist, then this method should security # If there is a config.yaml, but existing DB users exist, then this method should security
# fail. # fail.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403) username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_no_db_users(self): def test_config_file_with_no_db_users(self):
with FreshConfigProvider(): with FreshConfigProvider():
# Write some config. # Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# Delete all the users in the DB. # Delete all the users in the DB.
for user in list(User.select()): for user in list(User.select()):
model.user.delete_user(user, all_queues) model.user.delete_user(user, all_queues)
# This method should now succeed. # This method should now succeed.
data = dict(username='cooluser', password='password', email='fake@example.com') data = dict(
username="cooluser", password="password", email="fake@example.com"
)
result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data) result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
self.assertTrue(result['status']) self.assertTrue(result["status"])
# Verify the superuser was created. # Verify the superuser was created.
User.get(User.username == 'cooluser') User.get(User.username == "cooluser")
# Verify the superuser was placed into the config. # Verify the superuser was placed into the config.
result = self.getJsonResponse(SuperUserConfig) result = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['cooluser'], result['config']['SUPER_USERS']) self.assertEquals(["cooluser"], result["config"]["SUPER_USERS"])
class TestSuperUserConfigValidate(ApiTestCase): class TestSuperUserConfigValidate(ApiTestCase):
def test_nonsuperuser_noconfig(self): def test_nonsuperuser_noconfig(self):
with FreshConfigProvider(): with FreshConfigProvider():
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'), result = self.postJsonResponse(
data=dict(config={})) SuperUserConfigValidate,
params=dict(service="someservice"),
self.assertFalse(result['status']) data=dict(config={}),
)
self.assertFalse(result["status"])
def test_nonsuperuser_config(self): def test_nonsuperuser_config(self):
with FreshConfigProvider(): with FreshConfigProvider():
# The validate config call works if there is no config.yaml OR the user is a superuser. # The validate config call works if there is no config.yaml OR the user is a superuser.
# Add a config, and verify it breaks when unauthenticated. # Add a config, and verify it breaks when unauthenticated.
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) json = self.putJsonResponse(
self.assertTrue(json['exists']) SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
result = self.postJsonResponse(
SuperUserConfigValidate,
params=dict(service="someservice"),
data=dict(config={}),
)
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'), self.assertFalse(result["status"])
data=dict(config={}))
self.assertFalse(result['status'])
class TestSuperUserConfig(ApiTestCase): class TestSuperUserConfig(ApiTestCase):
@ -157,23 +212,27 @@ class TestSuperUserConfig(ApiTestCase):
# Note: We expect the config to be none because a config.yaml should never be checked into # Note: We expect the config to be none because a config.yaml should never be checked into
# the directory. # the directory.
self.assertIsNone(json['config']) self.assertIsNone(json["config"])
def test_put(self): def test_put(self):
with FreshConfigProvider() as config: with FreshConfigProvider() as config:
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar')) json = self.putJsonResponse(
self.assertTrue(json['exists']) SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
# Verify the config file exists. # Verify the config file exists.
self.assertTrue(config.config_exists()) self.assertTrue(config.config_exists())
# This should succeed. # This should succeed.
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='barbaz')) json = self.putJsonResponse(
self.assertTrue(json['exists']) SuperUserConfig, data=dict(config={}, hostname="barbaz")
)
self.assertTrue(json["exists"])
json = self.getJsonResponse(SuperUserConfig) json = self.getJsonResponse(SuperUserConfig)
self.assertIsNotNone(json['config']) self.assertIsNotNone(json["config"])
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -5,7 +5,8 @@ from backports.tempfile import TemporaryDirectory
from config_app.config_util.config.fileprovider import FileConfigProvider from config_app.config_util.config.fileprovider import FileConfigProvider
OLD_CONFIG_SUBDIR = 'old/' OLD_CONFIG_SUBDIR = "old/"
class TransientDirectoryProvider(FileConfigProvider): class TransientDirectoryProvider(FileConfigProvider):
""" Implementation of the config provider that reads and writes the data """ Implementation of the config provider that reads and writes the data
@ -20,11 +21,13 @@ class TransientDirectoryProvider(FileConfigProvider):
temp_dir = TemporaryDirectory() temp_dir = TemporaryDirectory()
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.old_config_dir = None self.old_config_dir = None
super(TransientDirectoryProvider, self).__init__(temp_dir.name, yaml_filename, py_filename) super(TransientDirectoryProvider, self).__init__(
temp_dir.name, yaml_filename, py_filename
)
@property @property
def provider_id(self): def provider_id(self):
return 'transient' return "transient"
def new_config_dir(self): def new_config_dir(self):
""" """
@ -57,6 +60,8 @@ class TransientDirectoryProvider(FileConfigProvider):
def get_old_config_dir(self): def get_old_config_dir(self):
if self.old_config_dir is None: if self.old_config_dir is None:
raise Exception('Cannot return a configuration that was no old configuration') raise Exception(
"Cannot return a configuration that was no old configuration"
)
return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR) return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)

View file

@ -3,7 +3,9 @@ import os
from config_app.config_util.config.fileprovider import FileConfigProvider from config_app.config_util.config.fileprovider import FileConfigProvider
from config_app.config_util.config.testprovider import TestConfigProvider from config_app.config_util.config.testprovider import TestConfigProvider
from config_app.config_util.config.TransientDirectoryProvider import TransientDirectoryProvider from config_app.config_util.config.TransientDirectoryProvider import (
TransientDirectoryProvider,
)
from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX
@ -27,8 +29,9 @@ def get_config_as_kube_secret(config_path):
if os.path.exists(certs_dir): if os.path.exists(certs_dir):
for extra_cert in os.listdir(certs_dir): for extra_cert in os.listdir(certs_dir):
with open(os.path.join(certs_dir, extra_cert)) as f: with open(os.path.join(certs_dir, extra_cert)) as f:
data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(f.read()) data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(
f.read()
)
for name in os.listdir(config_path): for name in os.listdir(config_path):
file_path = os.path.join(config_path, name) file_path = os.path.join(config_path, name)

View file

@ -1,8 +1,12 @@
import os import os
import logging import logging
from config_app.config_util.config.baseprovider import (BaseProvider, import_yaml, export_yaml, from config_app.config_util.config.baseprovider import (
CannotWriteConfigException) BaseProvider,
import_yaml,
export_yaml,
CannotWriteConfigException,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,11 +24,11 @@ class BaseFileProvider(BaseProvider):
def update_app_config(self, app_config): def update_app_config(self, app_config):
if os.path.exists(self.py_path): if os.path.exists(self.py_path):
logger.debug('Applying config file: %s', self.py_path) logger.debug("Applying config file: %s", self.py_path)
app_config.from_pyfile(self.py_path) app_config.from_pyfile(self.py_path)
if os.path.exists(self.yaml_path): if os.path.exists(self.yaml_path):
logger.debug('Applying config file: %s', self.yaml_path) logger.debug("Applying config file: %s", self.yaml_path)
import_yaml(app_config, self.yaml_path) import_yaml(app_config, self.yaml_path)
def get_config(self): def get_config(self):
@ -44,7 +48,7 @@ class BaseFileProvider(BaseProvider):
def volume_file_exists(self, filename): def volume_file_exists(self, filename):
return os.path.exists(os.path.join(self.config_volume, filename)) return os.path.exists(os.path.join(self.config_volume, filename))
def get_volume_file(self, filename, mode='r'): def get_volume_file(self, filename, mode="r"):
return open(os.path.join(self.config_volume, filename), mode=mode) return open(os.path.join(self.config_volume, filename), mode=mode)
def get_volume_path(self, directory, filename): def get_volume_path(self, directory, filename):

Some files were not shown because too many files have changed in this diff Show more