Reformat Python code with psf/black

This commit is contained in:
cclauss 2019-11-20 08:58:47 +01:00
parent f915352138
commit aa13f95ca5
746 changed files with 103596 additions and 76051 deletions

View file

@ -7,40 +7,45 @@ from util.config.provider import get_config_provider
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/')
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/')
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/')
STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, 'webfonts/')
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/')
STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, "webfonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_TESTING = 'TEST' in os.environ
IS_BUILDING = 'BUILDING' in os.environ
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ
OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, 'stack/')
IS_TESTING = "TEST" in os.environ
IS_BUILDING = "BUILDING" in os.environ
IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, "stack/")
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py',
testing=IS_TESTING, kubernetes=IS_KUBERNETES)
config_provider = get_config_provider(
OVERRIDE_CONFIG_DIRECTORY,
"config.yaml",
"config.py",
testing=IS_TESTING,
kubernetes=IS_KUBERNETES,
)
def _get_version_number_changelog():
try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0)
except IOError:
return ''
try:
with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError:
return ""
def _get_git_sha():
if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read()
else:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError, Exception):
pass
return "unknown"
if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read()
else:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError, Exception):
pass
return "unknown"
__version__ = _get_version_number_changelog()

View file

@ -1,22 +1,30 @@
from enum import Enum, unique
from data.migrationutil import DefinedDataMigration, MigrationPhase
@unique
class ERTMigrationFlags(Enum):
""" Flags for the encrypted robot token migration. """
READ_OLD_FIELDS = 'read-old'
WRITE_OLD_FIELDS = 'write-old'
""" Flags for the encrypted robot token migration. """
READ_OLD_FIELDS = "read-old"
WRITE_OLD_FIELDS = "write-old"
ActiveDataMigration = DefinedDataMigration(
'encrypted_robot_tokens',
'ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE',
[
MigrationPhase('add-new-fields', 'c13c8052f7a6', [ERTMigrationFlags.READ_OLD_FIELDS,
ERTMigrationFlags.WRITE_OLD_FIELDS]),
MigrationPhase('backfill-then-read-only-new',
'703298a825c2', [ERTMigrationFlags.WRITE_OLD_FIELDS]),
MigrationPhase('stop-writing-both', '703298a825c2', []),
MigrationPhase('remove-old-fields', 'c059b952ed76', []),
]
"encrypted_robot_tokens",
"ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE",
[
MigrationPhase(
"add-new-fields",
"c13c8052f7a6",
[ERTMigrationFlags.READ_OLD_FIELDS, ERTMigrationFlags.WRITE_OLD_FIELDS],
),
MigrationPhase(
"backfill-then-read-only-new",
"703298a825c2",
[ERTMigrationFlags.WRITE_OLD_FIELDS],
),
MigrationPhase("stop-writing-both", "703298a825c2", []),
MigrationPhase("remove-old-fields", "c059b952ed76", []),
],
)

318
app.py
View file

@ -16,8 +16,14 @@ from werkzeug.exceptions import HTTPException
import features
from _init import (config_provider, CONF_DIR, IS_KUBERNETES, IS_TESTING, OVERRIDE_CONFIG_DIRECTORY,
IS_BUILDING)
from _init import (
config_provider,
CONF_DIR,
IS_KUBERNETES,
IS_TESTING,
OVERRIDE_CONFIG_DIRECTORY,
IS_BUILDING,
)
from auth.auth_context import get_authenticated_user
from avatars.avatars import Avatar
@ -35,7 +41,11 @@ from data.userevent import UserEventsBuilderModule
from data.userfiles import Userfiles
from data.users import UserAuthentication
from data.registry_model import registry_model
from path_converters import RegexConverter, RepositoryPathConverter, APIRepositoryPathConverter
from path_converters import (
RegexConverter,
RepositoryPathConverter,
APIRepositoryPathConverter,
)
from oauth.services.github import GithubOAuthService
from oauth.services.gitlab import GitLabOAuthService
from oauth.loginmanager import OAuthLoginManager
@ -62,13 +72,13 @@ from util.security.instancekeys import InstanceKeys
from util.security.signing import Signer
OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, 'stack/config.yaml')
OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, 'stack/config.py')
OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, "stack/config.yaml")
OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, "stack/config.py")
OVERRIDE_CONFIG_KEY = 'QUAY_OVERRIDE_CONFIG'
OVERRIDE_CONFIG_KEY = "QUAY_OVERRIDE_CONFIG"
DOCKER_V2_SIGNINGKEY_FILENAME = 'docker_v2.pem'
INIT_SCRIPTS_LOCATION = '/conf/init/'
DOCKER_V2_SIGNINGKEY_FILENAME = "docker_v2.pem"
INIT_SCRIPTS_LOCATION = "/conf/init/"
app = Flask(__name__)
logger = logging.getLogger(__name__)
@ -79,62 +89,75 @@ is_kubernetes = IS_KUBERNETES
is_building = IS_BUILDING
if is_testing:
from test.testconfig import TestConfig
logger.debug('Loading test config.')
app.config.from_object(TestConfig())
from test.testconfig import TestConfig
logger.debug("Loading test config.")
app.config.from_object(TestConfig())
else:
from config import DefaultConfig
logger.debug('Loading default config.')
app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter)
from config import DefaultConfig
logger.debug("Loading default config.")
app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter)
# Load the override config via the provider.
config_provider.update_app_config(app.config)
# Update any configuration found in the override environment variable.
environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, '{}'))
environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, "{}"))
app.config.update(environ_config)
# Fix remote address handling for Flask.
if app.config.get('PROXY_COUNT', 1):
app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get('PROXY_COUNT', 1))
if app.config.get("PROXY_COUNT", 1):
app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get("PROXY_COUNT", 1))
# Ensure the V3 upgrade key is specified correctly. If not, simply fail.
# TODO: Remove for V3.1.
if not is_testing and not is_building and app.config.get('SETUP_COMPLETE', False):
v3_upgrade_mode = app.config.get('V3_UPGRADE_MODE')
if v3_upgrade_mode is None:
raise Exception('Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs')
if not is_testing and not is_building and app.config.get("SETUP_COMPLETE", False):
v3_upgrade_mode = app.config.get("V3_UPGRADE_MODE")
if v3_upgrade_mode is None:
raise Exception(
"Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs"
)
if (v3_upgrade_mode != 'background'
and v3_upgrade_mode != 'complete'
and v3_upgrade_mode != 'production-transition'
and v3_upgrade_mode != 'post-oci-rollout'
and v3_upgrade_mode != 'post-oci-roll-back-compat'):
raise Exception('Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs')
if (
v3_upgrade_mode != "background"
and v3_upgrade_mode != "complete"
and v3_upgrade_mode != "production-transition"
and v3_upgrade_mode != "post-oci-rollout"
and v3_upgrade_mode != "post-oci-roll-back-compat"
):
raise Exception(
"Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs"
)
# Split the registry model based on config.
# TODO: Remove once we are fully on the OCI data model.
registry_model.setup_split(app.config.get('OCI_NAMESPACE_PROPORTION') or 0,
app.config.get('OCI_NAMESPACE_WHITELIST') or set(),
app.config.get('V22_NAMESPACE_WHITELIST') or set(),
app.config.get('V3_UPGRADE_MODE'))
registry_model.setup_split(
app.config.get("OCI_NAMESPACE_PROPORTION") or 0,
app.config.get("OCI_NAMESPACE_WHITELIST") or set(),
app.config.get("V22_NAMESPACE_WHITELIST") or set(),
app.config.get("V3_UPGRADE_MODE"),
)
# Allow user to define a custom storage preference for the local instance.
_distributed_storage_preference = os.environ.get('QUAY_DISTRIBUTED_STORAGE_PREFERENCE', '').split()
_distributed_storage_preference = os.environ.get(
"QUAY_DISTRIBUTED_STORAGE_PREFERENCE", ""
).split()
if _distributed_storage_preference:
app.config['DISTRIBUTED_STORAGE_PREFERENCE'] = _distributed_storage_preference
app.config["DISTRIBUTED_STORAGE_PREFERENCE"] = _distributed_storage_preference
# Generate a secret key if none was specified.
if app.config['SECRET_KEY'] is None:
logger.debug('Generating in-memory secret key')
app.config['SECRET_KEY'] = generate_secret_key()
if app.config["SECRET_KEY"] is None:
logger.debug("Generating in-memory secret key")
app.config["SECRET_KEY"] = generate_secret_key()
# If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure
# session cookie.
if (app.config['PREFERRED_URL_SCHEME'] == 'https' and
not app.config.get('FORCE_NONSECURE_SESSION_COOKIE', False)):
app.config['SESSION_COOKIE_SECURE'] = True
if app.config["PREFERRED_URL_SCHEME"] == "https" and not app.config.get(
"FORCE_NONSECURE_SESSION_COOKIE", False
):
app.config["SESSION_COOKIE_SECURE"] = True
# Load features from config.
features.import_features(app.config)
@ -145,65 +168,77 @@ logger.debug("Loaded config", extra={"config": app.config})
class RequestWithId(Request):
request_gen = staticmethod(urn_generator(['request']))
request_gen = staticmethod(urn_generator(["request"]))
def __init__(self, *args, **kwargs):
super(RequestWithId, self).__init__(*args, **kwargs)
self.request_id = self.request_gen()
def __init__(self, *args, **kwargs):
super(RequestWithId, self).__init__(*args, **kwargs)
self.request_id = self.request_gen()
@app.before_request
def _request_start():
if os.getenv('PYDEV_DEBUG', None):
import pydevd
host, port = os.getenv('PYDEV_DEBUG').split(':')
pydevd.settrace(host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False)
if os.getenv("PYDEV_DEBUG", None):
import pydevd
logger.debug('Starting request: %s (%s)', request.request_id, request.path,
extra={"request_id": request.request_id})
host, port = os.getenv("PYDEV_DEBUG").split(":")
pydevd.settrace(
host,
port=int(port),
stdoutToServer=True,
stderrToServer=True,
suspend=False,
)
logger.debug(
"Starting request: %s (%s)",
request.request_id,
request.path,
extra={"request_id": request.request_id},
)
DEFAULT_FILTER = lambda x: '[FILTERED]'
DEFAULT_FILTER = lambda x: "[FILTERED]"
FILTERED_VALUES = [
{'key': ['password'], 'fn': DEFAULT_FILTER},
{'key': ['user', 'password'], 'fn': DEFAULT_FILTER},
{'key': ['blob'], 'fn': lambda x: x[0:8]}
{"key": ["password"], "fn": DEFAULT_FILTER},
{"key": ["user", "password"], "fn": DEFAULT_FILTER},
{"key": ["blob"], "fn": lambda x: x[0:8]},
]
@app.after_request
def _request_end(resp):
try:
jsonbody = request.get_json(force=True, silent=True)
except HTTPException:
jsonbody = None
try:
jsonbody = request.get_json(force=True, silent=True)
except HTTPException:
jsonbody = None
values = request.values.to_dict()
values = request.values.to_dict()
if jsonbody and not isinstance(jsonbody, dict):
jsonbody = {'_parsererror': jsonbody}
if jsonbody and not isinstance(jsonbody, dict):
jsonbody = {"_parsererror": jsonbody}
if isinstance(values, dict):
filter_logs(values, FILTERED_VALUES)
if isinstance(values, dict):
filter_logs(values, FILTERED_VALUES)
extra = {
"endpoint": request.endpoint,
"request_id" : request.request_id,
"remote_addr": request.remote_addr,
"http_method": request.method,
"original_url": request.url,
"path": request.path,
"parameters": values,
"json_body": jsonbody,
"confsha": CONFIG_DIGEST,
}
extra = {
"endpoint": request.endpoint,
"request_id": request.request_id,
"remote_addr": request.remote_addr,
"http_method": request.method,
"original_url": request.url,
"path": request.path,
"parameters": values,
"json_body": jsonbody,
"confsha": CONFIG_DIGEST,
}
if request.user_agent is not None:
extra["user-agent"] = request.user_agent.string
logger.debug('Ending request: %s (%s)', request.request_id, request.path, extra=extra)
return resp
if request.user_agent is not None:
extra["user-agent"] = request.user_agent.string
logger.debug(
"Ending request: %s (%s)", request.request_id, request.path, extra=extra
)
return resp
root_logger = logging.getLogger()
@ -211,13 +246,13 @@ root_logger = logging.getLogger()
app.request_class = RequestWithId
# Register custom converters.
app.url_map.converters['regex'] = RegexConverter
app.url_map.converters['repopath'] = RepositoryPathConverter
app.url_map.converters['apirepopath'] = APIRepositoryPathConverter
app.url_map.converters["regex"] = RegexConverter
app.url_map.converters["repopath"] = RepositoryPathConverter
app.url_map.converters["apirepopath"] = APIRepositoryPathConverter
Principal(app, use_sessions=False)
tf = app.config['DB_TRANSACTION_FACTORY']
tf = app.config["DB_TRANSACTION_FACTORY"]
model_cache = get_model_cache(app.config)
avatar = Avatar(app)
@ -225,10 +260,14 @@ login_manager = LoginManager(app)
mail = Mail(app)
prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus)
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue)
chunk_cleanup_queue = WorkQueue(
app.config["CHUNK_CLEANUP_QUEUE_NAME"], tf, metric_queue=metric_queue
)
instance_keys = InstanceKeys(app)
ip_resolver = IPResolver(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver)
storage = Storage(
app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver
)
userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage)
analytics = Analytics(app)
@ -246,55 +285,99 @@ build_canceller = BuildCanceller(app)
start_cloudwatch_sender(metric_queue, app)
github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG')
gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG')
github_trigger = GithubOAuthService(app.config, "GITHUB_TRIGGER_CONFIG")
gitlab_trigger = GitLabOAuthService(app.config, "GITLAB_TRIGGER_CONFIG")
oauth_login = OAuthLoginManager(app.config)
oauth_apps = [github_trigger, gitlab_trigger]
image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf,
has_namespace=False, metric_queue=metric_queue)
dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf,
metric_queue=metric_queue,
reporter=BuildMetricQueueReporter(metric_queue),
has_namespace=True)
notification_queue = WorkQueue(app.config['NOTIFICATION_QUEUE_NAME'], tf, has_namespace=True,
metric_queue=metric_queue)
secscan_notification_queue = WorkQueue(app.config['SECSCAN_NOTIFICATION_QUEUE_NAME'], tf,
has_namespace=False,
metric_queue=metric_queue)
export_action_logs_queue = WorkQueue(app.config['EXPORT_ACTION_LOGS_QUEUE_NAME'], tf,
has_namespace=True,
metric_queue=metric_queue)
image_replication_queue = WorkQueue(
app.config["REPLICATION_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
dockerfile_build_queue = WorkQueue(
app.config["DOCKERFILE_BUILD_QUEUE_NAME"],
tf,
metric_queue=metric_queue,
reporter=BuildMetricQueueReporter(metric_queue),
has_namespace=True,
)
notification_queue = WorkQueue(
app.config["NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
secscan_notification_queue = WorkQueue(
app.config["SECSCAN_NOTIFICATION_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
export_action_logs_queue = WorkQueue(
app.config["EXPORT_ACTION_LOGS_QUEUE_NAME"],
tf,
has_namespace=True,
metric_queue=metric_queue,
)
# Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied
# when a namespace is marked for deletion.
namespace_gc_queue = WorkQueue(app.config['NAMESPACE_GC_QUEUE_NAME'], tf, has_namespace=False,
metric_queue=metric_queue)
namespace_gc_queue = WorkQueue(
app.config["NAMESPACE_GC_QUEUE_NAME"],
tf,
has_namespace=False,
metric_queue=metric_queue,
)
all_queues = [image_replication_queue, dockerfile_build_queue, notification_queue,
secscan_notification_queue, chunk_cleanup_queue, namespace_gc_queue]
all_queues = [
image_replication_queue,
dockerfile_build_queue,
notification_queue,
secscan_notification_queue,
chunk_cleanup_queue,
namespace_gc_queue,
]
url_scheme_and_hostname = URLSchemeAndHostname(app.config['PREFERRED_URL_SCHEME'], app.config['SERVER_HOSTNAME'])
secscan_api = SecurityScannerAPI(app.config, storage, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'],
uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname),
instance_keys=instance_keys)
url_scheme_and_hostname = URLSchemeAndHostname(
app.config["PREFERRED_URL_SCHEME"], app.config["SERVER_HOSTNAME"]
)
secscan_api = SecurityScannerAPI(
app.config,
storage,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
uri_creator=get_blob_download_uri_getter(
app.test_request_context("/"), url_scheme_and_hostname
),
instance_keys=instance_keys,
)
repo_mirror_api = RepoMirrorAPI(app.config, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'],
instance_keys=instance_keys)
repo_mirror_api = RepoMirrorAPI(
app.config,
app.config["SERVER_HOSTNAME"],
app.config["HTTPCLIENT"],
instance_keys=instance_keys,
)
tuf_metadata_api = TUFMetadataAPI(app, app.config)
# Check for a key in config. If none found, generate a new signing key for Docker V2 manifests.
_v2_key_path = os.path.join(OVERRIDE_CONFIG_DIRECTORY, DOCKER_V2_SIGNINGKEY_FILENAME)
if os.path.exists(_v2_key_path):
docker_v2_signing_key = RSAKey().load(_v2_key_path)
docker_v2_signing_key = RSAKey().load(_v2_key_path)
else:
docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
# Configure the database.
if app.config.get('DATABASE_SECRET_KEY') is None and app.config.get('SETUP_COMPLETE', False):
raise Exception('Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?')
if app.config.get("DATABASE_SECRET_KEY") is None and app.config.get(
"SETUP_COMPLETE", False
):
raise Exception(
"Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?"
)
database.configure(app.config)
@ -306,8 +389,9 @@ model.config.register_repo_cleanup_callback(tuf_metadata_api.delete_metadata)
@login_manager.user_loader
def load_user(user_uuid):
logger.debug('User loader loading deferred user with uuid: %s', user_uuid)
return LoginWrappedDBUser(user_uuid)
logger.debug("User loader loading deferred user with uuid: %s", user_uuid)
return LoginWrappedDBUser(user_uuid)
logs_model.configure(app.config)

View file

@ -1,5 +1,6 @@
# NOTE: Must be before we import or call anything that may be synchronous.
from gevent import monkey
monkey.patch_all()
import os
@ -17,6 +18,6 @@ import registry
import secscan
if __name__ == '__main__':
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0')
if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -1,21 +1,25 @@
from flask import _request_ctx_stack
def get_authenticated_context():
""" Returns the auth context for the current request context, if any. """
return getattr(_request_ctx_stack.top, 'authenticated_context', None)
""" Returns the auth context for the current request context, if any. """
return getattr(_request_ctx_stack.top, "authenticated_context", None)
def get_authenticated_user():
""" Returns the authenticated user, if any, or None if none. """
context = get_authenticated_context()
return context.authed_user if context else None
""" Returns the authenticated user, if any, or None if none. """
context = get_authenticated_context()
return context.authed_user if context else None
def get_validated_oauth_token():
""" Returns the authenticated and validated OAuth access token, if any, or None if none. """
context = get_authenticated_context()
return context.authed_oauth_token if context else None
""" Returns the authenticated and validated OAuth access token, if any, or None if none. """
context = get_authenticated_context()
return context.authed_oauth_token if context else None
def set_authenticated_context(auth_context):
""" Sets the auth context for the current request context to that given. """
ctx = _request_ctx_stack.top
ctx.authenticated_context = auth_context
return auth_context
""" Sets the auth context for the current request context to that given. """
ctx = _request_ctx_stack.top
ctx.authenticated_context = auth_context
return auth_context

View file

@ -16,422 +16,446 @@ from auth.scopes import scopes_from_scope_string
logger = logging.getLogger(__name__)
@add_metaclass(ABCMeta)
class AuthContext(object):
"""
"""
Interface that represents the current context of authentication.
"""
@property
@abstractmethod
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
pass
@property
@abstractmethod
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
pass
@property
@abstractmethod
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
pass
@property
@abstractmethod
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
pass
@property
@abstractmethod
def authed_oauth_token(self):
""" Returns the authenticated OAuth token, if any. """
pass
@property
@abstractmethod
def authed_oauth_token(self):
""" Returns the authenticated OAuth token, if any. """
pass
@property
@abstractmethod
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
@property
@abstractmethod
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts.
"""
pass
pass
@property
@abstractmethod
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
pass
@property
@abstractmethod
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
pass
@property
@abstractmethod
def identity(self):
""" Returns the identity for the auth context. """
pass
@property
@abstractmethod
def identity(self):
""" Returns the identity for the auth context. """
pass
@property
@abstractmethod
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
pass
@property
@abstractmethod
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
pass
@property
@abstractmethod
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
pass
@property
@abstractmethod
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
pass
@abstractmethod
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
pass
@abstractmethod
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
pass
@abstractmethod
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
pass
@abstractmethod
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
pass
@abstractmethod
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
@abstractmethod
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
pass
pass
@property
@abstractmethod
def unique_key(self):
""" Returns a key that is unique to this auth context type and its data. For example, an
@property
@abstractmethod
def unique_key(self):
""" Returns a key that is unique to this auth context type and its data. For example, an
instance of the auth context type for the user might be a string of the form
`user-{user-uuid}`. Callers should treat this key as opaque and not rely on the contents
for anything besides uniqueness. This is typically used by callers when they'd like to
check cache but not hit the database to get a fully validated auth context.
"""
pass
pass
class ValidatedAuthContext(AuthContext):
""" ValidatedAuthContext represents the loaded, authenticated and validated auth information
""" ValidatedAuthContext represents the loaded, authenticated and validated auth information
for the current request context.
"""
def __init__(self, user=None, token=None, oauthtoken=None, robot=None, appspecifictoken=None,
signed_data=None):
# Note: These field names *MUST* match the string values of the kinds defined in
# ContextEntityKind.
self.user = user
self.robot = robot
self.token = token
self.oauthtoken = oauthtoken
self.appspecifictoken = appspecifictoken
self.signed_data = signed_data
def tuple(self):
return vars(self).values()
def __init__(
self,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
):
# Note: These field names *MUST* match the string values of the kinds defined in
# ContextEntityKind.
self.user = user
self.robot = robot
self.token = token
self.oauthtoken = oauthtoken
self.appspecifictoken = appspecifictoken
self.signed_data = signed_data
def __eq__(self, other):
return self.tuple() == other.tuple()
def tuple(self):
return vars(self).values()
@property
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
for kind in ContextEntityKind:
if hasattr(self, kind.value) and getattr(self, kind.value):
return kind
def __eq__(self, other):
return self.tuple() == other.tuple()
return ContextEntityKind.anonymous
@property
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
for kind in ContextEntityKind:
if hasattr(self, kind.value) and getattr(self, kind.value):
return kind
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. Note that this
return ContextEntityKind.anonymous
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. Note that this
will also return robot accounts.
"""
authed_user = self._authed_user()
if authed_user is not None and not authed_user.enabled:
logger.warning('Attempt to reference a disabled user/robot: %s', authed_user.username)
return None
authed_user = self._authed_user()
if authed_user is not None and not authed_user.enabled:
logger.warning(
"Attempt to reference a disabled user/robot: %s", authed_user.username
)
return None
return authed_user
return authed_user
@property
def authed_oauth_token(self):
return self.oauthtoken
@property
def authed_oauth_token(self):
return self.oauthtoken
def _authed_user(self):
if self.oauthtoken:
return self.oauthtoken.authorized_user
def _authed_user(self):
if self.oauthtoken:
return self.oauthtoken.authorized_user
if self.appspecifictoken:
return self.appspecifictoken.user
if self.appspecifictoken:
return self.appspecifictoken.user
if self.signed_data:
return model.user.get_user(self.signed_data['user_context'])
if self.signed_data:
return model.user.get_user(self.signed_data["user_context"])
return self.user if self.user else self.robot
return self.user if self.user else self.robot
@property
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
return not self.authed_user and not self.token and not self.signed_data
@property
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
return not self.authed_user and not self.token and not self.signed_data
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
return bool(self.authed_user and not self.robot)
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
return bool(self.authed_user and not self.robot)
@property
def identity(self):
""" Returns the identity for the auth context. """
if self.oauthtoken:
scope_set = scopes_from_scope_string(self.oauthtoken.scope)
return QuayDeferredPermissionUser.for_user(self.oauthtoken.authorized_user, scope_set)
@property
def identity(self):
""" Returns the identity for the auth context. """
if self.oauthtoken:
scope_set = scopes_from_scope_string(self.oauthtoken.scope)
return QuayDeferredPermissionUser.for_user(
self.oauthtoken.authorized_user, scope_set
)
if self.authed_user:
return QuayDeferredPermissionUser.for_user(self.authed_user)
if self.authed_user:
return QuayDeferredPermissionUser.for_user(self.authed_user)
if self.token:
return Identity(self.token.get_code(), 'token')
if self.token:
return Identity(self.token.get_code(), "token")
if self.signed_data:
identity = Identity(None, 'signed_grant')
identity.provides.update(self.signed_data['grants'])
return identity
if self.signed_data:
identity = Identity(None, "signed_grant")
identity.provides.update(self.signed_data["grants"])
return identity
return None
return None
@property
def entity_reference(self):
""" Returns the DB object reference for this context's entity. """
if self.entity_kind == ContextEntityKind.anonymous:
return None
@property
def entity_reference(self):
""" Returns the DB object reference for this context's entity. """
if self.entity_kind == ContextEntityKind.anonymous:
return None
return getattr(self, self.entity_kind.value)
return getattr(self, self.entity_kind.value)
@property
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.description(self.entity_reference)
@property
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.description(self.entity_reference)
@property
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.credential_username(self.entity_reference)
@property
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.credential_username(self.entity_reference)
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.analytics_id_and_public_metadata(self.entity_reference)
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
return handler.analytics_id_and_public_metadata(self.entity_reference)
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
# Save to the request context.
set_authenticated_context(self)
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
# Save to the request context.
set_authenticated_context(self)
# Set the identity for Flask-Principal.
if self.identity:
identity_changed.send(app, identity=self.identity)
# Set the identity for Flask-Principal.
if self.identity:
identity_changed.send(app, identity=self.identity)
@property
def unique_key(self):
signed_dict = self.to_signed_dict()
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)'))
@property
def unique_key(self):
signed_dict = self.to_signed_dict()
return "%s-%s" % (
signed_dict["entity_kind"],
signed_dict.get("entity_reference", "(anon)"),
)
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
dict_data = {
'version': 2,
'entity_kind': self.entity_kind.value,
}
dict_data = {"version": 2, "entity_kind": self.entity_kind.value}
if self.entity_kind != ContextEntityKind.anonymous:
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
dict_data.update({
'entity_reference': handler.get_serialized_entity_reference(self.entity_reference),
})
if self.entity_kind != ContextEntityKind.anonymous:
handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
dict_data.update(
{
"entity_reference": handler.get_serialized_entity_reference(
self.entity_reference
)
}
)
# Add legacy information.
# TODO: Remove this all once the new code is fully deployed.
if self.token:
dict_data.update({
'kind': 'token',
'token': self.token.code,
})
# Add legacy information.
# TODO: Remove this all once the new code is fully deployed.
if self.token:
dict_data.update({"kind": "token", "token": self.token.code})
if self.oauthtoken:
dict_data.update({
'kind': 'oauth',
'oauth': self.oauthtoken.uuid,
'user': self.authed_user.username,
})
if self.oauthtoken:
dict_data.update(
{
"kind": "oauth",
"oauth": self.oauthtoken.uuid,
"user": self.authed_user.username,
}
)
if self.user or self.robot:
dict_data.update({
'kind': 'user',
'user': self.authed_user.username,
})
if self.user or self.robot:
dict_data.update({"kind": "user", "user": self.authed_user.username})
if self.appspecifictoken:
dict_data.update({
'kind': 'user',
'user': self.authed_user.username,
})
if self.appspecifictoken:
dict_data.update({"kind": "user", "user": self.authed_user.username})
if self.is_anonymous:
dict_data.update({
'kind': 'anonymous',
})
if self.is_anonymous:
dict_data.update({"kind": "anonymous"})
# End of legacy information.
return dict_data
# End of legacy information.
return dict_data
class SignedAuthContext(AuthContext):
""" SignedAuthContext represents an auth context loaded from a signed token of some kind,
""" SignedAuthContext represents an auth context loaded from a signed token of some kind,
such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading
the actual {user, robot, token, etc} when requested. This allows registry operations that
only need to check if *some* entity is present to do so, without hitting the database.
"""
def __init__(self, kind, signed_data, v1_dict_format):
self.kind = kind
self.signed_data = signed_data
self.v1_dict_format = v1_dict_format
@property
def unique_key(self):
if self.v1_dict_format:
# Since V1 data format is verbose, just use the validated version to get the key.
return self._get_validated().unique_key
def __init__(self, kind, signed_data, v1_dict_format):
self.kind = kind
self.signed_data = signed_data
self.v1_dict_format = v1_dict_format
signed_dict = self.signed_data
return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)'))
@property
def unique_key(self):
if self.v1_dict_format:
# Since V1 data format is verbose, just use the validated version to get the key.
return self._get_validated().unique_key
@classmethod
def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
if not v1_dict_format:
entity_kind = ContextEntityKind(dict_data.get('entity_kind', 'anonymous'))
return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
signed_dict = self.signed_data
return "%s-%s" % (
signed_dict["entity_kind"],
signed_dict.get("entity_reference", "(anon)"),
)
# Legacy handling.
# TODO: Remove this all once the new code is fully deployed.
kind_string = dict_data.get('kind', 'anonymous')
if kind_string == 'oauth':
kind_string = 'oauthtoken'
@classmethod
def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
if not v1_dict_format:
entity_kind = ContextEntityKind(dict_data.get("entity_kind", "anonymous"))
return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
kind = ContextEntityKind(kind_string)
return SignedAuthContext(kind, dict_data, v1_dict_format)
# Legacy handling.
# TODO: Remove this all once the new code is fully deployed.
kind_string = dict_data.get("kind", "anonymous")
if kind_string == "oauth":
kind_string = "oauthtoken"
@lru_cache(maxsize=1)
def _get_validated(self):
""" Returns a ValidatedAuthContext for this signed context, resolving all the necessary
kind = ContextEntityKind(kind_string)
return SignedAuthContext(kind, dict_data, v1_dict_format)
@lru_cache(maxsize=1)
def _get_validated(self):
""" Returns a ValidatedAuthContext for this signed context, resolving all the necessary
references.
"""
if not self.v1_dict_format:
if self.kind == ContextEntityKind.anonymous:
return ValidatedAuthContext()
if not self.v1_dict_format:
if self.kind == ContextEntityKind.anonymous:
return ValidatedAuthContext()
serialized_entity_reference = self.signed_data['entity_reference']
handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
entity_reference = handler.deserialize_entity_reference(serialized_entity_reference)
if entity_reference is None:
logger.debug('Could not deserialize entity reference `%s` under kind `%s`',
serialized_entity_reference, self.kind)
return ValidatedAuthContext()
serialized_entity_reference = self.signed_data["entity_reference"]
handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
entity_reference = handler.deserialize_entity_reference(
serialized_entity_reference
)
if entity_reference is None:
logger.debug(
"Could not deserialize entity reference `%s` under kind `%s`",
serialized_entity_reference,
self.kind,
)
return ValidatedAuthContext()
return ValidatedAuthContext(**{self.kind.value: entity_reference})
return ValidatedAuthContext(**{self.kind.value: entity_reference})
# Legacy handling.
# TODO: Remove this all once the new code is fully deployed.
kind_string = self.signed_data.get('kind', 'anonymous')
if kind_string == 'oauth':
kind_string = 'oauthtoken'
# Legacy handling.
# TODO: Remove this all once the new code is fully deployed.
kind_string = self.signed_data.get("kind", "anonymous")
if kind_string == "oauth":
kind_string = "oauthtoken"
kind = ContextEntityKind(kind_string)
if kind == ContextEntityKind.anonymous:
return ValidatedAuthContext()
kind = ContextEntityKind(kind_string)
if kind == ContextEntityKind.anonymous:
return ValidatedAuthContext()
if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
user = model.user.get_user(self.signed_data.get('user', ''))
if not user:
return None
if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
user = model.user.get_user(self.signed_data.get("user", ""))
if not user:
return None
return ValidatedAuthContext(robot=user) if user.robot else ValidatedAuthContext(user=user)
return (
ValidatedAuthContext(robot=user)
if user.robot
else ValidatedAuthContext(user=user)
)
if kind == ContextEntityKind.token:
token = model.token.load_token_data(self.signed_data.get('token'))
if not token:
return None
if kind == ContextEntityKind.token:
token = model.token.load_token_data(self.signed_data.get("token"))
if not token:
return None
return ValidatedAuthContext(token=token)
return ValidatedAuthContext(token=token)
if kind == ContextEntityKind.oauthtoken:
user = model.user.get_user(self.signed_data.get('user', ''))
if not user:
return None
if kind == ContextEntityKind.oauthtoken:
user = model.user.get_user(self.signed_data.get("user", ""))
if not user:
return None
token_uuid = self.signed_data.get('oauth', '')
oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
if not oauthtoken:
return None
token_uuid = self.signed_data.get("oauth", "")
oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
if not oauthtoken:
return None
return ValidatedAuthContext(oauthtoken=oauthtoken)
return ValidatedAuthContext(oauthtoken=oauthtoken)
raise Exception('Unknown auth context kind `%s` when deserializing %s' % (kind,
self.signed_data))
# End of legacy handling.
raise Exception(
"Unknown auth context kind `%s` when deserializing %s"
% (kind, self.signed_data)
)
# End of legacy handling.
@property
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
return self.kind
@property
def entity_kind(self):
""" Returns the kind of the entity in this auth context. """
return self.kind
@property
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
return self.kind == ContextEntityKind.anonymous
@property
def is_anonymous(self):
""" Returns true if this is an anonymous context. """
return self.kind == ContextEntityKind.anonymous
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts.
"""
if self.kind == ContextEntityKind.anonymous:
return None
if self.kind == ContextEntityKind.anonymous:
return None
return self._get_validated().authed_user
return self._get_validated().authed_user
@property
def authed_oauth_token(self):
if self.kind == ContextEntityKind.anonymous:
return None
@property
def authed_oauth_token(self):
if self.kind == ContextEntityKind.anonymous:
return None
return self._get_validated().authed_oauth_token
return self._get_validated().authed_oauth_token
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
if self.kind == ContextEntityKind.anonymous:
return False
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
if self.kind == ContextEntityKind.anonymous:
return False
return self._get_validated().has_nonrobot_user
return self._get_validated().has_nonrobot_user
@property
def identity(self):
""" Returns the identity for the auth context. """
return self._get_validated().identity
@property
def identity(self):
""" Returns the identity for the auth context. """
return self._get_validated().identity
@property
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
return self._get_validated().description
@property
def description(self):
""" Returns a human-readable and *public* description of the current auth context. """
return self._get_validated().description
@property
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
return self._get_validated().credential_username
@property
def credential_username(self):
""" Returns the username to create credentials for this context's entity, if any. """
return self._get_validated().credential_username
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
return self._get_validated().analytics_id_and_public_metadata()
def analytics_id_and_public_metadata(self):
""" Returns the analytics ID and public log metadata for this auth context. """
return self._get_validated().analytics_id_and_public_metadata()
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
return self._get_validated().apply_to_request_context()
def apply_to_request_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
return self._get_validated().apply_to_request_context()
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
def to_signed_dict(self):
""" Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
return self.signed_data
return self.signed_data

View file

@ -8,51 +8,54 @@ from auth.validateresult import ValidateResult, AuthKind
logger = logging.getLogger(__name__)
def has_basic_auth(username):
""" Returns true if a basic auth header exists with a username and password pair that validates
""" Returns true if a basic auth header exists with a username and password pair that validates
against the internal authentication system. Returns True on full success and False on any
failure (missing header, invalid header, invalid credentials, etc).
"""
auth_header = request.headers.get('authorization', '')
result = validate_basic_auth(auth_header)
return result.has_nonrobot_user and result.context.user.username == username
auth_header = request.headers.get("authorization", "")
result = validate_basic_auth(auth_header)
return result.has_nonrobot_user and result.context.user.username == username
def validate_basic_auth(auth_header):
""" Validates the specified basic auth header, returning whether its credentials point
""" Validates the specified basic auth header, returning whether its credentials point
to a valid user or token.
"""
if not auth_header:
return ValidateResult(AuthKind.basic, missing=True)
if not auth_header:
return ValidateResult(AuthKind.basic, missing=True)
logger.debug('Attempt to process basic auth header')
logger.debug("Attempt to process basic auth header")
# Parse the basic auth header.
assert isinstance(auth_header, basestring)
credentials, err = _parse_basic_auth_header(auth_header)
if err is not None:
logger.debug('Got invalid basic auth header: %s', auth_header)
return ValidateResult(AuthKind.basic, missing=True)
# Parse the basic auth header.
assert isinstance(auth_header, basestring)
credentials, err = _parse_basic_auth_header(auth_header)
if err is not None:
logger.debug("Got invalid basic auth header: %s", auth_header)
return ValidateResult(AuthKind.basic, missing=True)
auth_username, auth_password_or_token = credentials
result, _ = validate_credentials(auth_username, auth_password_or_token)
return result.with_kind(AuthKind.basic)
auth_username, auth_password_or_token = credentials
result, _ = validate_credentials(auth_username, auth_password_or_token)
return result.with_kind(AuthKind.basic)
def _parse_basic_auth_header(auth):
""" Parses the given basic auth header, returning the credentials found inside.
""" Parses the given basic auth header, returning the credentials found inside.
"""
normalized = [part.strip() for part in auth.split(' ') if part]
if normalized[0].lower() != 'basic' or len(normalized) != 2:
return None, 'Invalid basic auth header'
normalized = [part.strip() for part in auth.split(" ") if part]
if normalized[0].lower() != "basic" or len(normalized) != 2:
return None, "Invalid basic auth header"
try:
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
except (TypeError, UnicodeDecodeError, ValueError):
logger.exception('Exception when parsing basic auth header: %s', auth)
return None, 'Could not parse basic auth header'
try:
credentials = [
part.decode("utf-8") for part in b64decode(normalized[1]).split(":", 1)
]
except (TypeError, UnicodeDecodeError, ValueError):
logger.exception("Exception when parsing basic auth header: %s", auth)
return None, "Could not parse basic auth header"
if len(credentials) != 2:
return None, 'Unexpected number of credentials found in basic auth header'
if len(credentials) != 2:
return None, "Unexpected number of credentials found in basic auth header"
return credentials, None
return credentials, None

View file

@ -4,200 +4,210 @@ from enum import Enum
from data import model
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME)
from auth.credential_consts import (
ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
class ContextEntityKind(Enum):
""" Defines the various kinds of entities in an auth context. Note that the string values of
""" Defines the various kinds of entities in an auth context. Note that the string values of
these fields *must* match the names of the fields in the ValidatedAuthContext class, as
we fill them in directly based on the string names here.
"""
anonymous = 'anonymous'
user = 'user'
robot = 'robot'
token = 'token'
oauthtoken = 'oauthtoken'
appspecifictoken = 'appspecifictoken'
signed_data = 'signed_data'
anonymous = "anonymous"
user = "user"
robot = "robot"
token = "token"
oauthtoken = "oauthtoken"
appspecifictoken = "appspecifictoken"
signed_data = "signed_data"
@add_metaclass(ABCMeta)
class ContextEntityHandler(object):
"""
"""
Interface that represents handling specific kinds of entities under an auth context.
"""
@abstractmethod
def credential_username(self, entity_reference):
""" Returns the username to create credentials for this entity, if any. """
pass
@abstractmethod
def credential_username(self, entity_reference):
""" Returns the username to create credentials for this entity, if any. """
pass
@abstractmethod
def get_serialized_entity_reference(self, entity_reference):
""" Returns the entity reference for this kind of auth context, serialized into a form that can
@abstractmethod
def get_serialized_entity_reference(self, entity_reference):
""" Returns the entity reference for this kind of auth context, serialized into a form that can
be placed into a JSON object and put into a JWT. This is typically a DB UUID or another
unique identifier for the object in the DB.
"""
pass
pass
@abstractmethod
def deserialize_entity_reference(self, serialized_entity_reference):
""" Returns the deserialized reference to the entity in the database, or None if none. """
pass
@abstractmethod
def deserialize_entity_reference(self, serialized_entity_reference):
""" Returns the deserialized reference to the entity in the database, or None if none. """
pass
@abstractmethod
def description(self, entity_reference):
""" Returns a human-readable and *public* description of the current entity. """
pass
@abstractmethod
def description(self, entity_reference):
""" Returns a human-readable and *public* description of the current entity. """
pass
@abstractmethod
def analytics_id_and_public_metadata(self, entity_reference):
""" Returns the analyitics ID and a dict of public metadata for the current entity. """
pass
@abstractmethod
def analytics_id_and_public_metadata(self, entity_reference):
""" Returns the analyitics ID and a dict of public metadata for the current entity. """
pass
class AnonymousEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return None
def credential_username(self, entity_reference):
return None
def get_serialized_entity_reference(self, entity_reference):
return None
def get_serialized_entity_reference(self, entity_reference):
return None
def deserialize_entity_reference(self, serialized_entity_reference):
return None
def deserialize_entity_reference(self, serialized_entity_reference):
return None
def description(self, entity_reference):
return "anonymous"
def description(self, entity_reference):
return "anonymous"
def analytics_id_and_public_metadata(self, entity_reference):
return "anonymous", {}
def analytics_id_and_public_metadata(self, entity_reference):
return "anonymous", {}
class UserEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return entity_reference.username
def credential_username(self, entity_reference):
return entity_reference.username
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.get_user_by_uuid(serialized_entity_reference)
def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.get_user_by_uuid(serialized_entity_reference)
def description(self, entity_reference):
return "user %s" % entity_reference.username
def description(self, entity_reference):
return "user %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, {
'username': entity_reference.username,
}
def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, {"username": entity_reference.username}
class RobotEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return entity_reference.username
def credential_username(self, entity_reference):
return entity_reference.username
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.username
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.username
def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.lookup_robot(serialized_entity_reference)
def deserialize_entity_reference(self, serialized_entity_reference):
return model.user.lookup_robot(serialized_entity_reference)
def description(self, entity_reference):
return "robot %s" % entity_reference.username
def description(self, entity_reference):
return "robot %s" % entity_reference.username
def analytics_id_and_public_metadata(self, entity_reference):
return entity_reference.username, {
'username': entity_reference.username,
'is_robot': True,
}
def analytics_id_and_public_metadata(self, entity_reference):
return (
entity_reference.username,
{"username": entity_reference.username, "is_robot": True},
)
class TokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return ACCESS_TOKEN_USERNAME
def credential_username(self, entity_reference):
return ACCESS_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.get_code()
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.get_code()
def deserialize_entity_reference(self, serialized_entity_reference):
return model.token.load_token_data(serialized_entity_reference)
def deserialize_entity_reference(self, serialized_entity_reference):
return model.token.load_token_data(serialized_entity_reference)
def description(self, entity_reference):
return "token %s" % entity_reference.friendly_name
def description(self, entity_reference):
return "token %s" % entity_reference.friendly_name
def analytics_id_and_public_metadata(self, entity_reference):
return 'token:%s' % entity_reference.id, {
'token': entity_reference.friendly_name,
}
def analytics_id_and_public_metadata(self, entity_reference):
return (
"token:%s" % entity_reference.id,
{"token": entity_reference.friendly_name},
)
class OAuthTokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return OAUTH_TOKEN_USERNAME
def credential_username(self, entity_reference):
return OAUTH_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference):
return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference)
def deserialize_entity_reference(self, serialized_entity_reference):
return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference)
def description(self, entity_reference):
return "oauthtoken for user %s" % entity_reference.authorized_user.username
def description(self, entity_reference):
return "oauthtoken for user %s" % entity_reference.authorized_user.username
def analytics_id_and_public_metadata(self, entity_reference):
return 'oauthtoken:%s' % entity_reference.id, {
'oauth_token_id': entity_reference.id,
'oauth_token_application_id': entity_reference.application.client_id,
'oauth_token_application': entity_reference.application.name,
'username': entity_reference.authorized_user.username,
}
def analytics_id_and_public_metadata(self, entity_reference):
return (
"oauthtoken:%s" % entity_reference.id,
{
"oauth_token_id": entity_reference.id,
"oauth_token_application_id": entity_reference.application.client_id,
"oauth_token_application": entity_reference.application.name,
"username": entity_reference.authorized_user.username,
},
)
class AppSpecificTokenEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return APP_SPECIFIC_TOKEN_USERNAME
def credential_username(self, entity_reference):
return APP_SPECIFIC_TOKEN_USERNAME
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def get_serialized_entity_reference(self, entity_reference):
return entity_reference.uuid
def deserialize_entity_reference(self, serialized_entity_reference):
return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference)
def deserialize_entity_reference(self, serialized_entity_reference):
return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference)
def description(self, entity_reference):
tpl = (entity_reference.title, entity_reference.user.username)
return "app specific token %s for user %s" % tpl
def description(self, entity_reference):
tpl = (entity_reference.title, entity_reference.user.username)
return "app specific token %s for user %s" % tpl
def analytics_id_and_public_metadata(self, entity_reference):
return 'appspecifictoken:%s' % entity_reference.id, {
'app_specific_token': entity_reference.uuid,
'app_specific_token_title': entity_reference.title,
'username': entity_reference.user.username,
}
def analytics_id_and_public_metadata(self, entity_reference):
return (
"appspecifictoken:%s" % entity_reference.id,
{
"app_specific_token": entity_reference.uuid,
"app_specific_token_title": entity_reference.title,
"username": entity_reference.user.username,
},
)
class SignedDataEntityHandler(ContextEntityHandler):
def credential_username(self, entity_reference):
return None
def credential_username(self, entity_reference):
return None
def get_serialized_entity_reference(self, entity_reference):
raise NotImplementedError
def get_serialized_entity_reference(self, entity_reference):
raise NotImplementedError
def deserialize_entity_reference(self, serialized_entity_reference):
raise NotImplementedError
def deserialize_entity_reference(self, serialized_entity_reference):
raise NotImplementedError
def description(self, entity_reference):
return "signed"
def description(self, entity_reference):
return "signed"
def analytics_id_and_public_metadata(self, entity_reference):
return 'signed', {'signed': entity_reference}
def analytics_id_and_public_metadata(self, entity_reference):
return "signed", {"signed": entity_reference}
CONTEXT_ENTITY_HANDLERS = {
ContextEntityKind.anonymous: AnonymousEntityHandler,
ContextEntityKind.user: UserEntityHandler,
ContextEntityKind.robot: RobotEntityHandler,
ContextEntityKind.token: TokenEntityHandler,
ContextEntityKind.oauthtoken: OAuthTokenEntityHandler,
ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler,
ContextEntityKind.signed_data: SignedDataEntityHandler,
ContextEntityKind.anonymous: AnonymousEntityHandler,
ContextEntityKind.user: UserEntityHandler,
ContextEntityKind.robot: RobotEntityHandler,
ContextEntityKind.token: TokenEntityHandler,
ContextEntityKind.oauthtoken: OAuthTokenEntityHandler,
ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler,
ContextEntityKind.signed_data: SignedDataEntityHandler,
}

View file

@ -7,31 +7,40 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__)
def validate_session_cookie(auth_header_unusued=None):
""" Attempts to load a user from a session cookie. """
if current_user.is_anonymous:
return ValidateResult(AuthKind.cookie, missing=True)
""" Attempts to load a user from a session cookie. """
if current_user.is_anonymous:
return ValidateResult(AuthKind.cookie, missing=True)
try:
# Attempt to parse the user uuid to make sure the cookie has the right value type
UUID(current_user.get_id())
except ValueError:
logger.debug('Got non-UUID for session cookie user: %s', current_user.get_id())
return ValidateResult(AuthKind.cookie, error_message='Invalid session cookie format')
try:
# Attempt to parse the user uuid to make sure the cookie has the right value type
UUID(current_user.get_id())
except ValueError:
logger.debug("Got non-UUID for session cookie user: %s", current_user.get_id())
return ValidateResult(
AuthKind.cookie, error_message="Invalid session cookie format"
)
logger.debug('Loading user from cookie: %s', current_user.get_id())
db_user = current_user.db_user()
if db_user is None:
return ValidateResult(AuthKind.cookie, error_message='Could not find matching user')
logger.debug("Loading user from cookie: %s", current_user.get_id())
db_user = current_user.db_user()
if db_user is None:
return ValidateResult(
AuthKind.cookie, error_message="Could not find matching user"
)
# Don't allow disabled users to login.
if not db_user.enabled:
logger.debug('User %s in session cookie is disabled', db_user.username)
return ValidateResult(AuthKind.cookie, error_message='User account is disabled')
# Don't allow disabled users to login.
if not db_user.enabled:
logger.debug("User %s in session cookie is disabled", db_user.username)
return ValidateResult(AuthKind.cookie, error_message="User account is disabled")
# Don't allow organizations to "login".
if db_user.organization:
logger.debug('User %s in session cookie is in-fact organization', db_user.username)
return ValidateResult(AuthKind.cookie, error_message='Cannot login to organization')
# Don't allow organizations to "login".
if db_user.organization:
logger.debug(
"User %s in session cookie is in-fact organization", db_user.username
)
return ValidateResult(
AuthKind.cookie, error_message="Cannot login to organization"
)
return ValidateResult(AuthKind.cookie, user=db_user)
return ValidateResult(AuthKind.cookie, user=db_user)

View file

@ -1,3 +1,3 @@
ACCESS_TOKEN_USERNAME = '$token'
OAUTH_TOKEN_USERNAME = '$oauthtoken'
APP_SPECIFIC_TOKEN_USERNAME = '$app'
ACCESS_TOKEN_USERNAME = "$token"
OAUTH_TOKEN_USERNAME = "$oauthtoken"
APP_SPECIFIC_TOKEN_USERNAME = "$app"

View file

@ -7,8 +7,11 @@ import features
from app import authentication
from auth.oauth import validate_oauth_token
from auth.validateresult import ValidateResult, AuthKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME)
from auth.credential_consts import (
ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from data import model
from util.names import parse_robot_username
@ -16,70 +19,116 @@ logger = logging.getLogger(__name__)
class CredentialKind(Enum):
user = 'user'
robot = 'robot'
token = ACCESS_TOKEN_USERNAME
oauth_token = OAUTH_TOKEN_USERNAME
app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
user = "user"
robot = "robot"
token = ACCESS_TOKEN_USERNAME
oauth_token = OAUTH_TOKEN_USERNAME
app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
def validate_credentials(auth_username, auth_password_or_token):
""" Validates a pair of auth username and password/token credentials. """
# Check for access tokens.
if auth_username == ACCESS_TOKEN_USERNAME:
logger.debug('Found credentials for access token')
try:
token = model.token.load_token_data(auth_password_or_token)
logger.debug('Successfully validated credentials for access token %s', token.id)
return ValidateResult(AuthKind.credentials, token=token), CredentialKind.token
except model.DataModelException:
logger.warning('Failed to validate credentials for access token %s', auth_password_or_token)
return (ValidateResult(AuthKind.credentials, error_message='Invalid access token'),
CredentialKind.token)
""" Validates a pair of auth username and password/token credentials. """
# Check for access tokens.
if auth_username == ACCESS_TOKEN_USERNAME:
logger.debug("Found credentials for access token")
try:
token = model.token.load_token_data(auth_password_or_token)
logger.debug(
"Successfully validated credentials for access token %s", token.id
)
return (
ValidateResult(AuthKind.credentials, token=token),
CredentialKind.token,
)
except model.DataModelException:
logger.warning(
"Failed to validate credentials for access token %s",
auth_password_or_token,
)
return (
ValidateResult(
AuthKind.credentials, error_message="Invalid access token"
),
CredentialKind.token,
)
# Check for App Specific tokens.
if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
logger.debug('Found credentials for app specific auth token')
token = model.appspecifictoken.access_valid_token(auth_password_or_token)
if token is None:
logger.debug('Failed to validate credentials for app specific token: %s',
auth_password_or_token)
return (ValidateResult(AuthKind.credentials, error_message='Invalid token'),
CredentialKind.app_specific_token)
# Check for App Specific tokens.
if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
logger.debug("Found credentials for app specific auth token")
token = model.appspecifictoken.access_valid_token(auth_password_or_token)
if token is None:
logger.debug(
"Failed to validate credentials for app specific token: %s",
auth_password_or_token,
)
return (
ValidateResult(AuthKind.credentials, error_message="Invalid token"),
CredentialKind.app_specific_token,
)
if not token.user.enabled:
logger.debug('Tried to use an app specific token for a disabled user: %s',
token.uuid)
return (ValidateResult(AuthKind.credentials,
error_message='This user has been disabled. Please contact your administrator.'),
CredentialKind.app_specific_token)
if not token.user.enabled:
logger.debug(
"Tried to use an app specific token for a disabled user: %s", token.uuid
)
return (
ValidateResult(
AuthKind.credentials,
error_message="This user has been disabled. Please contact your administrator.",
),
CredentialKind.app_specific_token,
)
logger.debug('Successfully validated credentials for app specific token %s', token.id)
return (ValidateResult(AuthKind.credentials, appspecifictoken=token),
CredentialKind.app_specific_token)
logger.debug(
"Successfully validated credentials for app specific token %s", token.id
)
return (
ValidateResult(AuthKind.credentials, appspecifictoken=token),
CredentialKind.app_specific_token,
)
# Check for OAuth tokens.
if auth_username == OAUTH_TOKEN_USERNAME:
return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token
# Check for OAuth tokens.
if auth_username == OAUTH_TOKEN_USERNAME:
return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token
# Check for robots and users.
is_robot = parse_robot_username(auth_username)
if is_robot:
logger.debug('Found credentials header for robot %s', auth_username)
try:
robot = model.user.verify_robot(auth_username, auth_password_or_token)
logger.debug('Successfully validated credentials for robot %s', auth_username)
return ValidateResult(AuthKind.credentials, robot=robot), CredentialKind.robot
except model.InvalidRobotException as ire:
logger.warning('Failed to validate credentials for robot %s: %s', auth_username, ire)
return ValidateResult(AuthKind.credentials, error_message=str(ire)), CredentialKind.robot
# Check for robots and users.
is_robot = parse_robot_username(auth_username)
if is_robot:
logger.debug("Found credentials header for robot %s", auth_username)
try:
robot = model.user.verify_robot(auth_username, auth_password_or_token)
logger.debug(
"Successfully validated credentials for robot %s", auth_username
)
return (
ValidateResult(AuthKind.credentials, robot=robot),
CredentialKind.robot,
)
except model.InvalidRobotException as ire:
logger.warning(
"Failed to validate credentials for robot %s: %s", auth_username, ire
)
return (
ValidateResult(AuthKind.credentials, error_message=str(ire)),
CredentialKind.robot,
)
# Otherwise, treat as a standard user.
(authenticated, err) = authentication.verify_and_link_user(auth_username, auth_password_or_token,
basic_auth=True)
if authenticated:
logger.debug('Successfully validated credentials for user %s', authenticated.username)
return ValidateResult(AuthKind.credentials, user=authenticated), CredentialKind.user
else:
logger.warning('Failed to validate credentials for user %s: %s', auth_username, err)
return ValidateResult(AuthKind.credentials, error_message=err), CredentialKind.user
# Otherwise, treat as a standard user.
(authenticated, err) = authentication.verify_and_link_user(
auth_username, auth_password_or_token, basic_auth=True
)
if authenticated:
logger.debug(
"Successfully validated credentials for user %s", authenticated.username
)
return (
ValidateResult(AuthKind.credentials, user=authenticated),
CredentialKind.user,
)
else:
logger.warning(
"Failed to validate credentials for user %s: %s", auth_username, err
)
return (
ValidateResult(AuthKind.credentials, error_message=err),
CredentialKind.user,
)

View file

@ -14,83 +14,101 @@ from util.http import abort
logger = logging.getLogger(__name__)
def _auth_decorator(pass_result=False, handlers=None):
""" Builds an auth decorator that runs the given handlers and, if any return successfully,
""" Builds an auth decorator that runs the given handlers and, if any return successfully,
sets up the auth context. The wrapped function will be invoked *regardless of success or
failure of the auth handler(s)*
"""
def processor(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get('authorization', '')
result = None
for handler in handlers:
result = handler(auth_header)
# If the handler was missing the necessary information, skip it and try the next one.
if result.missing:
continue
def processor(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get("authorization", "")
result = None
# Check for a valid result.
if result.auth_valid:
logger.debug('Found valid auth result: %s', result.tuple())
for handler in handlers:
result = handler(auth_header)
# If the handler was missing the necessary information, skip it and try the next one.
if result.missing:
continue
# Set the various pieces of the auth context.
result.apply_to_context()
# Check for a valid result.
if result.auth_valid:
logger.debug("Found valid auth result: %s", result.tuple())
# Log the metric.
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
break
# Set the various pieces of the auth context.
result.apply_to_context()
# Otherwise, report the error.
if result.error_message is not None:
# Log the failure.
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
break
# Log the metric.
metric_queue.authentication_count.Inc(
labelvalues=[result.kind, True]
)
break
if pass_result:
kwargs['auth_result'] = result
# Otherwise, report the error.
if result.error_message is not None:
# Log the failure.
metric_queue.authentication_count.Inc(
labelvalues=[result.kind, False]
)
break
return func(*args, **kwargs)
return wrapper
return processor
if pass_result:
kwargs["auth_result"] = result
return func(*args, **kwargs)
return wrapper
return processor
process_oauth = _auth_decorator(handlers=[validate_bearer_auth, validate_session_cookie])
process_oauth = _auth_decorator(
handlers=[validate_bearer_auth, validate_session_cookie]
)
process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth])
process_auth_or_cookie = _auth_decorator(handlers=[validate_basic_auth, validate_session_cookie])
process_auth_or_cookie = _auth_decorator(
handlers=[validate_basic_auth, validate_session_cookie]
)
process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True)
process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth])
def require_session_login(func):
""" Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
""" Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
a valid session cookie does exist, the authenticated user and identity are also set.
"""
@wraps(func)
def wrapper(*args, **kwargs):
result = validate_session_cookie()
if result.has_nonrobot_user:
result.apply_to_context()
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
return func(*args, **kwargs)
elif not result.missing:
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
abort(401, message='Method requires login and no valid login could be loaded.')
return wrapper
@wraps(func)
def wrapper(*args, **kwargs):
result = validate_session_cookie()
if result.has_nonrobot_user:
result.apply_to_context()
metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
return func(*args, **kwargs)
elif not result.missing:
metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
abort(401, message="Method requires login and no valid login could be loaded.")
return wrapper
def extract_namespace_repo_from_session(func):
""" Extracts the namespace and repository name from the current session (which must exist)
""" Extracts the namespace and repository name from the current session (which must exist)
and passes them into the decorated function as the first and second arguments. If the
session doesn't exist or does not contain these arugments, a 400 error is raised.
"""
@wraps(func)
def wrapper(*args, **kwargs):
if 'namespace' not in session or 'repository' not in session:
logger.error('Unable to load namespace or repository from session: %s', session)
abort(400, message='Missing namespace in request')
return func(session['namespace'], session['repository'], *args, **kwargs)
return wrapper
@wraps(func)
def wrapper(*args, **kwargs):
if "namespace" not in session or "repository" not in session:
logger.error(
"Unable to load namespace or repository from session: %s", session
)
abort(400, message="Missing namespace in request")
return func(session["namespace"], session["repository"], *args, **kwargs)
return wrapper

View file

@ -8,41 +8,47 @@ from data import model
logger = logging.getLogger(__name__)
def validate_bearer_auth(auth_header):
""" Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
""" Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
points to a valid OAuth token.
"""
if not auth_header:
return ValidateResult(AuthKind.oauth, missing=True)
if not auth_header:
return ValidateResult(AuthKind.oauth, missing=True)
normalized = [part.strip() for part in auth_header.split(' ') if part]
if normalized[0].lower() != 'bearer' or len(normalized) != 2:
logger.debug('Got invalid bearer token format: %s', auth_header)
return ValidateResult(AuthKind.oauth, missing=True)
normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != "bearer" or len(normalized) != 2:
logger.debug("Got invalid bearer token format: %s", auth_header)
return ValidateResult(AuthKind.oauth, missing=True)
(_, oauth_token) = normalized
return validate_oauth_token(oauth_token)
(_, oauth_token) = normalized
return validate_oauth_token(oauth_token)
def validate_oauth_token(token):
""" Validates the specified OAuth token, returning whether it points to a valid OAuth token.
""" Validates the specified OAuth token, returning whether it points to a valid OAuth token.
"""
validated = model.oauth.validate_access_token(token)
if not validated:
logger.warning('OAuth access token could not be validated: %s', token)
return ValidateResult(AuthKind.oauth,
error_message='OAuth access token could not be validated')
validated = model.oauth.validate_access_token(token)
if not validated:
logger.warning("OAuth access token could not be validated: %s", token)
return ValidateResult(
AuthKind.oauth, error_message="OAuth access token could not be validated"
)
if validated.expires_at <= datetime.utcnow():
logger.warning('OAuth access with an expired token: %s', token)
return ValidateResult(AuthKind.oauth, error_message='OAuth access token has expired')
if validated.expires_at <= datetime.utcnow():
logger.warning("OAuth access with an expired token: %s", token)
return ValidateResult(
AuthKind.oauth, error_message="OAuth access token has expired"
)
# Don't allow disabled users to login.
if not validated.authorized_user.enabled:
return ValidateResult(AuthKind.oauth,
error_message='Granter of the oauth access token is disabled')
# Don't allow disabled users to login.
if not validated.authorized_user.enabled:
return ValidateResult(
AuthKind.oauth,
error_message="Granter of the oauth access token is disabled",
)
# We have a valid token
scope_set = scopes_from_scope_string(validated.scope)
logger.debug('Successfully validated oauth access token with scope: %s', scope_set)
return ValidateResult(AuthKind.oauth, oauthtoken=validated)
# We have a valid token
scope_set = scopes_from_scope_string(validated.scope)
logger.debug("Successfully validated oauth access token with scope: %s", scope_set)
return ValidateResult(AuthKind.oauth, oauthtoken=validated)

View file

@ -14,351 +14,399 @@ from data import model
logger = logging.getLogger(__name__)
_ResourceNeed = namedtuple('resource', ['type', 'namespace', 'name', 'role'])
_RepositoryNeed = partial(_ResourceNeed, 'repository')
_NamespaceWideNeed = namedtuple('namespacewide', ['type', 'namespace', 'role'])
_OrganizationNeed = partial(_NamespaceWideNeed, 'organization')
_OrganizationRepoNeed = partial(_NamespaceWideNeed, 'organizationrepo')
_TeamTypeNeed = namedtuple('teamwideneed', ['type', 'orgname', 'teamname', 'role'])
_TeamNeed = partial(_TeamTypeNeed, 'orgteam')
_UserTypeNeed = namedtuple('userspecificneed', ['type', 'username', 'role'])
_UserNeed = partial(_UserTypeNeed, 'user')
_SuperUserNeed = partial(namedtuple('superuserneed', ['type']), 'superuser')
_ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"])
_RepositoryNeed = partial(_ResourceNeed, "repository")
_NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"])
_OrganizationNeed = partial(_NamespaceWideNeed, "organization")
_OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo")
_TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"])
_TeamNeed = partial(_TeamTypeNeed, "orgteam")
_UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"])
_UserNeed = partial(_UserTypeNeed, "user")
_SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser")
REPO_ROLES = [None, 'read', 'write', 'admin']
TEAM_ROLES = [None, 'member', 'creator', 'admin']
USER_ROLES = [None, 'read', 'admin']
REPO_ROLES = [None, "read", "write", "admin"]
TEAM_ROLES = [None, "member", "creator", "admin"]
USER_ROLES = [None, "read", "admin"]
TEAM_ORGWIDE_REPO_ROLES = {
'admin': 'admin',
'creator': None,
'member': None,
}
TEAM_ORGWIDE_REPO_ROLES = {"admin": "admin", "creator": None, "member": None}
SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None)
SCOPE_MAX_REPO_ROLES.update({
scopes.READ_REPO: 'read',
scopes.WRITE_REPO: 'write',
scopes.ADMIN_REPO: 'admin',
scopes.DIRECT_LOGIN: 'admin',
})
SCOPE_MAX_REPO_ROLES.update(
{
scopes.READ_REPO: "read",
scopes.WRITE_REPO: "write",
scopes.ADMIN_REPO: "admin",
scopes.DIRECT_LOGIN: "admin",
}
)
SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None)
SCOPE_MAX_TEAM_ROLES.update({
scopes.CREATE_REPO: 'creator',
scopes.DIRECT_LOGIN: 'admin',
scopes.ORG_ADMIN: 'admin',
})
SCOPE_MAX_TEAM_ROLES.update(
{
scopes.CREATE_REPO: "creator",
scopes.DIRECT_LOGIN: "admin",
scopes.ORG_ADMIN: "admin",
}
)
SCOPE_MAX_USER_ROLES = defaultdict(lambda: None)
SCOPE_MAX_USER_ROLES.update({
scopes.READ_USER: 'read',
scopes.DIRECT_LOGIN: 'admin',
scopes.ADMIN_USER: 'admin',
})
SCOPE_MAX_USER_ROLES.update(
{scopes.READ_USER: "read", scopes.DIRECT_LOGIN: "admin", scopes.ADMIN_USER: "admin"}
)
def repository_read_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'read')
return _RepositoryNeed(namespace, repository, "read")
def repository_write_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'write')
return _RepositoryNeed(namespace, repository, "write")
def repository_admin_grant(namespace, repository):
return _RepositoryNeed(namespace, repository, 'admin')
return _RepositoryNeed(namespace, repository, "admin")
class QuayDeferredPermissionUser(Identity):
def __init__(self, uuid, auth_type, auth_scopes, user=None):
super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type)
def __init__(self, uuid, auth_type, auth_scopes, user=None):
super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type)
self._namespace_wide_loaded = set()
self._repositories_loaded = set()
self._personal_loaded = False
self._namespace_wide_loaded = set()
self._repositories_loaded = set()
self._personal_loaded = False
self._scope_set = auth_scopes
self._user_object = user
self._scope_set = auth_scopes
self._user_object = user
@staticmethod
def for_id(uuid, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(uuid, 'user_uuid', auth_scopes)
@staticmethod
def for_id(uuid, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(uuid, "user_uuid", auth_scopes)
@staticmethod
def for_user(user, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(user.uuid, 'user_uuid', auth_scopes, user=user)
@staticmethod
def for_user(user, auth_scopes=None):
auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
return QuayDeferredPermissionUser(
user.uuid, "user_uuid", auth_scopes, user=user
)
def _translate_role_for_scopes(self, cardinality, max_roles, role):
if self._scope_set is None:
return role
def _translate_role_for_scopes(self, cardinality, max_roles, role):
if self._scope_set is None:
return role
max_for_scopes = max({cardinality.index(max_roles[scope]) for scope in self._scope_set})
max_for_scopes = max(
{cardinality.index(max_roles[scope]) for scope in self._scope_set}
)
if max_for_scopes < cardinality.index(role):
logger.debug('Translated permission %s -> %s', role, cardinality[max_for_scopes])
return cardinality[max_for_scopes]
else:
return role
if max_for_scopes < cardinality.index(role):
logger.debug(
"Translated permission %s -> %s", role, cardinality[max_for_scopes]
)
return cardinality[max_for_scopes]
else:
return role
def _team_role_for_scopes(self, role):
return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role)
def _team_role_for_scopes(self, role):
return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role)
def _repo_role_for_scopes(self, role):
return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role)
def _repo_role_for_scopes(self, role):
return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role)
def _user_role_for_scopes(self, role):
return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role)
def _user_role_for_scopes(self, role):
return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role)
def _populate_user_provides(self, user_object):
""" Populates the provides that naturally apply to a user, such as being the admin of
def _populate_user_provides(self, user_object):
""" Populates the provides that naturally apply to a user, such as being the admin of
their own namespace.
"""
# Add the user specific permissions, only for non-oauth permission
user_grant = _UserNeed(user_object.username, self._user_role_for_scopes('admin'))
logger.debug('User permission: {0}'.format(user_grant))
self.provides.add(user_grant)
# Add the user specific permissions, only for non-oauth permission
user_grant = _UserNeed(
user_object.username, self._user_role_for_scopes("admin")
)
logger.debug("User permission: {0}".format(user_grant))
self.provides.add(user_grant)
# Every user is the admin of their own 'org'
user_namespace = _OrganizationNeed(user_object.username, self._team_role_for_scopes('admin'))
logger.debug('User namespace permission: {0}'.format(user_namespace))
self.provides.add(user_namespace)
# Every user is the admin of their own 'org'
user_namespace = _OrganizationNeed(
user_object.username, self._team_role_for_scopes("admin")
)
logger.debug("User namespace permission: {0}".format(user_namespace))
self.provides.add(user_namespace)
# Org repo roles can differ for scopes
user_repos = _OrganizationRepoNeed(user_object.username, self._repo_role_for_scopes('admin'))
logger.debug('User namespace repo permission: {0}'.format(user_repos))
self.provides.add(user_repos)
# Org repo roles can differ for scopes
user_repos = _OrganizationRepoNeed(
user_object.username, self._repo_role_for_scopes("admin")
)
logger.debug("User namespace repo permission: {0}".format(user_repos))
self.provides.add(user_repos)
if ((scopes.SUPERUSER in self._scope_set or scopes.DIRECT_LOGIN in self._scope_set) and
superusers.is_superuser(user_object.username)):
logger.debug('Adding superuser to user: %s', user_object.username)
self.provides.add(_SuperUserNeed())
if (
scopes.SUPERUSER in self._scope_set
or scopes.DIRECT_LOGIN in self._scope_set
) and superusers.is_superuser(user_object.username):
logger.debug("Adding superuser to user: %s", user_object.username)
self.provides.add(_SuperUserNeed())
def _populate_namespace_wide_provides(self, user_object, namespace_filter):
""" Populates the namespace-wide provides for a particular user under a particular namespace.
def _populate_namespace_wide_provides(self, user_object, namespace_filter):
""" Populates the namespace-wide provides for a particular user under a particular namespace.
This method does *not* add any provides for specific repositories.
"""
for team in model.permission.get_org_wide_permissions(user_object, org_filter=namespace_filter):
team_org_grant = _OrganizationNeed(team.organization.username,
self._team_role_for_scopes(team.role.name))
logger.debug('Organization team added permission: {0}'.format(team_org_grant))
self.provides.add(team_org_grant)
for team in model.permission.get_org_wide_permissions(
user_object, org_filter=namespace_filter
):
team_org_grant = _OrganizationNeed(
team.organization.username, self._team_role_for_scopes(team.role.name)
)
logger.debug(
"Organization team added permission: {0}".format(team_org_grant)
)
self.provides.add(team_org_grant)
team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
org_repo_grant = _OrganizationRepoNeed(team.organization.username,
self._repo_role_for_scopes(team_repo_role))
logger.debug('Organization team added repo permission: {0}'.format(org_repo_grant))
self.provides.add(org_repo_grant)
team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
org_repo_grant = _OrganizationRepoNeed(
team.organization.username, self._repo_role_for_scopes(team_repo_role)
)
logger.debug(
"Organization team added repo permission: {0}".format(org_repo_grant)
)
self.provides.add(org_repo_grant)
team_grant = _TeamNeed(team.organization.username, team.name,
self._team_role_for_scopes(team.role.name))
logger.debug('Team added permission: {0}'.format(team_grant))
self.provides.add(team_grant)
team_grant = _TeamNeed(
team.organization.username,
team.name,
self._team_role_for_scopes(team.role.name),
)
logger.debug("Team added permission: {0}".format(team_grant))
self.provides.add(team_grant)
def _populate_repository_provides(self, user_object, namespace_filter, repository_name):
""" Populates the repository-specific provides for a particular user and repository. """
def _populate_repository_provides(
self, user_object, namespace_filter, repository_name
):
""" Populates the repository-specific provides for a particular user and repository. """
if namespace_filter and repository_name:
permissions = model.permission.get_user_repository_permissions(user_object, namespace_filter,
repository_name)
else:
permissions = model.permission.get_all_user_repository_permissions(user_object)
if namespace_filter and repository_name:
permissions = model.permission.get_user_repository_permissions(
user_object, namespace_filter, repository_name
)
else:
permissions = model.permission.get_all_user_repository_permissions(
user_object
)
for perm in permissions:
repo_grant = _RepositoryNeed(perm.repository.namespace_user.username, perm.repository.name,
self._repo_role_for_scopes(perm.role.name))
logger.debug('User added permission: {0}'.format(repo_grant))
self.provides.add(repo_grant)
for perm in permissions:
repo_grant = _RepositoryNeed(
perm.repository.namespace_user.username,
perm.repository.name,
self._repo_role_for_scopes(perm.role.name),
)
logger.debug("User added permission: {0}".format(repo_grant))
self.provides.add(repo_grant)
def can(self, permission):
logger.debug('Loading user permissions after deferring for: %s', self.id)
user_object = self._user_object or model.user.get_user_by_uuid(self.id)
if user_object is None:
return super(QuayDeferredPermissionUser, self).can(permission)
def can(self, permission):
logger.debug("Loading user permissions after deferring for: %s", self.id)
user_object = self._user_object or model.user.get_user_by_uuid(self.id)
if user_object is None:
return super(QuayDeferredPermissionUser, self).can(permission)
# Add the user-specific provides.
if not self._personal_loaded:
self._populate_user_provides(user_object)
self._personal_loaded = True
# Add the user-specific provides.
if not self._personal_loaded:
self._populate_user_provides(user_object)
self._personal_loaded = True
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
# Check for namespace and/or repository permissions.
perm_namespace = permission.namespace
perm_repo_name = permission.repo_name
perm_repository = None
# Check for namespace and/or repository permissions.
perm_namespace = permission.namespace
perm_repo_name = permission.repo_name
perm_repository = None
if perm_namespace and perm_repo_name:
perm_repository = '%s/%s' % (perm_namespace, perm_repo_name)
if perm_namespace and perm_repo_name:
perm_repository = "%s/%s" % (perm_namespace, perm_repo_name)
if not perm_namespace and not perm_repo_name:
# Nothing more to load, so just check directly.
return super(QuayDeferredPermissionUser, self).can(permission)
if not perm_namespace and not perm_repo_name:
# Nothing more to load, so just check directly.
return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the repository-specific permissions.
if perm_repository and perm_repository not in self._repositories_loaded:
self._populate_repository_provides(user_object, perm_namespace, perm_repo_name)
self._repositories_loaded.add(perm_repository)
# Lazy-load the repository-specific permissions.
if perm_repository and perm_repository not in self._repositories_loaded:
self._populate_repository_provides(
user_object, perm_namespace, perm_repo_name
)
self._repositories_loaded.add(perm_repository)
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the namespace-wide-only permissions.
if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
self._populate_namespace_wide_provides(user_object, perm_namespace)
self._namespace_wide_loaded.add(perm_namespace)
# If we now have permission, no need to load any more permissions.
if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
# Lazy-load the namespace-wide-only permissions.
if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
self._populate_namespace_wide_provides(user_object, perm_namespace)
self._namespace_wide_loaded.add(perm_namespace)
return super(QuayDeferredPermissionUser, self).can(permission)
class QuayPermission(Permission):
""" Base for all permissions in Quay. """
namespace = None
repo_name = None
""" Base for all permissions in Quay. """
namespace = None
repo_name = None
class ModifyRepositoryPermission(QuayPermission):
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin')
write_need = _RepositoryNeed(namespace, name, 'write')
org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
org_write_need = _OrganizationRepoNeed(namespace, 'write')
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, "write")
org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, "write")
self.namespace = namespace
self.repo_name = name
self.namespace = namespace
self.repo_name = name
super(ModifyRepositoryPermission, self).__init__(admin_need, write_need, org_admin_need,
org_write_need)
super(ModifyRepositoryPermission, self).__init__(
admin_need, write_need, org_admin_need, org_write_need
)
class ReadRepositoryPermission(QuayPermission):
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin')
write_need = _RepositoryNeed(namespace, name, 'write')
read_need = _RepositoryNeed(namespace, name, 'read')
org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
org_write_need = _OrganizationRepoNeed(namespace, 'write')
org_read_need = _OrganizationRepoNeed(namespace, 'read')
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, "admin")
write_need = _RepositoryNeed(namespace, name, "write")
read_need = _RepositoryNeed(namespace, name, "read")
org_admin_need = _OrganizationRepoNeed(namespace, "admin")
org_write_need = _OrganizationRepoNeed(namespace, "write")
org_read_need = _OrganizationRepoNeed(namespace, "read")
self.namespace = namespace
self.repo_name = name
self.namespace = namespace
self.repo_name = name
super(ReadRepositoryPermission, self).__init__(admin_need, write_need, read_need,
org_admin_need, org_read_need, org_write_need)
super(ReadRepositoryPermission, self).__init__(
admin_need,
write_need,
read_need,
org_admin_need,
org_read_need,
org_write_need,
)
class AdministerRepositoryPermission(QuayPermission):
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, 'admin')
org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
def __init__(self, namespace, name):
admin_need = _RepositoryNeed(namespace, name, "admin")
org_admin_need = _OrganizationRepoNeed(namespace, "admin")
self.namespace = namespace
self.repo_name = name
self.namespace = namespace
self.repo_name = name
super(AdministerRepositoryPermission, self).__init__(admin_need,
org_admin_need)
super(AdministerRepositoryPermission, self).__init__(admin_need, org_admin_need)
class CreateRepositoryPermission(QuayPermission):
def __init__(self, namespace):
admin_org = _OrganizationNeed(namespace, 'admin')
create_repo_org = _OrganizationNeed(namespace, 'creator')
def __init__(self, namespace):
admin_org = _OrganizationNeed(namespace, "admin")
create_repo_org = _OrganizationNeed(namespace, "creator")
self.namespace = namespace
self.namespace = namespace
super(CreateRepositoryPermission, self).__init__(admin_org, create_repo_org)
super(CreateRepositoryPermission, self).__init__(admin_org,
create_repo_org)
class SuperUserPermission(QuayPermission):
def __init__(self):
need = _SuperUserNeed()
super(SuperUserPermission, self).__init__(need)
def __init__(self):
need = _SuperUserNeed()
super(SuperUserPermission, self).__init__(need)
class UserAdminPermission(QuayPermission):
def __init__(self, username):
user_admin = _UserNeed(username, 'admin')
super(UserAdminPermission, self).__init__(user_admin)
def __init__(self, username):
user_admin = _UserNeed(username, "admin")
super(UserAdminPermission, self).__init__(user_admin)
class UserReadPermission(QuayPermission):
def __init__(self, username):
user_admin = _UserNeed(username, 'admin')
user_read = _UserNeed(username, 'read')
super(UserReadPermission, self).__init__(user_read, user_admin)
def __init__(self, username):
user_admin = _UserNeed(username, "admin")
user_read = _UserNeed(username, "read")
super(UserReadPermission, self).__init__(user_read, user_admin)
class AdministerOrganizationPermission(QuayPermission):
def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin')
def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name
self.namespace = org_name
super(AdministerOrganizationPermission, self).__init__(admin_org)
super(AdministerOrganizationPermission, self).__init__(admin_org)
class OrganizationMemberPermission(QuayPermission):
def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, 'admin')
repo_creator_org = _OrganizationNeed(org_name, 'creator')
org_member = _OrganizationNeed(org_name, 'member')
def __init__(self, org_name):
admin_org = _OrganizationNeed(org_name, "admin")
repo_creator_org = _OrganizationNeed(org_name, "creator")
org_member = _OrganizationNeed(org_name, "member")
self.namespace = org_name
self.namespace = org_name
super(OrganizationMemberPermission, self).__init__(admin_org, org_member,
repo_creator_org)
super(OrganizationMemberPermission, self).__init__(
admin_org, org_member, repo_creator_org
)
class ViewTeamPermission(QuayPermission):
def __init__(self, org_name, team_name):
team_admin = _TeamNeed(org_name, team_name, 'admin')
team_creator = _TeamNeed(org_name, team_name, 'creator')
team_member = _TeamNeed(org_name, team_name, 'member')
admin_org = _OrganizationNeed(org_name, 'admin')
def __init__(self, org_name, team_name):
team_admin = _TeamNeed(org_name, team_name, "admin")
team_creator = _TeamNeed(org_name, team_name, "creator")
team_member = _TeamNeed(org_name, team_name, "member")
admin_org = _OrganizationNeed(org_name, "admin")
self.namespace = org_name
self.namespace = org_name
super(ViewTeamPermission, self).__init__(team_admin, team_creator,
team_member, admin_org)
super(ViewTeamPermission, self).__init__(
team_admin, team_creator, team_member, admin_org
)
class AlwaysFailPermission(QuayPermission):
def can(self):
return False
def can(self):
return False
@identity_loaded.connect_via(app)
def on_identity_loaded(sender, identity):
logger.debug('Identity loaded: %s' % identity)
# We have verified an identity, load in all of the permissions
logger.debug("Identity loaded: %s" % identity)
# We have verified an identity, load in all of the permissions
if isinstance(identity, QuayDeferredPermissionUser):
logger.debug('Deferring permissions for user with uuid: %s', identity.id)
if isinstance(identity, QuayDeferredPermissionUser):
logger.debug("Deferring permissions for user with uuid: %s", identity.id)
elif identity.auth_type == 'user_uuid':
logger.debug('Switching username permission to deferred object with uuid: %s', identity.id)
switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
identity_changed.send(app, identity=switch_to_deferred)
elif identity.auth_type == "user_uuid":
logger.debug(
"Switching username permission to deferred object with uuid: %s",
identity.id,
)
switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
identity_changed.send(app, identity=switch_to_deferred)
elif identity.auth_type == 'token':
logger.debug('Loading permissions for token: %s', identity.id)
token_data = model.token.load_token_data(identity.id)
elif identity.auth_type == "token":
logger.debug("Loading permissions for token: %s", identity.id)
token_data = model.token.load_token_data(identity.id)
repo_grant = _RepositoryNeed(token_data.repository.namespace_user.username,
token_data.repository.name,
token_data.role.name)
logger.debug('Delegate token added permission: %s', repo_grant)
identity.provides.add(repo_grant)
repo_grant = _RepositoryNeed(
token_data.repository.namespace_user.username,
token_data.repository.name,
token_data.role.name,
)
logger.debug("Delegate token added permission: %s", repo_grant)
identity.provides.add(repo_grant)
elif identity.auth_type == 'signed_grant' or identity.auth_type == 'signed_jwt':
logger.debug('Loaded %s identity for: %s', identity.auth_type, identity.id)
elif identity.auth_type == "signed_grant" or identity.auth_type == "signed_jwt":
logger.debug("Loaded %s identity for: %s", identity.auth_type, identity.id)
else:
logger.error('Unknown identity auth type: %s', identity.auth_type)
else:
logger.error("Unknown identity auth type: %s", identity.auth_type)

View file

@ -9,156 +9,166 @@ from flask_principal import identity_changed, Identity
from app import app, get_app_url, instance_keys, metric_queue
from auth.auth_context import set_authenticated_context
from auth.auth_context_type import SignedAuthContext
from auth.permissions import repository_read_grant, repository_write_grant, repository_admin_grant
from auth.permissions import (
repository_read_grant,
repository_write_grant,
repository_admin_grant,
)
from util.http import abort
from util.names import parse_namespace_repository
from util.security.registry_jwt import (ANONYMOUS_SUB, decode_bearer_header,
InvalidBearerTokenException)
from util.security.registry_jwt import (
ANONYMOUS_SUB,
decode_bearer_header,
InvalidBearerTokenException,
)
logger = logging.getLogger(__name__)
ACCESS_SCHEMA = {
'type': 'array',
'description': 'List of access granted to the subject',
'items': {
'type': 'object',
'required': [
'type',
'name',
'actions',
],
'properties': {
'type': {
'type': 'string',
'description': 'We only allow repository permissions',
'enum': [
'repository',
],
},
'name': {
'type': 'string',
'description': 'The name of the repository for which we are receiving access'
},
'actions': {
'type': 'array',
'description': 'List of specific verbs which can be performed against repository',
'items': {
'type': 'string',
'enum': [
'push',
'pull',
'*',
],
"type": "array",
"description": "List of access granted to the subject",
"items": {
"type": "object",
"required": ["type", "name", "actions"],
"properties": {
"type": {
"type": "string",
"description": "We only allow repository permissions",
"enum": ["repository"],
},
"name": {
"type": "string",
"description": "The name of the repository for which we are receiving access",
},
"actions": {
"type": "array",
"description": "List of specific verbs which can be performed against repository",
"items": {"type": "string", "enum": ["push", "pull", "*"]},
},
},
},
},
},
}
class InvalidJWTException(Exception):
pass
pass
def get_auth_headers(repository=None, scopes=None):
""" Returns a dictionary of headers for auth responses. """
headers = {}
realm_auth_path = url_for('v2.generate_registry_jwt')
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(),
realm_auth_path,
app.config['SERVER_HOSTNAME'])
if repository:
scopes_string = "repository:{0}".format(repository)
if scopes:
scopes_string += ':' + ','.join(scopes)
""" Returns a dictionary of headers for auth responses. """
headers = {}
realm_auth_path = url_for("v2.generate_registry_jwt")
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(
get_app_url(), realm_auth_path, app.config["SERVER_HOSTNAME"]
)
if repository:
scopes_string = "repository:{0}".format(repository)
if scopes:
scopes_string += ":" + ",".join(scopes)
authenticate += ',scope="{0}"'.format(scopes_string)
authenticate += ',scope="{0}"'.format(scopes_string)
headers['WWW-Authenticate'] = authenticate
headers['Docker-Distribution-API-Version'] = 'registry/2.0'
return headers
headers["WWW-Authenticate"] = authenticate
headers["Docker-Distribution-API-Version"] = "registry/2.0"
return headers
def identity_from_bearer_token(bearer_header):
""" Process a bearer header and return the loaded identity, or raise InvalidJWTException if an
""" Process a bearer header and return the loaded identity, or raise InvalidJWTException if an
identity could not be loaded. Expects tokens and grants in the format of the Docker registry
v2 auth spec: https://docs.docker.com/registry/spec/auth/token/
"""
logger.debug('Validating auth header: %s', bearer_header)
logger.debug("Validating auth header: %s", bearer_header)
try:
payload = decode_bearer_header(bearer_header, instance_keys, app.config,
metric_queue=metric_queue)
except InvalidBearerTokenException as bte:
logger.exception('Invalid bearer token: %s', bte)
raise InvalidJWTException(bte)
loaded_identity = Identity(payload['sub'], 'signed_jwt')
# Process the grants from the payload
if 'access' in payload:
try:
validate(payload['access'], ACCESS_SCHEMA)
except ValidationError:
logger.exception('We should not be minting invalid credentials')
raise InvalidJWTException('Token contained invalid or malformed access grants')
payload = decode_bearer_header(
bearer_header, instance_keys, app.config, metric_queue=metric_queue
)
except InvalidBearerTokenException as bte:
logger.exception("Invalid bearer token: %s", bte)
raise InvalidJWTException(bte)
lib_namespace = app.config['LIBRARY_NAMESPACE']
for grant in payload['access']:
namespace, repo_name = parse_namespace_repository(grant['name'], lib_namespace)
loaded_identity = Identity(payload["sub"], "signed_jwt")
if '*' in grant['actions']:
loaded_identity.provides.add(repository_admin_grant(namespace, repo_name))
elif 'push' in grant['actions']:
loaded_identity.provides.add(repository_write_grant(namespace, repo_name))
elif 'pull' in grant['actions']:
loaded_identity.provides.add(repository_read_grant(namespace, repo_name))
# Process the grants from the payload
if "access" in payload:
try:
validate(payload["access"], ACCESS_SCHEMA)
except ValidationError:
logger.exception("We should not be minting invalid credentials")
raise InvalidJWTException(
"Token contained invalid or malformed access grants"
)
default_context = {
'kind': 'anonymous'
}
lib_namespace = app.config["LIBRARY_NAMESPACE"]
for grant in payload["access"]:
namespace, repo_name = parse_namespace_repository(
grant["name"], lib_namespace
)
if payload['sub'] != ANONYMOUS_SUB:
default_context = {
'kind': 'user',
'user': payload['sub'],
}
if "*" in grant["actions"]:
loaded_identity.provides.add(
repository_admin_grant(namespace, repo_name)
)
elif "push" in grant["actions"]:
loaded_identity.provides.add(
repository_write_grant(namespace, repo_name)
)
elif "pull" in grant["actions"]:
loaded_identity.provides.add(
repository_read_grant(namespace, repo_name)
)
return loaded_identity, payload.get('context', default_context)
default_context = {"kind": "anonymous"}
if payload["sub"] != ANONYMOUS_SUB:
default_context = {"kind": "user", "user": payload["sub"]}
return loaded_identity, payload.get("context", default_context)
def process_registry_jwt_auth(scopes=None):
""" Processes the registry JWT auth token found in the authorization header. If none found,
""" Processes the registry JWT auth token found in the authorization header. If none found,
no error is returned. If an invalid token is found, raises a 401.
"""
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger.debug('Called with params: %s, %s', args, kwargs)
auth = request.headers.get('authorization', '').strip()
if auth:
try:
extracted_identity, context_dict = identity_from_bearer_token(auth)
identity_changed.send(app, identity=extracted_identity)
logger.debug('Identity changed to %s', extracted_identity.id)
auth_context = SignedAuthContext.build_from_signed_dict(context_dict)
if auth_context is not None:
logger.debug('Auth context set to %s', auth_context.signed_data)
set_authenticated_context(auth_context)
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger.debug("Called with params: %s, %s", args, kwargs)
auth = request.headers.get("authorization", "").strip()
if auth:
try:
extracted_identity, context_dict = identity_from_bearer_token(auth)
identity_changed.send(app, identity=extracted_identity)
logger.debug("Identity changed to %s", extracted_identity.id)
except InvalidJWTException as ije:
repository = None
if 'namespace_name' in kwargs and 'repo_name' in kwargs:
repository = kwargs['namespace_name'] + '/' + kwargs['repo_name']
auth_context = SignedAuthContext.build_from_signed_dict(
context_dict
)
if auth_context is not None:
logger.debug("Auth context set to %s", auth_context.signed_data)
set_authenticated_context(auth_context)
abort(401, message=ije.message, headers=get_auth_headers(repository=repository,
scopes=scopes))
else:
logger.debug('No auth header.')
except InvalidJWTException as ije:
repository = None
if "namespace_name" in kwargs and "repo_name" in kwargs:
repository = (
kwargs["namespace_name"] + "/" + kwargs["repo_name"]
)
return func(*args, **kwargs)
return wrapper
return inner
abort(
401,
message=ije.message,
headers=get_auth_headers(repository=repository, scopes=scopes),
)
else:
logger.debug("No auth header.")
return func(*args, **kwargs)
return wrapper
return inner

View file

@ -2,145 +2,194 @@ from collections import namedtuple
import features
import re
Scope = namedtuple('scope', ['scope', 'icon', 'dangerous', 'title', 'description'])
Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"])
READ_REPO = Scope(scope='repo:read',
icon='fa-hdd-o',
dangerous=False,
title='View all visible repositories',
description=('This application will be able to view and pull all repositories '
'visible to the granting user or robot account'))
READ_REPO = Scope(
scope="repo:read",
icon="fa-hdd-o",
dangerous=False,
title="View all visible repositories",
description=(
"This application will be able to view and pull all repositories "
"visible to the granting user or robot account"
),
)
WRITE_REPO = Scope(scope='repo:write',
icon='fa-hdd-o',
dangerous=False,
title='Read/Write to any accessible repositories',
description=('This application will be able to view, push and pull to all '
'repositories to which the granting user or robot account has '
'write access'))
WRITE_REPO = Scope(
scope="repo:write",
icon="fa-hdd-o",
dangerous=False,
title="Read/Write to any accessible repositories",
description=(
"This application will be able to view, push and pull to all "
"repositories to which the granting user or robot account has "
"write access"
),
)
ADMIN_REPO = Scope(scope='repo:admin',
icon='fa-hdd-o',
dangerous=False,
title='Administer Repositories',
description=('This application will have administrator access to all '
'repositories to which the granting user or robot account has '
'access'))
ADMIN_REPO = Scope(
scope="repo:admin",
icon="fa-hdd-o",
dangerous=False,
title="Administer Repositories",
description=(
"This application will have administrator access to all "
"repositories to which the granting user or robot account has "
"access"
),
)
CREATE_REPO = Scope(scope='repo:create',
icon='fa-plus',
dangerous=False,
title='Create Repositories',
description=('This application will be able to create repositories in to any '
'namespaces that the granting user or robot account is allowed '
'to create repositories'))
CREATE_REPO = Scope(
scope="repo:create",
icon="fa-plus",
dangerous=False,
title="Create Repositories",
description=(
"This application will be able to create repositories in to any "
"namespaces that the granting user or robot account is allowed "
"to create repositories"
),
)
READ_USER = Scope(scope= 'user:read',
icon='fa-user',
dangerous=False,
title='Read User Information',
description=('This application will be able to read user information such as '
'username and email address.'))
READ_USER = Scope(
scope="user:read",
icon="fa-user",
dangerous=False,
title="Read User Information",
description=(
"This application will be able to read user information such as "
"username and email address."
),
)
ADMIN_USER = Scope(scope= 'user:admin',
icon='fa-gear',
dangerous=True,
title='Administer User',
description=('This application will be able to administer your account '
'including creating robots and granting them permissions '
'to your repositories. You should have absolute trust in the '
'requesting application before granting this permission.'))
ADMIN_USER = Scope(
scope="user:admin",
icon="fa-gear",
dangerous=True,
title="Administer User",
description=(
"This application will be able to administer your account "
"including creating robots and granting them permissions "
"to your repositories. You should have absolute trust in the "
"requesting application before granting this permission."
),
)
ORG_ADMIN = Scope(scope='org:admin',
icon='fa-gear',
dangerous=True,
title='Administer Organization',
description=('This application will be able to administer your organizations '
'including creating robots, creating teams, adjusting team '
'membership, and changing billing settings. You should have '
'absolute trust in the requesting application before granting this '
'permission.'))
ORG_ADMIN = Scope(
scope="org:admin",
icon="fa-gear",
dangerous=True,
title="Administer Organization",
description=(
"This application will be able to administer your organizations "
"including creating robots, creating teams, adjusting team "
"membership, and changing billing settings. You should have "
"absolute trust in the requesting application before granting this "
"permission."
),
)
DIRECT_LOGIN = Scope(scope='direct_user_login',
icon='fa-exclamation-triangle',
dangerous=True,
title='Full Access',
description=('This scope should not be available to OAuth applications. '
'Never approve a request for this scope!'))
DIRECT_LOGIN = Scope(
scope="direct_user_login",
icon="fa-exclamation-triangle",
dangerous=True,
title="Full Access",
description=(
"This scope should not be available to OAuth applications. "
"Never approve a request for this scope!"
),
)
SUPERUSER = Scope(scope='super:user',
icon='fa-street-view',
dangerous=True,
title='Super User Access',
description=('This application will be able to administer your installation '
'including managing users, managing organizations and other '
'features found in the superuser panel. You should have '
'absolute trust in the requesting application before granting this '
'permission.'))
SUPERUSER = Scope(
scope="super:user",
icon="fa-street-view",
dangerous=True,
title="Super User Access",
description=(
"This application will be able to administer your installation "
"including managing users, managing organizations and other "
"features found in the superuser panel. You should have "
"absolute trust in the requesting application before granting this "
"permission."
),
)
ALL_SCOPES = {scope.scope: scope for scope in (READ_REPO, WRITE_REPO, ADMIN_REPO, CREATE_REPO,
READ_USER, ORG_ADMIN, SUPERUSER, ADMIN_USER)}
ALL_SCOPES = {
scope.scope: scope
for scope in (
READ_REPO,
WRITE_REPO,
ADMIN_REPO,
CREATE_REPO,
READ_USER,
ORG_ADMIN,
SUPERUSER,
ADMIN_USER,
)
}
IMPLIED_SCOPES = {
ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
WRITE_REPO: {WRITE_REPO, READ_REPO},
READ_REPO: {READ_REPO},
CREATE_REPO: {CREATE_REPO},
READ_USER: {READ_USER},
ORG_ADMIN: {ORG_ADMIN},
SUPERUSER: {SUPERUSER},
ADMIN_USER: {ADMIN_USER},
None: set(),
ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
WRITE_REPO: {WRITE_REPO, READ_REPO},
READ_REPO: {READ_REPO},
CREATE_REPO: {CREATE_REPO},
READ_USER: {READ_USER},
ORG_ADMIN: {ORG_ADMIN},
SUPERUSER: {SUPERUSER},
ADMIN_USER: {ADMIN_USER},
None: set(),
}
def app_scopes(app_config):
scopes_from_config = dict(ALL_SCOPES)
if not app_config.get('FEATURE_SUPER_USERS', False):
del scopes_from_config[SUPERUSER.scope]
return scopes_from_config
scopes_from_config = dict(ALL_SCOPES)
if not app_config.get("FEATURE_SUPER_USERS", False):
del scopes_from_config[SUPERUSER.scope]
return scopes_from_config
def scopes_from_scope_string(scopes):
if not scopes:
scopes = ''
if not scopes:
scopes = ""
# Note: The scopes string should be space seperated according to the spec:
# https://tools.ietf.org/html/rfc6749#section-3.3
# However, we also support commas for backwards compatibility with existing callers to our code.
scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(' |,', scopes)}
return scope_set if not None in scope_set else set()
# Note: The scopes string should be space seperated according to the spec:
# https://tools.ietf.org/html/rfc6749#section-3.3
# However, we also support commas for backwards compatibility with existing callers to our code.
scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(" |,", scopes)}
return scope_set if not None in scope_set else set()
def validate_scope_string(scopes):
decoded = scopes_from_scope_string(scopes)
return len(decoded) > 0
decoded = scopes_from_scope_string(scopes)
return len(decoded) > 0
def is_subset_string(full_string, expected_string):
""" Returns true if the scopes found in expected_string are also found
""" Returns true if the scopes found in expected_string are also found
in full_string.
"""
full_scopes = scopes_from_scope_string(full_string)
if not full_scopes:
return False
full_scopes = scopes_from_scope_string(full_string)
if not full_scopes:
return False
full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes])
expected_scopes = scopes_from_scope_string(expected_string)
return expected_scopes.issubset(full_implied_scopes)
full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes])
expected_scopes = scopes_from_scope_string(expected_string)
return expected_scopes.issubset(full_implied_scopes)
def get_scope_information(scopes_string):
scopes = scopes_from_scope_string(scopes_string)
scope_info = []
for scope in scopes:
scope_info.append({
'title': scope.title,
'scope': scope.scope,
'description': scope.description,
'icon': scope.icon,
'dangerous': scope.dangerous,
})
scopes = scopes_from_scope_string(scopes_string)
scope_info = []
for scope in scopes:
scope_info.append(
{
"title": scope.title,
"scope": scope.scope,
"description": scope.description,
"icon": scope.icon,
"dangerous": scope.dangerous,
}
)
return scope_info
return scope_info

View file

@ -8,48 +8,49 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__)
# The prefix for all signatures of signed granted.
SIGNATURE_PREFIX = 'sigv2='
SIGNATURE_PREFIX = "sigv2="
def generate_signed_token(grants, user_context):
""" Generates a signed session token with the given grants and user context. """
ser = SecureCookieSessionInterface().get_signing_serializer(app)
data_to_sign = {
'grants': grants,
'user_context': user_context,
}
""" Generates a signed session token with the given grants and user context. """
ser = SecureCookieSessionInterface().get_signing_serializer(app)
data_to_sign = {"grants": grants, "user_context": user_context}
encrypted = ser.dumps(data_to_sign)
return '{0}{1}'.format(SIGNATURE_PREFIX, encrypted)
encrypted = ser.dumps(data_to_sign)
return "{0}{1}".format(SIGNATURE_PREFIX, encrypted)
def validate_signed_grant(auth_header):
""" Validates a signed grant as found inside an auth header and returns whether it points to
""" Validates a signed grant as found inside an auth header and returns whether it points to
a valid grant.
"""
if not auth_header:
return ValidateResult(AuthKind.signed_grant, missing=True)
if not auth_header:
return ValidateResult(AuthKind.signed_grant, missing=True)
# Try to parse the token from the header.
normalized = [part.strip() for part in auth_header.split(' ') if part]
if normalized[0].lower() != 'token' or len(normalized) != 2:
logger.debug('Not a token: %s', auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True)
# Try to parse the token from the header.
normalized = [part.strip() for part in auth_header.split(" ") if part]
if normalized[0].lower() != "token" or len(normalized) != 2:
logger.debug("Not a token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True)
# Check that it starts with the expected prefix.
if not normalized[1].startswith(SIGNATURE_PREFIX):
logger.debug('Not a signed grant token: %s', auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True)
# Check that it starts with the expected prefix.
if not normalized[1].startswith(SIGNATURE_PREFIX):
logger.debug("Not a signed grant token: %s", auth_header)
return ValidateResult(AuthKind.signed_grant, missing=True)
# Decrypt the grant.
encrypted = normalized[1][len(SIGNATURE_PREFIX):]
ser = SecureCookieSessionInterface().get_signing_serializer(app)
# Decrypt the grant.
encrypted = normalized[1][len(SIGNATURE_PREFIX) :]
ser = SecureCookieSessionInterface().get_signing_serializer(app)
try:
token_data = ser.loads(encrypted, max_age=app.config['SIGNED_GRANT_EXPIRATION_SEC'])
except BadSignature:
logger.warning('Signed grant could not be validated: %s', encrypted)
return ValidateResult(AuthKind.signed_grant,
error_message='Signed grant could not be validated')
try:
token_data = ser.loads(
encrypted, max_age=app.config["SIGNED_GRANT_EXPIRATION_SEC"]
)
except BadSignature:
logger.warning("Signed grant could not be validated: %s", encrypted)
return ValidateResult(
AuthKind.signed_grant, error_message="Signed grant could not be validated"
)
logger.debug('Successfully validated signed grant with data: %s', token_data)
return ValidateResult(AuthKind.signed_grant, signed_data=token_data)
logger.debug("Successfully validated signed grant with data: %s", token_data)
return ValidateResult(AuthKind.signed_grant, signed_data=token_data)

View file

@ -1,51 +1,65 @@
import pytest
from auth.auth_context_type import SignedAuthContext, ValidatedAuthContext, ContextEntityKind
from auth.auth_context_type import (
SignedAuthContext,
ValidatedAuthContext,
ContextEntityKind,
)
from data import model, database
from test.fixtures import *
def get_oauth_token(_):
return database.OAuthAccessToken.get()
return database.OAuthAccessToken.get()
@pytest.mark.parametrize('kind, entity_reference, loader', [
(ContextEntityKind.anonymous, None, None),
(ContextEntityKind.appspecifictoken, '%s%s' % ('a' * 60, 'b' * 60),
model.appspecifictoken.access_valid_token),
(ContextEntityKind.oauthtoken, None, get_oauth_token),
(ContextEntityKind.robot, 'devtable+dtrobot', model.user.lookup_robot),
(ContextEntityKind.user, 'devtable', model.user.get_user),
])
@pytest.mark.parametrize('v1_dict_format', [
(True),
(False),
])
def test_signed_auth_context(kind, entity_reference, loader, v1_dict_format, initialized_db):
if kind == ContextEntityKind.anonymous:
validated = ValidatedAuthContext()
assert validated.is_anonymous
else:
ref = loader(entity_reference)
validated = ValidatedAuthContext(**{kind.value: ref})
assert not validated.is_anonymous
@pytest.mark.parametrize(
"kind, entity_reference, loader",
[
(ContextEntityKind.anonymous, None, None),
(
ContextEntityKind.appspecifictoken,
"%s%s" % ("a" * 60, "b" * 60),
model.appspecifictoken.access_valid_token,
),
(ContextEntityKind.oauthtoken, None, get_oauth_token),
(ContextEntityKind.robot, "devtable+dtrobot", model.user.lookup_robot),
(ContextEntityKind.user, "devtable", model.user.get_user),
],
)
@pytest.mark.parametrize("v1_dict_format", [(True), (False)])
def test_signed_auth_context(
kind, entity_reference, loader, v1_dict_format, initialized_db
):
if kind == ContextEntityKind.anonymous:
validated = ValidatedAuthContext()
assert validated.is_anonymous
else:
ref = loader(entity_reference)
validated = ValidatedAuthContext(**{kind.value: ref})
assert not validated.is_anonymous
assert validated.entity_kind == kind
assert validated.unique_key
assert validated.entity_kind == kind
assert validated.unique_key
signed = SignedAuthContext.build_from_signed_dict(validated.to_signed_dict(),
v1_dict_format=v1_dict_format)
signed = SignedAuthContext.build_from_signed_dict(
validated.to_signed_dict(), v1_dict_format=v1_dict_format
)
if not v1_dict_format:
# Under legacy V1 format, we don't track the app specific token, merely its associated user.
assert signed.entity_kind == kind
assert signed.description == validated.description
assert signed.credential_username == validated.credential_username
assert signed.analytics_id_and_public_metadata() == validated.analytics_id_and_public_metadata()
assert signed.unique_key == validated.unique_key
if not v1_dict_format:
# Under legacy V1 format, we don't track the app specific token, merely its associated user.
assert signed.entity_kind == kind
assert signed.description == validated.description
assert signed.credential_username == validated.credential_username
assert (
signed.analytics_id_and_public_metadata()
== validated.analytics_id_and_public_metadata()
)
assert signed.unique_key == validated.unique_key
assert signed.is_anonymous == validated.is_anonymous
assert signed.authed_user == validated.authed_user
assert signed.has_nonrobot_user == validated.has_nonrobot_user
assert signed.is_anonymous == validated.is_anonymous
assert signed.authed_user == validated.authed_user
assert signed.has_nonrobot_user == validated.has_nonrobot_user
assert signed.to_signed_dict() == validated.to_signed_dict()
assert signed.to_signed_dict() == validated.to_signed_dict()

View file

@ -5,8 +5,11 @@ import pytest
from base64 import b64encode
from auth.basic import validate_basic_auth
from auth.credentials import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME)
from auth.credentials import (
ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult
from data import model
@ -14,85 +17,120 @@ from test.fixtures import *
def _token(username, password):
assert isinstance(username, basestring)
assert isinstance(password, basestring)
return 'basic ' + b64encode('%s:%s' % (username, password))
assert isinstance(username, basestring)
assert isinstance(password, basestring)
return "basic " + b64encode("%s:%s" % (username, password))
@pytest.mark.parametrize('token, expected_result', [
('', ValidateResult(AuthKind.basic, missing=True)),
('someinvalidtoken', ValidateResult(AuthKind.basic, missing=True)),
('somefoobartoken', ValidateResult(AuthKind.basic, missing=True)),
('basic ', ValidateResult(AuthKind.basic, missing=True)),
('basic some token', ValidateResult(AuthKind.basic, missing=True)),
('basic sometoken', ValidateResult(AuthKind.basic, missing=True)),
(_token(APP_SPECIFIC_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic,
error_message='Invalid token')),
(_token(ACCESS_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic,
error_message='Invalid access token')),
(_token(OAUTH_TOKEN_USERNAME, 'invalid'),
ValidateResult(AuthKind.basic, error_message='OAuth access token could not be validated')),
(_token('devtable', 'invalid'), ValidateResult(AuthKind.basic,
error_message='Invalid Username or Password')),
(_token('devtable+somebot', 'invalid'), ValidateResult(
AuthKind.basic, error_message='Could not find robot with username: devtable+somebot')),
(_token('disabled', 'password'), ValidateResult(
AuthKind.basic,
error_message='This user has been disabled. Please contact your administrator.')),])
@pytest.mark.parametrize(
"token, expected_result",
[
("", ValidateResult(AuthKind.basic, missing=True)),
("someinvalidtoken", ValidateResult(AuthKind.basic, missing=True)),
("somefoobartoken", ValidateResult(AuthKind.basic, missing=True)),
("basic ", ValidateResult(AuthKind.basic, missing=True)),
("basic some token", ValidateResult(AuthKind.basic, missing=True)),
("basic sometoken", ValidateResult(AuthKind.basic, missing=True)),
(
_token(APP_SPECIFIC_TOKEN_USERNAME, "invalid"),
ValidateResult(AuthKind.basic, error_message="Invalid token"),
),
(
_token(ACCESS_TOKEN_USERNAME, "invalid"),
ValidateResult(AuthKind.basic, error_message="Invalid access token"),
),
(
_token(OAUTH_TOKEN_USERNAME, "invalid"),
ValidateResult(
AuthKind.basic,
error_message="OAuth access token could not be validated",
),
),
(
_token("devtable", "invalid"),
ValidateResult(
AuthKind.basic, error_message="Invalid Username or Password"
),
),
(
_token("devtable+somebot", "invalid"),
ValidateResult(
AuthKind.basic,
error_message="Could not find robot with username: devtable+somebot",
),
),
(
_token("disabled", "password"),
ValidateResult(
AuthKind.basic,
error_message="This user has been disabled. Please contact your administrator.",
),
),
],
)
def test_validate_basic_auth_token(token, expected_result, app):
result = validate_basic_auth(token)
assert result == expected_result
result = validate_basic_auth(token)
assert result == expected_result
def test_valid_user(app):
token = _token('devtable', 'password')
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, user=model.user.get_user('devtable'))
token = _token("devtable", "password")
result = validate_basic_auth(token)
assert result == ValidateResult(
AuthKind.basic, user=model.user.get_user("devtable")
)
def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable'))
token = _token(robot.username, password)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, robot=robot)
robot, password = model.user.create_robot(
"somerobot", model.user.get_user("devtable")
)
token = _token(robot.username, password)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, robot=robot)
def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken')
token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, token=access_token)
access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, token=access_token)
def test_valid_oauth(app):
user = model.user.get_user('devtable')
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read')
token = _token(OAUTH_TOKEN_USERNAME, code)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(
model.user.get_user_or_org("buynlarge")
)[0]
oauth_token, code = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read"
)
token = _token(OAUTH_TOKEN_USERNAME, code)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
def test_valid_app_specific_token(app):
user = model.user.get_user('devtable')
app_specific_token = model.appspecifictoken.create_token(user, 'some token')
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token)
user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
result = validate_basic_auth(token)
assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token)
def test_invalid_unicode(app):
token = '\xebOH'
header = 'basic ' + b64encode(token)
result = validate_basic_auth(header)
assert result == ValidateResult(AuthKind.basic, missing=True)
token = "\xebOH"
header = "basic " + b64encode(token)
result = validate_basic_auth(header)
assert result == ValidateResult(AuthKind.basic, missing=True)
def test_invalid_unicode_2(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
header = 'basic ' + b64encode('devtable+somerobot:%s' % token)
result = validate_basic_auth(header)
assert result == ValidateResult(
AuthKind.basic,
error_message='Could not find robot with username: devtable+somerobot and supplied password.')
token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
header = "basic " + b64encode("devtable+somerobot:%s" % token)
result = validate_basic_auth(header)
assert result == ValidateResult(
AuthKind.basic,
error_message="Could not find robot with username: devtable+somerobot and supplied password.",
)

View file

@ -9,58 +9,58 @@ from test.fixtures import *
def test_anonymous_cookie(app):
assert validate_session_cookie().missing
assert validate_session_cookie().missing
def test_invalidformatted_cookie(app):
# "Login" with a non-UUID reference.
someuser = model.user.get_user('devtable')
login_user(LoginWrappedDBUser('somenonuuid', someuser))
# "Login" with a non-UUID reference.
someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser("somenonuuid", someuser))
# Ensure we get an invalid session cookie format error.
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == 'Invalid session cookie format'
# Ensure we get an invalid session cookie format error.
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == "Invalid session cookie format"
def test_disabled_user(app):
# "Login" with a disabled user.
someuser = model.user.get_user('disabled')
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# "Login" with a disabled user.
someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Ensure we get an invalid session cookie format error.
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == 'User account is disabled'
# Ensure we get an invalid session cookie format error.
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == "User account is disabled"
def test_valid_user(app):
# Login with a valid user.
someuser = model.user.get_user('devtable')
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Login with a valid user.
someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
result = validate_session_cookie()
assert result.authed_user == someuser
assert result.context.identity is not None
assert result.has_nonrobot_user
assert result.error_message is None
result = validate_session_cookie()
assert result.authed_user == someuser
assert result.context.identity is not None
assert result.has_nonrobot_user
assert result.error_message is None
def test_valid_organization(app):
# "Login" with a valid organization.
someorg = model.user.get_namespace_user('buynlarge')
someorg.uuid = str(uuid.uuid4())
someorg.verified = True
someorg.save()
# "Login" with a valid organization.
someorg = model.user.get_namespace_user("buynlarge")
someorg.uuid = str(uuid.uuid4())
someorg.verified = True
someorg.save()
login_user(LoginWrappedDBUser(someorg.uuid, someorg))
login_user(LoginWrappedDBUser(someorg.uuid, someorg))
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == 'Cannot login to organization'
result = validate_session_cookie()
assert result.authed_user is None
assert result.context.identity is None
assert not result.has_nonrobot_user
assert result.error_message == "Cannot login to organization"

View file

@ -1,147 +1,184 @@
# -*- coding: utf-8 -*-
from auth.credentials import validate_credentials, CredentialKind
from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME)
from auth.credential_consts import (
ACCESS_TOKEN_USERNAME,
OAUTH_TOKEN_USERNAME,
APP_SPECIFIC_TOKEN_USERNAME,
)
from auth.validateresult import AuthKind, ValidateResult
from data import model
from test.fixtures import *
def test_valid_user(app):
result, kind = validate_credentials('devtable', 'password')
assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials, user=model.user.get_user('devtable'))
result, kind = validate_credentials("devtable", "password")
assert kind == CredentialKind.user
assert result == ValidateResult(
AuthKind.credentials, user=model.user.get_user("devtable")
)
def test_valid_robot(app):
robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable'))
result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot
assert result == ValidateResult(AuthKind.credentials, robot=robot)
robot, password = model.user.create_robot(
"somerobot", model.user.get_user("devtable")
)
result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot
assert result == ValidateResult(AuthKind.credentials, robot=robot)
def test_valid_robot_for_disabled_user(app):
user = model.user.get_user('devtable')
user.enabled = False
user.save()
user = model.user.get_user("devtable")
user.enabled = False
user.save()
robot, password = model.user.create_robot('somerobot', user)
result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot
robot, password = model.user.create_robot("somerobot", user)
result, kind = validate_credentials(robot.username, password)
assert kind == CredentialKind.robot
err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err)
err = 'This user has been disabled. Please contact your administrator.'
assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_valid_token(app):
access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken')
result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
assert kind == CredentialKind.token
assert result == ValidateResult(AuthKind.credentials, token=access_token)
access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
assert kind == CredentialKind.token
assert result == ValidateResult(AuthKind.credentials, token=access_token)
def test_valid_oauth(app):
user = model.user.get_user('devtable')
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read')
result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
assert kind == CredentialKind.oauth_token
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(
model.user.get_user_or_org("buynlarge")
)[0]
oauth_token, code = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read"
)
result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
assert kind == CredentialKind.oauth_token
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
def test_invalid_user(app):
result, kind = validate_credentials('devtable', 'somepassword')
assert kind == CredentialKind.user
assert result == ValidateResult(AuthKind.credentials,
error_message='Invalid Username or Password')
result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user
assert result == ValidateResult(
AuthKind.credentials, error_message="Invalid Username or Password"
)
def test_valid_app_specific_token(app):
user = model.user.get_user('devtable')
app_specific_token = model.appspecifictoken.create_token(user, 'some token')
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, appspecifictoken=app_specific_token)
user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(
AuthKind.credentials, appspecifictoken=app_specific_token
)
def test_valid_app_specific_token_for_disabled_user(app):
user = model.user.get_user('devtable')
user.enabled = False
user.save()
user = model.user.get_user("devtable")
user.enabled = False
user.save()
app_specific_token = model.appspecifictoken.create_token(user, 'some token')
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
err = 'This user has been disabled. Please contact your administrator.'
assert result == ValidateResult(AuthKind.credentials, error_message=err)
err = "This user has been disabled. Please contact your administrator."
assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_invalid_app_specific_token(app):
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, "somecode")
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_invalid_app_specific_token(app):
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, 'somecode')
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token')
def test_invalid_app_specific_token_code(app):
user = model.user.get_user('devtable')
app_specific_token = model.appspecifictoken.create_token(user, 'some token')
full_token = app_specific_token.token_name + 'something'
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token')
user = model.user.get_user("devtable")
app_specific_token = model.appspecifictoken.create_token(user, "some token")
full_token = app_specific_token.token_name + "something"
result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
assert kind == CredentialKind.app_specific_token
assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
def test_unicode(app):
result, kind = validate_credentials('someusername', 'some₪code')
assert kind == CredentialKind.user
assert not result.auth_valid
assert result == ValidateResult(AuthKind.credentials,
error_message='Invalid Username or Password')
def test_unicode_robot(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable'))
result, kind = validate_credentials(robot.username, 'some₪code')
def test_unicode(app):
result, kind = validate_credentials("someusername", "some₪code")
assert kind == CredentialKind.user
assert not result.auth_valid
assert result == ValidateResult(
AuthKind.credentials, error_message="Invalid Username or Password"
)
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot and supplied password.'
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_unicode_robot(app):
robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, "some₪code")
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = (
"Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_user(app):
result, kind = validate_credentials('someinvaliduser', 'password')
assert kind == CredentialKind.user
assert not result.authed_user
assert not result.auth_valid
result, kind = validate_credentials("someinvaliduser", "password")
assert kind == CredentialKind.user
assert not result.authed_user
assert not result.auth_valid
def test_invalid_user_password(app):
result, kind = validate_credentials('devtable', 'somepassword')
assert kind == CredentialKind.user
assert not result.authed_user
assert not result.auth_valid
result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user
assert not result.authed_user
assert not result.auth_valid
def test_invalid_robot(app):
result, kind = validate_credentials('devtable+doesnotexist', 'password')
assert kind == CredentialKind.robot
assert not result.authed_user
assert not result.auth_valid
result, kind = validate_credentials("devtable+doesnotexist", "password")
assert kind == CredentialKind.robot
assert not result.authed_user
assert not result.auth_valid
def test_invalid_robot_token(app):
robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable'))
result, kind = validate_credentials(robot.username, 'invalidpassword')
assert kind == CredentialKind.robot
assert not result.authed_user
assert not result.auth_valid
robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
result, kind = validate_credentials(robot.username, "invalidpassword")
assert kind == CredentialKind.robot
assert not result.authed_user
assert not result.auth_valid
def test_invalid_unicode_robot(app):
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
result, kind = validate_credentials('devtable+somerobot', token)
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot'
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = "Could not find robot with username: devtable+somerobot"
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
def test_invalid_unicode_robot_2(app):
user = model.user.get_user('devtable')
robot, password = model.user.create_robot('somerobot', user)
user = model.user.get_user("devtable")
robot, password = model.user.create_robot("somerobot", user)
token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
result, kind = validate_credentials('devtable+somerobot', token)
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = 'Could not find robot with username: devtable+somerobot and supplied password.'
assert result == ValidateResult(AuthKind.credentials, error_message=msg)
token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
result, kind = validate_credentials("devtable+somerobot", token)
assert kind == CredentialKind.robot
assert not result.auth_valid
msg = (
"Could not find robot with username: devtable+somerobot and supplied password."
)
assert result == ValidateResult(AuthKind.credentials, error_message=msg)

View file

@ -7,99 +7,102 @@ from werkzeug.exceptions import HTTPException
from app import LoginWrappedDBUser
from auth.auth_context import get_authenticated_user
from auth.decorators import (
extract_namespace_repo_from_session, require_session_login, process_auth_or_cookie)
extract_namespace_repo_from_session,
require_session_login,
process_auth_or_cookie,
)
from data import model
from test.fixtures import *
def test_extract_namespace_repo_from_session_missing(app):
def emptyfunc():
pass
def emptyfunc():
pass
session.clear()
with pytest.raises(HTTPException):
extract_namespace_repo_from_session(emptyfunc)()
session.clear()
with pytest.raises(HTTPException):
extract_namespace_repo_from_session(emptyfunc)()
def test_extract_namespace_repo_from_session_present(app):
encountered = []
encountered = []
def somefunc(namespace, repository):
encountered.append(namespace)
encountered.append(repository)
def somefunc(namespace, repository):
encountered.append(namespace)
encountered.append(repository)
# Add the namespace and repository to the session.
session.clear()
session['namespace'] = 'foo'
session['repository'] = 'bar'
# Add the namespace and repository to the session.
session.clear()
session["namespace"] = "foo"
session["repository"] = "bar"
# Call the decorated method.
extract_namespace_repo_from_session(somefunc)()
# Call the decorated method.
extract_namespace_repo_from_session(somefunc)()
assert encountered[0] == 'foo'
assert encountered[1] == 'bar'
assert encountered[0] == "foo"
assert encountered[1] == "bar"
def test_require_session_login_missing(app):
def emptyfunc():
pass
def emptyfunc():
pass
with pytest.raises(HTTPException):
require_session_login(emptyfunc)()
with pytest.raises(HTTPException):
require_session_login(emptyfunc)()
def test_require_session_login_valid_user(app):
def emptyfunc():
pass
def emptyfunc():
pass
# Login as a valid user.
someuser = model.user.get_user('devtable')
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Login as a valid user.
someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function.
require_session_login(emptyfunc)()
# Call the function.
require_session_login(emptyfunc)()
# Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser
# Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser
def test_require_session_login_invalid_user(app):
def emptyfunc():
pass
def emptyfunc():
pass
# "Login" as a disabled user.
someuser = model.user.get_user('disabled')
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# "Login" as a disabled user.
someuser = model.user.get_user("disabled")
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function.
with pytest.raises(HTTPException):
require_session_login(emptyfunc)()
# Call the function.
with pytest.raises(HTTPException):
require_session_login(emptyfunc)()
# Ensure the authenticated user was not updated.
assert get_authenticated_user() is None
# Ensure the authenticated user was not updated.
assert get_authenticated_user() is None
def test_process_auth_or_cookie_invalid_user(app):
def emptyfunc():
pass
def emptyfunc():
pass
# Call the function.
process_auth_or_cookie(emptyfunc)()
# Call the function.
process_auth_or_cookie(emptyfunc)()
# Ensure the authenticated user was not updated.
assert get_authenticated_user() is None
# Ensure the authenticated user was not updated.
assert get_authenticated_user() is None
def test_process_auth_or_cookie_valid_user(app):
def emptyfunc():
pass
def emptyfunc():
pass
# Login as a valid user.
someuser = model.user.get_user('devtable')
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Login as a valid user.
someuser = model.user.get_user("devtable")
login_user(LoginWrappedDBUser(someuser.uuid, someuser))
# Call the function.
process_auth_or_cookie(emptyfunc)()
# Call the function.
process_auth_or_cookie(emptyfunc)()
# Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser
# Ensure the authenticated user was updated.
assert get_authenticated_user() == someuser

View file

@ -6,50 +6,63 @@ from data import model
from test.fixtures import *
@pytest.mark.parametrize('header, expected_result', [
('', ValidateResult(AuthKind.oauth, missing=True)),
('somerandomtoken', ValidateResult(AuthKind.oauth, missing=True)),
('bearer some random token', ValidateResult(AuthKind.oauth, missing=True)),
('bearer invalidtoken',
ValidateResult(AuthKind.oauth, error_message='OAuth access token could not be validated')),])
@pytest.mark.parametrize(
"header, expected_result",
[
("", ValidateResult(AuthKind.oauth, missing=True)),
("somerandomtoken", ValidateResult(AuthKind.oauth, missing=True)),
("bearer some random token", ValidateResult(AuthKind.oauth, missing=True)),
(
"bearer invalidtoken",
ValidateResult(
AuthKind.oauth,
error_message="OAuth access token could not be validated",
),
),
],
)
def test_bearer(header, expected_result, app):
assert validate_bearer_auth(header) == expected_result
assert validate_bearer_auth(header) == expected_result
def test_valid_oauth(app):
user = model.user.get_user('devtable')
app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
token_string = '%s%s' % ('a' * 20, 'b' * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read',
access_token=token_string)
result = validate_bearer_auth('bearer ' + token_string)
assert result.context.oauthtoken == oauth_token
assert result.authed_user == user
assert result.auth_valid
user = model.user.get_user("devtable")
app = model.oauth.list_applications_for_org(
model.user.get_user_or_org("buynlarge")
)[0]
token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(
user, app.client_id, "repo:read", access_token=token_string
)
result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken == oauth_token
assert result.authed_user == user
assert result.auth_valid
def test_disabled_user_oauth(app):
user = model.user.get_user('disabled')
token_string = '%s%s' % ('a' * 20, 'b' * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin',
access_token=token_string)
user = model.user.get_user("disabled")
token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(
user, "deadbeef", "repo:admin", access_token=token_string
)
result = validate_bearer_auth('bearer ' + token_string)
assert result.context.oauthtoken is None
assert result.authed_user is None
assert not result.auth_valid
assert result.error_message == 'Granter of the oauth access token is disabled'
result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None
assert result.authed_user is None
assert not result.auth_valid
assert result.error_message == "Granter of the oauth access token is disabled"
def test_expired_token(app):
user = model.user.get_user('devtable')
token_string = '%s%s' % ('a' * 20, 'b' * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin',
access_token=token_string,
expires_in=-1000)
user = model.user.get_user("devtable")
token_string = "%s%s" % ("a" * 20, "b" * 20)
oauth_token, _ = model.oauth.create_access_token_for_testing(
user, "deadbeef", "repo:admin", access_token=token_string, expires_in=-1000
)
result = validate_bearer_auth('bearer ' + token_string)
assert result.context.oauthtoken is None
assert result.authed_user is None
assert not result.auth_valid
assert result.error_message == 'OAuth access token has expired'
result = validate_bearer_auth("bearer " + token_string)
assert result.context.oauthtoken is None
assert result.authed_user is None
assert not result.auth_valid
assert result.error_message == "OAuth access token has expired"

View file

@ -6,32 +6,33 @@ from data import model
from test.fixtures import *
SUPER_USERNAME = 'devtable'
UNSUPER_USERNAME = 'freshuser'
SUPER_USERNAME = "devtable"
UNSUPER_USERNAME = "freshuser"
@pytest.fixture()
def superuser(initialized_db):
return model.user.get_user(SUPER_USERNAME)
return model.user.get_user(SUPER_USERNAME)
@pytest.fixture()
def normie(initialized_db):
return model.user.get_user(UNSUPER_USERNAME)
return model.user.get_user(UNSUPER_USERNAME)
def test_superuser_matrix(superuser, normie):
test_cases = [
(superuser, {scopes.SUPERUSER}, True),
(superuser, {scopes.DIRECT_LOGIN}, True),
(superuser, {scopes.READ_USER, scopes.SUPERUSER}, True),
(superuser, {scopes.READ_USER}, False),
(normie, {scopes.SUPERUSER}, False),
(normie, {scopes.DIRECT_LOGIN}, False),
(normie, {scopes.READ_USER, scopes.SUPERUSER}, False),
(normie, {scopes.READ_USER}, False),
]
test_cases = [
(superuser, {scopes.SUPERUSER}, True),
(superuser, {scopes.DIRECT_LOGIN}, True),
(superuser, {scopes.READ_USER, scopes.SUPERUSER}, True),
(superuser, {scopes.READ_USER}, False),
(normie, {scopes.SUPERUSER}, False),
(normie, {scopes.DIRECT_LOGIN}, False),
(normie, {scopes.READ_USER, scopes.SUPERUSER}, False),
(normie, {scopes.READ_USER}, False),
]
for user_obj, scope_set, expected in test_cases:
perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set)
has_su = perm_user.can(SuperUserPermission())
assert has_su == expected
for user_obj, scope_set, expected in test_cases:
perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set)
has_su = perm_user.can(SuperUserPermission())
assert has_su == expected

View file

@ -14,190 +14,226 @@ from initdb import setup_database_for_testing, finished_database_for_testing
from util.morecollections import AttrDict
from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject
TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
TEST_USER = AttrDict({'username': 'joeuser', 'uuid': 'foobar', 'enabled': True})
TEST_AUDIENCE = app.config["SERVER_HOSTNAME"]
TEST_USER = AttrDict({"username": "joeuser", "uuid": "foobar", "enabled": True})
MAX_SIGNED_S = 3660
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
ANONYMOUS_SUB = '(anonymous)'
SERVICE_NAME = 'quay'
ANONYMOUS_SUB = "(anonymous)"
SERVICE_NAME = "quay"
# This import has to come below any references to "app".
from test.fixtures import *
def _access(typ='repository', name='somens/somerepo', actions=None):
actions = [] if actions is None else actions
return [{
'type': typ,
'name': name,
'actions': actions,
}]
def _access(typ="repository", name="somens/somerepo", actions=None):
actions = [] if actions is None else actions
return [{"type": typ, "name": name, "actions": actions}]
def _delete_field(token_data, field_name):
token_data.pop(field_name)
return token_data
token_data.pop(field_name)
return token_data
def _token_data(access=[], context=None, audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
exp=None, nbf=None, iss=None, subject=None):
if subject is None:
_, subject = build_context_and_subject(ValidatedAuthContext(user=user))
return {
'iss': iss or instance_keys.service_name,
'aud': audience,
'nbf': nbf if nbf is not None else int(time.time()),
'iat': iat if iat is not None else int(time.time()),
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': subject,
'access': access,
'context': context,
}
def _token_data(
access=[],
context=None,
audience=TEST_AUDIENCE,
user=TEST_USER,
iat=None,
exp=None,
nbf=None,
iss=None,
subject=None,
):
if subject is None:
_, subject = build_context_and_subject(ValidatedAuthContext(user=user))
return {
"iss": iss or instance_keys.service_name,
"aud": audience,
"nbf": nbf if nbf is not None else int(time.time()),
"iat": iat if iat is not None else int(time.time()),
"exp": exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
"sub": subject,
"access": access,
"context": context,
}
def _token(token_data, key_id=None, private_key=None, skip_header=False, alg=None):
key_id = key_id or instance_keys.local_key_id
private_key = private_key or instance_keys.local_private_key
key_id = key_id or instance_keys.local_key_id
private_key = private_key or instance_keys.local_private_key
if alg == "none":
private_key = None
if alg == "none":
private_key = None
token_headers = {'kid': key_id}
token_headers = {"kid": key_id}
if skip_header:
token_headers = {}
if skip_header:
token_headers = {}
token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers)
return 'Bearer {0}'.format(token_data)
token_data = jwt.encode(
token_data, private_key, alg or "RS256", headers=token_headers
)
return "Bearer {0}".format(token_data)
def _parse_token(token):
return identity_from_bearer_token(token)[0]
return identity_from_bearer_token(token)[0]
def test_accepted_token(initialized_db):
token = _token(_token_data())
identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username,
identity.id)
assert len(identity.provides) == 0
token = _token(_token_data())
identity = _parse_token(token)
assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
TEST_USER.username,
identity.id,
)
assert len(identity.provides) == 0
anon_token = _token(_token_data(user=None))
anon_identity = _parse_token(anon_token)
assert anon_identity.id == ANONYMOUS_SUB, 'should be %s, but was %s' % (ANONYMOUS_SUB,
anon_identity.id)
assert len(identity.provides) == 0
anon_token = _token(_token_data(user=None))
anon_identity = _parse_token(anon_token)
assert anon_identity.id == ANONYMOUS_SUB, "should be %s, but was %s" % (
ANONYMOUS_SUB,
anon_identity.id,
)
assert len(identity.provides) == 0
@pytest.mark.parametrize('access', [
(_access(actions=['pull', 'push'])),
(_access(actions=['pull', '*'])),
(_access(actions=['*', 'push'])),
(_access(actions=['*'])),
(_access(actions=['pull', '*', 'push'])),])
@pytest.mark.parametrize(
"access",
[
(_access(actions=["pull", "push"])),
(_access(actions=["pull", "*"])),
(_access(actions=["*", "push"])),
(_access(actions=["*"])),
(_access(actions=["pull", "*", "push"])),
],
)
def test_token_with_access(access, initialized_db):
token = _token(_token_data(access=access))
identity = _parse_token(token)
assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username,
identity.id)
assert len(identity.provides) == 1
token = _token(_token_data(access=access))
identity = _parse_token(token)
assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
TEST_USER.username,
identity.id,
)
assert len(identity.provides) == 1
role = list(identity.provides)[0][3]
if "*" in access[0]['actions']:
assert role == 'admin'
elif "push" in access[0]['actions']:
assert role == 'write'
elif "pull" in access[0]['actions']:
assert role == 'read'
role = list(identity.provides)[0][3]
if "*" in access[0]["actions"]:
assert role == "admin"
elif "push" in access[0]["actions"]:
assert role == "write"
elif "pull" in access[0]["actions"]:
assert role == "read"
@pytest.mark.parametrize('token', [
pytest.param(_token(
_token_data(access=[{
'toipe': 'repository',
'namesies': 'somens/somerepo',
'akshuns': ['pull', 'push', '*']}])), id='bad access'),
pytest.param(_token(_token_data(audience='someotherapp')), id='bad aud'),
pytest.param(_token(_delete_field(_token_data(), 'aud')), id='no aud'),
pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id='future nbf'),
pytest.param(_token(_delete_field(_token_data(), 'nbf')), id='no nbf'),
pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id='future iat'),
pytest.param(_token(_delete_field(_token_data(), 'iat')), id='no iat'),
pytest.param(_token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)), id='exp too long'),
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id='expired'),
pytest.param(_token(_delete_field(_token_data(), 'exp')), id='no exp'),
pytest.param(_token(_delete_field(_token_data(), 'sub')), id='no sub'),
pytest.param(_token(_token_data(iss='badissuer')), id='bad iss'),
pytest.param(_token(_delete_field(_token_data(), 'iss')), id='no iss'),
pytest.param(_token(_token_data(), skip_header=True), id='no header'),
pytest.param(_token(_token_data(), key_id='someunknownkey'), id='bad key'),
pytest.param(_token(_token_data(), key_id='kid7'), id='bad key :: kid7'),
pytest.param(_token(_token_data(), alg='none', private_key=None), id='none alg'),
pytest.param('some random token', id='random token'),
pytest.param('Bearer: sometokenhere', id='extra bearer'),
pytest.param('\nBearer: dGVzdA', id='leading newline'),
])
@pytest.mark.parametrize(
"token",
[
pytest.param(
_token(
_token_data(
access=[
{
"toipe": "repository",
"namesies": "somens/somerepo",
"akshuns": ["pull", "push", "*"],
}
]
)
),
id="bad access",
),
pytest.param(_token(_token_data(audience="someotherapp")), id="bad aud"),
pytest.param(_token(_delete_field(_token_data(), "aud")), id="no aud"),
pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id="future nbf"),
pytest.param(_token(_delete_field(_token_data(), "nbf")), id="no nbf"),
pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id="future iat"),
pytest.param(_token(_delete_field(_token_data(), "iat")), id="no iat"),
pytest.param(
_token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)),
id="exp too long",
),
pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id="expired"),
pytest.param(_token(_delete_field(_token_data(), "exp")), id="no exp"),
pytest.param(_token(_delete_field(_token_data(), "sub")), id="no sub"),
pytest.param(_token(_token_data(iss="badissuer")), id="bad iss"),
pytest.param(_token(_delete_field(_token_data(), "iss")), id="no iss"),
pytest.param(_token(_token_data(), skip_header=True), id="no header"),
pytest.param(_token(_token_data(), key_id="someunknownkey"), id="bad key"),
pytest.param(_token(_token_data(), key_id="kid7"), id="bad key :: kid7"),
pytest.param(
_token(_token_data(), alg="none", private_key=None), id="none alg"
),
pytest.param("some random token", id="random token"),
pytest.param("Bearer: sometokenhere", id="extra bearer"),
pytest.param("\nBearer: dGVzdA", id="leading newline"),
],
)
def test_invalid_jwt(token, initialized_db):
with pytest.raises(InvalidJWTException):
_parse_token(token)
with pytest.raises(InvalidJWTException):
_parse_token(token)
def test_mixing_keys_e2e(initialized_db):
token_data = _token_data()
token_data = _token_data()
# Create a new key for testing.
p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey',
name='newkey', metadata={})
private_key = p.exportKey('PEM')
# Create a new key for testing.
p, key = model.service_keys.generate_service_key(
instance_keys.service_name, None, kid="newkey", name="newkey", metadata={}
)
private_key = p.exportKey("PEM")
# Test first with the new valid, but unapproved key.
unapproved_key_token = _token(token_data, key_id='newkey', private_key=private_key)
with pytest.raises(InvalidJWTException):
_parse_token(unapproved_key_token)
# Test first with the new valid, but unapproved key.
unapproved_key_token = _token(token_data, key_id="newkey", private_key=private_key)
with pytest.raises(InvalidJWTException):
_parse_token(unapproved_key_token)
# Approve the key and try again.
admin_user = model.user.get_user('devtable')
model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user)
# Approve the key and try again.
admin_user = model.user.get_user("devtable")
model.service_keys.approve_service_key(
key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user
)
valid_token = _token(token_data, key_id='newkey', private_key=private_key)
valid_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(valid_token)
assert identity.id == TEST_USER.username
assert len(identity.provides) == 0
identity = _parse_token(valid_token)
assert identity.id == TEST_USER.username
assert len(identity.provides) == 0
# Try using a different private key with the existing key ID.
bad_private_token = _token(token_data, key_id='newkey',
private_key=instance_keys.local_private_key)
with pytest.raises(InvalidJWTException):
_parse_token(bad_private_token)
# Try using a different private key with the existing key ID.
bad_private_token = _token(
token_data, key_id="newkey", private_key=instance_keys.local_private_key
)
with pytest.raises(InvalidJWTException):
_parse_token(bad_private_token)
# Try using a different key ID with the existing private key.
kid_mismatch_token = _token(token_data, key_id=instance_keys.local_key_id,
private_key=private_key)
with pytest.raises(InvalidJWTException):
_parse_token(kid_mismatch_token)
# Try using a different key ID with the existing private key.
kid_mismatch_token = _token(
token_data, key_id=instance_keys.local_key_id, private_key=private_key
)
with pytest.raises(InvalidJWTException):
_parse_token(kid_mismatch_token)
# Delete the new key.
key.delete_instance(recursive=True)
# Delete the new key.
key.delete_instance(recursive=True)
# Ensure it still works (via the cache.)
deleted_key_token = _token(token_data, key_id='newkey', private_key=private_key)
identity = _parse_token(deleted_key_token)
assert identity.id == TEST_USER.username
assert len(identity.provides) == 0
# Ensure it still works (via the cache.)
deleted_key_token = _token(token_data, key_id="newkey", private_key=private_key)
identity = _parse_token(deleted_key_token)
assert identity.id == TEST_USER.username
assert len(identity.provides) == 0
# Break the cache.
instance_keys.clear_cache()
# Break the cache.
instance_keys.clear_cache()
# Ensure the key no longer works.
with pytest.raises(InvalidJWTException):
_parse_token(deleted_key_token)
# Ensure the key no longer works.
with pytest.raises(InvalidJWTException):
_parse_token(deleted_key_token)
@pytest.mark.parametrize('token', [
u'someunicodetoken✡',
u'\xc9\xad\xbd',
])
@pytest.mark.parametrize("token", [u"someunicodetoken✡", u"\xc9\xad\xbd"])
def test_unicode_token(token):
with pytest.raises(InvalidJWTException):
_parse_token(token)
with pytest.raises(InvalidJWTException):
_parse_token(token)

View file

@ -1,50 +1,55 @@
import pytest
from auth.scopes import (
scopes_from_scope_string, validate_scope_string, ALL_SCOPES, is_subset_string)
scopes_from_scope_string,
validate_scope_string,
ALL_SCOPES,
is_subset_string,
)
@pytest.mark.parametrize(
'scopes_string, expected',
[
# Valid single scopes.
('repo:read', ['repo:read']),
('repo:admin', ['repo:admin']),
# Invalid scopes.
('not:valid', []),
('repo:admins', []),
# Valid scope strings.
('repo:read repo:admin', ['repo:read', 'repo:admin']),
('repo:read,repo:admin', ['repo:read', 'repo:admin']),
('repo:read,repo:admin repo:write', ['repo:read', 'repo:admin', 'repo:write']),
# Partially invalid scopes.
('repo:read,not:valid', []),
('repo:read repo:admins', []),
# Invalid scope strings.
('repo:read|repo:admin', []),
# Mixture of delimiters.
('repo:read, repo:admin', []),])
"scopes_string, expected",
[
# Valid single scopes.
("repo:read", ["repo:read"]),
("repo:admin", ["repo:admin"]),
# Invalid scopes.
("not:valid", []),
("repo:admins", []),
# Valid scope strings.
("repo:read repo:admin", ["repo:read", "repo:admin"]),
("repo:read,repo:admin", ["repo:read", "repo:admin"]),
("repo:read,repo:admin repo:write", ["repo:read", "repo:admin", "repo:write"]),
# Partially invalid scopes.
("repo:read,not:valid", []),
("repo:read repo:admins", []),
# Invalid scope strings.
("repo:read|repo:admin", []),
# Mixture of delimiters.
("repo:read, repo:admin", []),
],
)
def test_parsing(scopes_string, expected):
expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
parsed_scope_set = scopes_from_scope_string(scopes_string)
assert parsed_scope_set == expected_scope_set
assert validate_scope_string(scopes_string) == bool(expected)
expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
parsed_scope_set = scopes_from_scope_string(scopes_string)
assert parsed_scope_set == expected_scope_set
assert validate_scope_string(scopes_string) == bool(expected)
@pytest.mark.parametrize('superset, subset, result', [
('repo:read', 'repo:read', True),
('repo:read repo:admin', 'repo:read', True),
('repo:read,repo:admin', 'repo:read', True),
('repo:read,repo:admin', 'repo:admin', True),
('repo:read,repo:admin', 'repo:admin repo:read', True),
('', 'repo:read', False),
('unknown:tag', 'repo:read', False),
('repo:read unknown:tag', 'repo:read', False),
('repo:read,unknown:tag', 'repo:read', False),])
@pytest.mark.parametrize(
"superset, subset, result",
[
("repo:read", "repo:read", True),
("repo:read repo:admin", "repo:read", True),
("repo:read,repo:admin", "repo:read", True),
("repo:read,repo:admin", "repo:admin", True),
("repo:read,repo:admin", "repo:admin repo:read", True),
("", "repo:read", False),
("unknown:tag", "repo:read", False),
("repo:read unknown:tag", "repo:read", False),
("repo:read,unknown:tag", "repo:read", False),
],
)
def test_subset_string(superset, subset, result):
assert is_subset_string(superset, subset) == result
assert is_subset_string(superset, subset) == result

View file

@ -1,32 +1,47 @@
import pytest
from auth.signedgrant import validate_signed_grant, generate_signed_token, SIGNATURE_PREFIX
from auth.signedgrant import (
validate_signed_grant,
generate_signed_token,
SIGNATURE_PREFIX,
)
from auth.validateresult import AuthKind, ValidateResult
@pytest.mark.parametrize('header, expected_result', [
pytest.param('', ValidateResult(AuthKind.signed_grant, missing=True), id='Missing'),
pytest.param('somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True),
id='Invalid header'),
pytest.param('token somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True),
id='Random Token'),
pytest.param('token ' + SIGNATURE_PREFIX + 'foo',
ValidateResult(AuthKind.signed_grant,
error_message='Signed grant could not be validated'),
id='Invalid token'),
])
@pytest.mark.parametrize(
"header, expected_result",
[
pytest.param(
"", ValidateResult(AuthKind.signed_grant, missing=True), id="Missing"
),
pytest.param(
"somerandomtoken",
ValidateResult(AuthKind.signed_grant, missing=True),
id="Invalid header",
),
pytest.param(
"token somerandomtoken",
ValidateResult(AuthKind.signed_grant, missing=True),
id="Random Token",
),
pytest.param(
"token " + SIGNATURE_PREFIX + "foo",
ValidateResult(
AuthKind.signed_grant,
error_message="Signed grant could not be validated",
),
id="Invalid token",
),
],
)
def test_token(header, expected_result):
assert validate_signed_grant(header) == expected_result
assert validate_signed_grant(header) == expected_result
def test_valid_grant():
header = 'token ' + generate_signed_token({'a': 'b'}, {'c': 'd'})
expected = ValidateResult(AuthKind.signed_grant, signed_data={
'grants': {
'a': 'b',
},
'user_context': {
'c': 'd'
},
})
assert validate_signed_grant(header) == expected
header = "token " + generate_signed_token({"a": "b"}, {"c": "d"})
expected = ValidateResult(
AuthKind.signed_grant,
signed_data={"grants": {"a": "b"}, "user_context": {"c": "d"}},
)
assert validate_signed_grant(header) == expected

View file

@ -6,58 +6,68 @@ from data import model
from data.database import AppSpecificAuthToken
from test.fixtures import *
def get_user():
return model.user.get_user('devtable')
return model.user.get_user("devtable")
def get_app_specific_token():
return AppSpecificAuthToken.get()
return AppSpecificAuthToken.get()
def get_robot():
robot, _ = model.user.create_robot('somebot', get_user())
return robot
robot, _ = model.user.create_robot("somebot", get_user())
return robot
def get_token():
return model.token.create_delegate_token('devtable', 'simple', 'sometoken')
return model.token.create_delegate_token("devtable", "simple", "sometoken")
def get_oauthtoken():
user = model.user.get_user('devtable')
return list(model.oauth.list_access_tokens_for_user(user))[0]
user = model.user.get_user("devtable")
return list(model.oauth.list_access_tokens_for_user(user))[0]
def get_signeddata():
return {'grants': {'a': 'b'}, 'user_context': {'c': 'd'}}
return {"grants": {"a": "b"}, "user_context": {"c": "d"}}
@pytest.mark.parametrize('get_entity,entity_kind', [
(get_user, 'user'),
(get_robot, 'robot'),
(get_token, 'token'),
(get_oauthtoken, 'oauthtoken'),
(get_signeddata, 'signed_data'),
(get_app_specific_token, 'appspecifictoken'),
])
@pytest.mark.parametrize(
"get_entity,entity_kind",
[
(get_user, "user"),
(get_robot, "robot"),
(get_token, "token"),
(get_oauthtoken, "oauthtoken"),
(get_signeddata, "signed_data"),
(get_app_specific_token, "appspecifictoken"),
],
)
def test_apply_context(get_entity, entity_kind, app):
assert get_authenticated_context() is None
assert get_authenticated_context() is None
entity = get_entity()
args = {}
args[entity_kind] = entity
entity = get_entity()
args = {}
args[entity_kind] = entity
result = ValidateResult(AuthKind.basic, **args)
result.apply_to_context()
result = ValidateResult(AuthKind.basic, **args)
result.apply_to_context()
expected_user = entity if entity_kind == 'user' or entity_kind == 'robot' else None
if entity_kind == 'oauthtoken':
expected_user = entity.authorized_user
expected_user = entity if entity_kind == "user" or entity_kind == "robot" else None
if entity_kind == "oauthtoken":
expected_user = entity.authorized_user
if entity_kind == 'appspecifictoken':
expected_user = entity.user
if entity_kind == "appspecifictoken":
expected_user = entity.user
expected_token = entity if entity_kind == 'token' else None
expected_oauth = entity if entity_kind == 'oauthtoken' else None
expected_appspecifictoken = entity if entity_kind == 'appspecifictoken' else None
expected_grant = entity if entity_kind == 'signed_data' else None
expected_token = entity if entity_kind == "token" else None
expected_oauth = entity if entity_kind == "oauthtoken" else None
expected_appspecifictoken = entity if entity_kind == "appspecifictoken" else None
expected_grant = entity if entity_kind == "signed_data" else None
assert get_authenticated_context().authed_user == expected_user
assert get_authenticated_context().token == expected_token
assert get_authenticated_context().oauthtoken == expected_oauth
assert get_authenticated_context().appspecifictoken == expected_appspecifictoken
assert get_authenticated_context().signed_data == expected_grant
assert get_authenticated_context().authed_user == expected_user
assert get_authenticated_context().token == expected_token
assert get_authenticated_context().oauthtoken == expected_oauth
assert get_authenticated_context().appspecifictoken == expected_appspecifictoken
assert get_authenticated_context().signed_data == expected_grant

View file

@ -3,54 +3,76 @@ from auth.auth_context_type import ValidatedAuthContext, ContextEntityKind
class AuthKind(Enum):
cookie = 'cookie'
basic = 'basic'
oauth = 'oauth'
signed_grant = 'signed_grant'
credentials = 'credentials'
cookie = "cookie"
basic = "basic"
oauth = "oauth"
signed_grant = "signed_grant"
credentials = "credentials"
class ValidateResult(object):
""" A result of validating auth in one form or another. """
def __init__(self, kind, missing=False, user=None, token=None, oauthtoken=None,
robot=None, appspecifictoken=None, signed_data=None, error_message=None):
self.kind = kind
self.missing = missing
self.error_message = error_message
self.context = ValidatedAuthContext(user=user, token=token, oauthtoken=oauthtoken, robot=robot,
appspecifictoken=appspecifictoken, signed_data=signed_data)
""" A result of validating auth in one form or another. """
def tuple(self):
return (self.kind, self.missing, self.error_message, self.context.tuple())
def __init__(
self,
kind,
missing=False,
user=None,
token=None,
oauthtoken=None,
robot=None,
appspecifictoken=None,
signed_data=None,
error_message=None,
):
self.kind = kind
self.missing = missing
self.error_message = error_message
self.context = ValidatedAuthContext(
user=user,
token=token,
oauthtoken=oauthtoken,
robot=robot,
appspecifictoken=appspecifictoken,
signed_data=signed_data,
)
def __eq__(self, other):
return self.tuple() == other.tuple()
def tuple(self):
return (self.kind, self.missing, self.error_message, self.context.tuple())
def apply_to_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
self.context.apply_to_request_context()
def __eq__(self, other):
return self.tuple() == other.tuple()
def with_kind(self, kind):
""" Returns a copy of this result, but with the kind replaced. """
result = ValidateResult(kind, missing=self.missing, error_message=self.error_message)
result.context = self.context
return result
def apply_to_context(self):
""" Applies this auth result to the auth context and Flask-Principal. """
self.context.apply_to_request_context()
def __repr__(self):
return 'ValidateResult: %s (missing: %s, error: %s)' % (self.kind, self.missing,
self.error_message)
def with_kind(self, kind):
""" Returns a copy of this result, but with the kind replaced. """
result = ValidateResult(
kind, missing=self.missing, error_message=self.error_message
)
result.context = self.context
return result
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. """
return self.context.authed_user
def __repr__(self):
return "ValidateResult: %s (missing: %s, error: %s)" % (
self.kind,
self.missing,
self.error_message,
)
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
return self.context.has_nonrobot_user
@property
def authed_user(self):
""" Returns the authenticated user, whether directly, or via an OAuth token. """
return self.context.authed_user
@property
def auth_valid(self):
""" Returns whether authentication successfully occurred. """
return self.context.entity_kind != ContextEntityKind.anonymous
@property
def has_nonrobot_user(self):
""" Returns whether a user (not a robot) was authenticated successfully. """
return self.context.has_nonrobot_user
@property
def auth_valid(self):
""" Returns whether authentication successfully occurred. """
return self.context.entity_kind != ContextEntityKind.anonymous

View file

@ -6,110 +6,133 @@ from requests.exceptions import RequestException
logger = logging.getLogger(__name__)
class Avatar(object):
def __init__(self, app=None):
self.app = app
self.state = self._init_app(app)
def __init__(self, app=None):
self.app = app
self.state = self._init_app(app)
def _init_app(self, app):
return AVATAR_CLASSES[app.config.get('AVATAR_KIND', 'Gravatar')](
app.config['PREFERRED_URL_SCHEME'], app.config['AVATAR_COLORS'], app.config['HTTPCLIENT'])
def _init_app(self, app):
return AVATAR_CLASSES[app.config.get("AVATAR_KIND", "Gravatar")](
app.config["PREFERRED_URL_SCHEME"],
app.config["AVATAR_COLORS"],
app.config["HTTPCLIENT"],
)
def __getattr__(self, name):
return getattr(self.state, name, None)
def __getattr__(self, name):
return getattr(self.state, name, None)
class BaseAvatar(object):
""" Base class for all avatar implementations. """
def __init__(self, preferred_url_scheme, colors, http_client):
self.preferred_url_scheme = preferred_url_scheme
self.colors = colors
self.http_client = http_client
""" Base class for all avatar implementations. """
def get_mail_html(self, name, email_or_id, size=16, kind='user'):
""" Returns the full HTML and CSS for viewing the avatar of the given name and email address,
def __init__(self, preferred_url_scheme, colors, http_client):
self.preferred_url_scheme = preferred_url_scheme
self.colors = colors
self.http_client = http_client
def get_mail_html(self, name, email_or_id, size=16, kind="user"):
""" Returns the full HTML and CSS for viewing the avatar of the given name and email address,
with an optional size.
"""
data = self.get_data(name, email_or_id, kind)
url = self._get_url(data['hash'], size) if kind != 'team' else None
font_size = size - 6
data = self.get_data(name, email_or_id, kind)
url = self._get_url(data["hash"], size) if kind != "team" else None
font_size = size - 6
if url is not None:
# Try to load the gravatar. If we get a non-404 response, then we use it in place of
# the CSS avatar.
try:
response = self.http_client.get(url, timeout=5)
if response.status_code == 200:
return """<img src="%s" width="%s" height="%s" alt="%s"
style="vertical-align: middle;">""" % (url, size, size, kind)
except RequestException:
logger.exception('Could not retrieve avatar for user %s', name)
if url is not None:
# Try to load the gravatar. If we get a non-404 response, then we use it in place of
# the CSS avatar.
try:
response = self.http_client.get(url, timeout=5)
if response.status_code == 200:
return """<img src="%s" width="%s" height="%s" alt="%s"
style="vertical-align: middle;">""" % (
url,
size,
size,
kind,
)
except RequestException:
logger.exception("Could not retrieve avatar for user %s", name)
radius = '50%' if kind == 'team' else '0%'
letter = '&Omega;' if kind == 'team' and data['name'] == 'owners' else data['name'].upper()[0]
radius = "50%" if kind == "team" else "0%"
letter = (
"&Omega;"
if kind == "team" and data["name"] == "owners"
else data["name"].upper()[0]
)
return """
return """
<span style="width: %spx; height: %spx; background-color: %s; font-size: %spx;
line-height: %spx; margin-left: 2px; margin-right: 2px; display: inline-block;
vertical-align: middle; text-align: center; color: white; border-radius: %s">
%s
</span>
""" % (size, size, data['color'], font_size, size, radius, letter)
""" % (
size,
size,
data["color"],
font_size,
size,
radius,
letter,
)
def get_data_for_user(self, user):
return self.get_data(user.username, user.email, 'robot' if user.robot else 'user')
def get_data_for_user(self, user):
return self.get_data(
user.username, user.email, "robot" if user.robot else "user"
)
def get_data_for_team(self, team):
return self.get_data(team.name, team.name, 'team')
def get_data_for_team(self, team):
return self.get_data(team.name, team.name, "team")
def get_data_for_org(self, org):
return self.get_data(org.username, org.email, 'org')
def get_data_for_org(self, org):
return self.get_data(org.username, org.email, "org")
def get_data_for_external_user(self, external_user):
return self.get_data(external_user.username, external_user.email, 'user')
def get_data_for_external_user(self, external_user):
return self.get_data(external_user.username, external_user.email, "user")
def get_data(self, name, email_or_id, kind='user'):
""" Computes and returns the full data block for the avatar:
def get_data(self, name, email_or_id, kind="user"):
""" Computes and returns the full data block for the avatar:
{
'name': name,
'hash': The gravatar hash, if any.
'color': The color for the avatar
}
"""
colors = self.colors
colors = self.colors
# Note: email_or_id may be None if gotten from external auth when email is disabled,
# so use the username in that case.
username_email_or_id = email_or_id or name
hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest()
# Note: email_or_id may be None if gotten from external auth when email is disabled,
# so use the username in that case.
username_email_or_id = email_or_id or name
hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest()
byte_count = int(math.ceil(math.log(len(colors), 16)))
byte_data = hash_value[0:byte_count]
hash_color = colors[int(byte_data, 16) % len(colors)]
byte_count = int(math.ceil(math.log(len(colors), 16)))
byte_data = hash_value[0:byte_count]
hash_color = colors[int(byte_data, 16) % len(colors)]
return {
'name': name,
'hash': hash_value,
'color': hash_color,
'kind': kind
}
return {"name": name, "hash": hash_value, "color": hash_color, "kind": kind}
def _get_url(self, hash_value, size):
""" Returns the URL for displaying the overlay avatar. """
return None
def _get_url(self, hash_value, size):
""" Returns the URL for displaying the overlay avatar. """
return None
class GravatarAvatar(BaseAvatar):
""" Avatar system that uses gravatar for generating avatars. """
def _get_url(self, hash_value, size=16):
return '%s://www.gravatar.com/avatar/%s?d=404&size=%s' % (self.preferred_url_scheme,
hash_value, size)
""" Avatar system that uses gravatar for generating avatars. """
def _get_url(self, hash_value, size=16):
return "%s://www.gravatar.com/avatar/%s?d=404&size=%s" % (
self.preferred_url_scheme,
hash_value,
size,
)
class LocalAvatar(BaseAvatar):
""" Avatar system that uses the local system for generating avatars. """
pass
""" Avatar system that uses the local system for generating avatars. """
AVATAR_CLASSES = {
'gravatar': GravatarAvatar,
'local': LocalAvatar
}
pass
AVATAR_CLASSES = {"gravatar": GravatarAvatar, "local": LocalAvatar}

165
boot.py
View file

@ -16,7 +16,7 @@ from data.model.release import set_region_release
from data.model.service_keys import get_service_key
from util.config.database import sync_database_with_config
from util.generatepresharedkey import generate_key
from _init import CONF_DIR
from _init import CONF_DIR
logger = logging.getLogger(__name__)
@ -24,108 +24,117 @@ logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_audience():
audience = app.config.get('JWTPROXY_AUDIENCE')
audience = app.config.get("JWTPROXY_AUDIENCE")
if audience:
return audience
if audience:
return audience
scheme = app.config.get('PREFERRED_URL_SCHEME')
hostname = app.config.get('SERVER_HOSTNAME')
scheme = app.config.get("PREFERRED_URL_SCHEME")
hostname = app.config.get("SERVER_HOSTNAME")
# hostname includes port, use that
if ':' in hostname:
return urlunparse((scheme, hostname, '', '', '', ''))
# hostname includes port, use that
if ":" in hostname:
return urlunparse((scheme, hostname, "", "", "", ""))
# no port, guess based on scheme
if scheme == 'https':
port = '443'
else:
port = '80'
# no port, guess based on scheme
if scheme == "https":
port = "443"
else:
port = "80"
return urlunparse((scheme, hostname + ':' + port, '', '', '', ''))
return urlunparse((scheme, hostname + ":" + port, "", "", "", ""))
def _verify_service_key():
try:
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION']) as f:
quay_key_id = f.read()
try:
get_service_key(quay_key_id, approved_only=False)
assert os.path.exists(app.config['INSTANCE_SERVICE_KEY_LOCATION'])
return quay_key_id
except ServiceKeyDoesNotExist:
logger.exception('Could not find non-expired existing service key %s; creating a new one',
quay_key_id)
return None
with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"]) as f:
quay_key_id = f.read()
# Found a valid service key, so exiting.
except IOError:
logger.exception('Could not load existing service key; creating a new one')
return None
try:
get_service_key(quay_key_id, approved_only=False)
assert os.path.exists(app.config["INSTANCE_SERVICE_KEY_LOCATION"])
return quay_key_id
except ServiceKeyDoesNotExist:
logger.exception(
"Could not find non-expired existing service key %s; creating a new one",
quay_key_id,
)
return None
# Found a valid service key, so exiting.
except IOError:
logger.exception("Could not load existing service key; creating a new one")
return None
def setup_jwt_proxy():
"""
"""
Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration.
"""
if os.path.exists(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml')):
# Proxy is already setup. Make sure the service key is still valid.
quay_key_id = _verify_service_key()
if quay_key_id is not None:
return
if os.path.exists(os.path.join(CONF_DIR, "jwtproxy_conf.yaml")):
# Proxy is already setup. Make sure the service key is still valid.
quay_key_id = _verify_service_key()
if quay_key_id is not None:
return
# Ensure we have an existing key if in read-only mode.
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly':
quay_key_id = _verify_service_key()
if quay_key_id is None:
raise Exception('No valid service key found for read-only registry.')
else:
# Generate the key for this Quay instance to use.
minutes_until_expiration = app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
quay_key, quay_key_id = generate_key(app.config['INSTANCE_SERVICE_KEY_SERVICE'],
get_audience(), expiration_date=expiration)
# Ensure we have an existing key if in read-only mode.
if app.config.get("REGISTRY_STATE", "normal") == "readonly":
quay_key_id = _verify_service_key()
if quay_key_id is None:
raise Exception("No valid service key found for read-only registry.")
else:
# Generate the key for this Quay instance to use.
minutes_until_expiration = app.config.get(
"INSTANCE_SERVICE_KEY_EXPIRATION", 120
)
expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
quay_key, quay_key_id = generate_key(
app.config["INSTANCE_SERVICE_KEY_SERVICE"],
get_audience(),
expiration_date=expiration,
)
with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'], mode='w') as f:
f.truncate(0)
f.write(quay_key_id)
with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"], mode="w") as f:
f.truncate(0)
f.write(quay_key_id)
with open(app.config['INSTANCE_SERVICE_KEY_LOCATION'], mode='w') as f:
f.truncate(0)
f.write(quay_key.exportKey())
with open(app.config["INSTANCE_SERVICE_KEY_LOCATION"], mode="w") as f:
f.truncate(0)
f.write(quay_key.exportKey())
# Generate the JWT proxy configuration.
audience = get_audience()
registry = audience + '/keys'
security_issuer = app.config.get('SECURITY_SCANNER_ISSUER_NAME', 'security_scanner')
# Generate the JWT proxy configuration.
audience = get_audience()
registry = audience + "/keys"
security_issuer = app.config.get("SECURITY_SCANNER_ISSUER_NAME", "security_scanner")
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml.jnj')) as f:
template = Template(f.read())
rendered = template.render(
conf_dir=CONF_DIR,
audience=audience,
registry=registry,
key_id=quay_key_id,
security_issuer=security_issuer,
service_key_location=app.config['INSTANCE_SERVICE_KEY_LOCATION'],
)
with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml.jnj")) as f:
template = Template(f.read())
rendered = template.render(
conf_dir=CONF_DIR,
audience=audience,
registry=registry,
key_id=quay_key_id,
security_issuer=security_issuer,
service_key_location=app.config["INSTANCE_SERVICE_KEY_LOCATION"],
)
with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml'), 'w') as f:
f.write(rendered)
with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml"), "w") as f:
f.write(rendered)
def main():
if not app.config.get('SETUP_COMPLETE', False):
raise Exception('Your configuration bundle is either not mounted or setup has not been completed')
if not app.config.get("SETUP_COMPLETE", False):
raise Exception(
"Your configuration bundle is either not mounted or setup has not been completed"
)
sync_database_with_config(app.config)
setup_jwt_proxy()
sync_database_with_config(app.config)
setup_jwt_proxy()
# Record deploy
if release.REGION and release.GIT_HEAD:
set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
# Record deploy
if release.REGION and release.GIT_HEAD:
set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View file

@ -5,38 +5,39 @@ from trollius import get_event_loop, coroutine
def wrap_with_threadpool(obj, worker_threads=1):
"""
"""
Wraps a class in an async executor so that it can be safely used in an event loop like trollius.
"""
async_executor = ThreadPoolExecutor(worker_threads)
return AsyncWrapper(obj, executor=async_executor), async_executor
async_executor = ThreadPoolExecutor(worker_threads)
return AsyncWrapper(obj, executor=async_executor), async_executor
class AsyncWrapper(object):
""" Wrapper class which will transform a syncronous library to one that can be used with
""" Wrapper class which will transform a syncronous library to one that can be used with
trollius coroutines.
"""
def __init__(self, delegate, loop=None, executor=None):
self._loop = loop if loop is not None else get_event_loop()
self._delegate = delegate
self._executor = executor
def __getattr__(self, attrib):
delegate_attr = getattr(self._delegate, attrib)
def __init__(self, delegate, loop=None, executor=None):
self._loop = loop if loop is not None else get_event_loop()
self._delegate = delegate
self._executor = executor
if not callable(delegate_attr):
return delegate_attr
def __getattr__(self, attrib):
delegate_attr = getattr(self._delegate, attrib)
def wrapper(*args, **kwargs):
""" Wraps the delegate_attr with primitives that will transform sync calls to ones shelled
if not callable(delegate_attr):
return delegate_attr
def wrapper(*args, **kwargs):
""" Wraps the delegate_attr with primitives that will transform sync calls to ones shelled
out to a thread pool.
"""
callable_delegate_attr = partial(delegate_attr, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr)
callable_delegate_attr = partial(delegate_attr, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr)
return wrapper
return wrapper
@coroutine
def __call__(self, *args, **kwargs):
callable_delegate_attr = partial(self._delegate, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr)
@coroutine
def __call__(self, *args, **kwargs):
callable_delegate_attr = partial(self._delegate, *args, **kwargs)
return self._loop.run_in_executor(self._executor, callable_delegate_attr)

View file

@ -18,80 +18,104 @@ from raven.conf import setup_logging
logger = logging.getLogger(__name__)
BUILD_MANAGERS = {
'enterprise': EnterpriseManager,
'ephemeral': EphemeralBuilderManager,
}
BUILD_MANAGERS = {"enterprise": EnterpriseManager, "ephemeral": EphemeralBuilderManager}
EXTERNALLY_MANAGED = 'external'
EXTERNALLY_MANAGED = "external"
DEFAULT_WEBSOCKET_PORT = 8787
DEFAULT_CONTROLLER_PORT = 8686
LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
def run_build_manager():
if not features.BUILD_SUPPORT:
logger.debug('Building is disabled. Please enable the feature flag')
while True:
time.sleep(1000)
return
if not features.BUILD_SUPPORT:
logger.debug("Building is disabled. Please enable the feature flag")
while True:
time.sleep(1000)
return
if app.config.get('REGISTRY_STATE', 'normal') == 'readonly':
logger.debug('Building is disabled while in read-only mode.')
while True:
time.sleep(1000)
return
if app.config.get("REGISTRY_STATE", "normal") == "readonly":
logger.debug("Building is disabled while in read-only mode.")
while True:
time.sleep(1000)
return
build_manager_config = app.config.get('BUILD_MANAGER')
if build_manager_config is None:
return
build_manager_config = app.config.get("BUILD_MANAGER")
if build_manager_config is None:
return
# If the build system is externally managed, then we just sleep this process.
if build_manager_config[0] == EXTERNALLY_MANAGED:
logger.debug('Builds are externally managed.')
while True:
time.sleep(1000)
return
# If the build system is externally managed, then we just sleep this process.
if build_manager_config[0] == EXTERNALLY_MANAGED:
logger.debug("Builds are externally managed.")
while True:
time.sleep(1000)
return
logger.debug('Asking to start build manager with lifecycle "%s"', build_manager_config[0])
manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
if manager_klass is None:
return
logger.debug(
'Asking to start build manager with lifecycle "%s"', build_manager_config[0]
)
manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
if manager_klass is None:
return
manager_hostname = os.environ.get('BUILDMAN_HOSTNAME',
app.config.get('BUILDMAN_HOSTNAME',
app.config['SERVER_HOSTNAME']))
websocket_port = int(os.environ.get('BUILDMAN_WEBSOCKET_PORT',
app.config.get('BUILDMAN_WEBSOCKET_PORT',
DEFAULT_WEBSOCKET_PORT)))
controller_port = int(os.environ.get('BUILDMAN_CONTROLLER_PORT',
app.config.get('BUILDMAN_CONTROLLER_PORT',
DEFAULT_CONTROLLER_PORT)))
manager_hostname = os.environ.get(
"BUILDMAN_HOSTNAME",
app.config.get("BUILDMAN_HOSTNAME", app.config["SERVER_HOSTNAME"]),
)
websocket_port = int(
os.environ.get(
"BUILDMAN_WEBSOCKET_PORT",
app.config.get("BUILDMAN_WEBSOCKET_PORT", DEFAULT_WEBSOCKET_PORT),
)
)
controller_port = int(
os.environ.get(
"BUILDMAN_CONTROLLER_PORT",
app.config.get("BUILDMAN_CONTROLLER_PORT", DEFAULT_CONTROLLER_PORT),
)
)
logger.debug('Will pass buildman hostname %s to builders for websocket connection',
manager_hostname)
logger.debug(
"Will pass buildman hostname %s to builders for websocket connection",
manager_hostname,
)
logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
ssl_context = None
if os.environ.get('SSL_CONFIG'):
logger.debug('Loading SSL cert and key')
ssl_context = SSLContext()
ssl_context.load_cert_chain(os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.cert'),
os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.key'))
logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
ssl_context = None
if os.environ.get("SSL_CONFIG"):
logger.debug("Loading SSL cert and key")
ssl_context = SSLContext()
ssl_context.load_cert_chain(
os.path.join(os.environ.get("SSL_CONFIG"), "ssl.cert"),
os.path.join(os.environ.get("SSL_CONFIG"), "ssl.key"),
)
server = BuilderServer(app.config['SERVER_HOSTNAME'], dockerfile_build_queue, build_logs,
user_files, manager_klass, build_manager_config[1], manager_hostname)
server.run('0.0.0.0', websocket_port, controller_port, ssl=ssl_context)
server = BuilderServer(
app.config["SERVER_HOSTNAME"],
dockerfile_build_queue,
build_logs,
user_files,
manager_klass,
build_manager_config[1],
manager_hostname,
)
server.run("0.0.0.0", websocket_port, controller_port, ssl=ssl_context)
if __name__ == '__main__':
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
logging.getLogger('peewee').setLevel(logging.WARN)
logging.getLogger('boto').setLevel(logging.WARN)
if app.config.get('EXCEPTION_LOG_TYPE', 'FakeSentry') == 'Sentry':
buildman_name = '%s:buildman' % socket.gethostname()
setup_logging(SentryHandler(app.config.get('SENTRY_DSN', ''), name=buildman_name,
level=logging.ERROR))
if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
logging.getLogger("peewee").setLevel(logging.WARN)
logging.getLogger("boto").setLevel(logging.WARN)
run_build_manager()
if app.config.get("EXCEPTION_LOG_TYPE", "FakeSentry") == "Sentry":
buildman_name = "%s:buildman" % socket.gethostname()
setup_logging(
SentryHandler(
app.config.get("SENTRY_DSN", ""),
name=buildman_name,
level=logging.ERROR,
)
)
run_build_manager()

View file

@ -1,13 +1,15 @@
from autobahn.asyncio.wamp import ApplicationSession
class BaseComponent(ApplicationSession):
""" Base class for all registered component sessions in the server. """
def __init__(self, config, **kwargs):
ApplicationSession.__init__(self, config)
self.server = None
self.parent_manager = None
self.build_logs = None
self.user_files = None
def kind(self):
raise NotImplementedError
class BaseComponent(ApplicationSession):
""" Base class for all registered component sessions in the server. """
def __init__(self, config, **kwargs):
ApplicationSession.__init__(self, config)
self.server = None
self.parent_manager = None
self.build_logs = None
self.user_files = None
def kind(self):
raise NotImplementedError

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,16 @@
import re
def extract_current_step(current_status_string):
""" Attempts to extract the current step numeric identifier from the given status string. Returns the step
""" Attempts to extract the current step numeric identifier from the given status string. Returns the step
number or None if none.
"""
# Older format: `Step 12 :`
# Newer format: `Step 4/13 :`
step_increment = re.search(r'Step ([0-9]+)/([0-9]+) :', current_status_string)
if step_increment:
return int(step_increment.group(1))
# Older format: `Step 12 :`
# Newer format: `Step 4/13 :`
step_increment = re.search(r"Step ([0-9]+)/([0-9]+) :", current_status_string)
if step_increment:
return int(step_increment.group(1))
step_increment = re.search(r'Step ([0-9]+) :', current_status_string)
if step_increment:
return int(step_increment.group(1))
step_increment = re.search(r"Step ([0-9]+) :", current_status_string)
if step_increment:
return int(step_increment.group(1))

View file

@ -3,34 +3,62 @@ import pytest
from buildman.component.buildcomponent import BuildComponent
@pytest.mark.parametrize('input,expected_path,expected_file', [
("", "/", "Dockerfile"),
("/", "/", "Dockerfile"),
("/Dockerfile", "/", "Dockerfile"),
("/server.Dockerfile", "/", "server.Dockerfile"),
("/somepath", "/somepath", "Dockerfile"),
("/somepath/", "/somepath", "Dockerfile"),
("/somepath/Dockerfile", "/somepath", "Dockerfile"),
("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/Dockerfile", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/server.Dockerfile", "/somepath/some_other_path", "server.Dockerfile"),
])
@pytest.mark.parametrize(
"input,expected_path,expected_file",
[
("", "/", "Dockerfile"),
("/", "/", "Dockerfile"),
("/Dockerfile", "/", "Dockerfile"),
("/server.Dockerfile", "/", "server.Dockerfile"),
("/somepath", "/somepath", "Dockerfile"),
("/somepath/", "/somepath", "Dockerfile"),
("/somepath/Dockerfile", "/somepath", "Dockerfile"),
("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
(
"/somepath/some_other_path/Dockerfile",
"/somepath/some_other_path",
"Dockerfile",
),
(
"/somepath/some_other_path/server.Dockerfile",
"/somepath/some_other_path",
"server.Dockerfile",
),
],
)
def test_path_is_dockerfile(input, expected_path, expected_file):
actual_path, actual_file = BuildComponent.name_and_path(input)
assert actual_path == expected_path
assert actual_file == expected_file
actual_path, actual_file = BuildComponent.name_and_path(input)
assert actual_path == expected_path
assert actual_file == expected_file
@pytest.mark.parametrize('build_config,context,dockerfile_path', [
({}, '', ''),
({'build_subdir': '/builddir/Dockerfile'}, '', '/builddir/Dockerfile'),
({'context': '/builddir'}, '/builddir', ''),
({'context': '/builddir', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'),
({'context': '/some_other_dir/Dockerfile', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'),
({'context': '/', 'build_subdir':'Dockerfile'}, '/', 'Dockerfile')
])
@pytest.mark.parametrize(
"build_config,context,dockerfile_path",
[
({}, "", ""),
({"build_subdir": "/builddir/Dockerfile"}, "", "/builddir/Dockerfile"),
({"context": "/builddir"}, "/builddir", ""),
(
{"context": "/builddir", "build_subdir": "/builddir/Dockerfile"},
"/builddir",
"Dockerfile",
),
(
{
"context": "/some_other_dir/Dockerfile",
"build_subdir": "/builddir/Dockerfile",
},
"/builddir",
"Dockerfile",
),
({"context": "/", "build_subdir": "Dockerfile"}, "/", "Dockerfile"),
],
)
def test_extract_dockerfile_args(build_config, context, dockerfile_path):
actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(build_config)
assert context == actual_context
assert dockerfile_path == actual_dockerfile_path
actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(
build_config
)
assert context == actual_context
assert dockerfile_path == actual_dockerfile_path

View file

@ -3,14 +3,17 @@ import pytest
from buildman.component.buildparse import extract_current_step
@pytest.mark.parametrize('input,expected_step', [
("", None),
("Step a :", None),
("Step 1 :", 1),
("Step 1 : ", 1),
("Step 1/2 : ", 1),
("Step 2/17 : ", 2),
("Step 4/13 : ARG somearg=foo", 4),
])
@pytest.mark.parametrize(
"input,expected_step",
[
("", None),
("Step a :", None),
("Step 1 :", 1),
("Step 1 : ", 1),
("Step 1/2 : ", 1),
("Step 2/17 : ", 2),
("Step 4/13 : ARG somearg=foo", 4),
],
)
def test_extract_current_step(input, expected_step):
assert extract_current_step(input) == expected_step
assert extract_current_step(input) == expected_step

View file

@ -1,21 +1,25 @@
from data.database import BUILD_PHASE
class BuildJobResult(object):
""" Build job result enum """
INCOMPLETE = 'incomplete'
COMPLETE = 'complete'
ERROR = 'error'
""" Build job result enum """
INCOMPLETE = "incomplete"
COMPLETE = "complete"
ERROR = "error"
class BuildServerStatus(object):
""" Build server status enum """
STARTING = 'starting'
RUNNING = 'running'
SHUTDOWN = 'shutting_down'
EXCEPTION = 'exception'
""" Build server status enum """
STARTING = "starting"
RUNNING = "running"
SHUTDOWN = "shutting_down"
EXCEPTION = "exception"
RESULT_PHASES = {
BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,
BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE,
BuildJobResult.ERROR: BUILD_PHASE.ERROR,
BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,
BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE,
BuildJobResult.ERROR: BUILD_PHASE.ERROR,
}

View file

@ -14,170 +14,196 @@ logger = logging.getLogger(__name__)
class BuildJobLoadException(Exception):
""" Exception raised if a build job could not be instantiated for some reason. """
pass
""" Exception raised if a build job could not be instantiated for some reason. """
pass
class BuildJob(object):
""" Represents a single in-progress build job. """
def __init__(self, job_item):
self.job_item = job_item
""" Represents a single in-progress build job. """
try:
self.job_details = json.loads(job_item.body)
self.build_notifier = BuildJobNotifier(self.build_uuid)
except ValueError:
raise BuildJobLoadException(
'Could not parse build queue item config with ID %s' % self.job_details['build_uuid']
)
def __init__(self, job_item):
self.job_item = job_item
@property
def retries_remaining(self):
return self.job_item.retries_remaining
try:
self.job_details = json.loads(job_item.body)
self.build_notifier = BuildJobNotifier(self.build_uuid)
except ValueError:
raise BuildJobLoadException(
"Could not parse build queue item config with ID %s"
% self.job_details["build_uuid"]
)
def has_retries_remaining(self):
return self.job_item.retries_remaining > 0
@property
def retries_remaining(self):
return self.job_item.retries_remaining
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None):
self.build_notifier.send_notification(kind, error_message, image_id, manifest_digests)
def has_retries_remaining(self):
return self.job_item.retries_remaining > 0
@lru_cache(maxsize=1)
def _load_repo_build(self):
with UseThenDisconnect(app.config):
try:
return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException:
raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid)
def send_notification(
self, kind, error_message=None, image_id=None, manifest_digests=None
):
self.build_notifier.send_notification(
kind, error_message, image_id, manifest_digests
)
@property
def build_uuid(self):
""" Returns the unique UUID for this build job. """
return self.job_details['build_uuid']
@lru_cache(maxsize=1)
def _load_repo_build(self):
with UseThenDisconnect(app.config):
try:
return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException:
raise BuildJobLoadException(
"Could not load repository build with ID %s" % self.build_uuid
)
@property
def namespace(self):
""" Returns the namespace under which this build is running. """
return self.repo_build.repository.namespace_user.username
@property
def build_uuid(self):
""" Returns the unique UUID for this build job. """
return self.job_details["build_uuid"]
@property
def repo_name(self):
""" Returns the name of the repository under which this build is running. """
return self.repo_build.repository.name
@property
def namespace(self):
""" Returns the namespace under which this build is running. """
return self.repo_build.repository.namespace_user.username
@property
def repo_build(self):
return self._load_repo_build()
@property
def repo_name(self):
""" Returns the name of the repository under which this build is running. """
return self.repo_build.repository.name
def get_build_package_url(self, user_files):
""" Returns the URL of the build package for this build, if any or empty string if none. """
archive_url = self.build_config.get('archive_url', None)
if archive_url:
return archive_url
@property
def repo_build(self):
return self._load_repo_build()
if not self.repo_build.resource_key:
return ''
def get_build_package_url(self, user_files):
""" Returns the URL of the build package for this build, if any or empty string if none. """
archive_url = self.build_config.get("archive_url", None)
if archive_url:
return archive_url
return user_files.get_file_url(self.repo_build.resource_key, '127.0.0.1', requires_cors=False)
if not self.repo_build.resource_key:
return ""
@property
def pull_credentials(self):
""" Returns the pull credentials for this job, or None if none. """
return self.job_details.get('pull_credentials')
return user_files.get_file_url(
self.repo_build.resource_key, "127.0.0.1", requires_cors=False
)
@property
def build_config(self):
try:
return json.loads(self.repo_build.job_config)
except ValueError:
raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.job_details['build_uuid']
)
@property
def pull_credentials(self):
""" Returns the pull credentials for this job, or None if none. """
return self.job_details.get("pull_credentials")
def determine_cached_tag(self, base_image_id=None, cache_comments=None):
""" Returns the tag to pull to prime the cache or None if none. """
cached_tag = self._determine_cached_tag_by_tag()
logger.debug('Determined cached tag %s for %s: %s', cached_tag, base_image_id, cache_comments)
return cached_tag
@property
def build_config(self):
try:
return json.loads(self.repo_build.job_config)
except ValueError:
raise BuildJobLoadException(
"Could not parse repository build job config with ID %s"
% self.job_details["build_uuid"]
)
def _determine_cached_tag_by_tag(self):
""" Determines the cached tag by looking for one of the tags being built, and seeing if it
def determine_cached_tag(self, base_image_id=None, cache_comments=None):
""" Returns the tag to pull to prime the cache or None if none. """
cached_tag = self._determine_cached_tag_by_tag()
logger.debug(
"Determined cached tag %s for %s: %s",
cached_tag,
base_image_id,
cache_comments,
)
return cached_tag
def _determine_cached_tag_by_tag(self):
""" Determines the cached tag by looking for one of the tags being built, and seeing if it
exists in the repository. This is a fallback for when no comment information is available.
"""
with UseThenDisconnect(app.config):
tags = self.build_config.get('docker_tags', ['latest'])
repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
matching_tag = registry_model.find_matching_tag(repository, tags)
if matching_tag is not None:
return matching_tag.name
with UseThenDisconnect(app.config):
tags = self.build_config.get("docker_tags", ["latest"])
repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
matching_tag = registry_model.find_matching_tag(repository, tags)
if matching_tag is not None:
return matching_tag.name
most_recent_tag = registry_model.get_most_recent_tag(repository)
if most_recent_tag is not None:
return most_recent_tag.name
most_recent_tag = registry_model.get_most_recent_tag(repository)
if most_recent_tag is not None:
return most_recent_tag.name
return None
return None
class BuildJobNotifier(object):
""" A class for sending notifications to a job that only relies on the build_uuid """
""" A class for sending notifications to a job that only relies on the build_uuid """
def __init__(self, build_uuid):
self.build_uuid = build_uuid
def __init__(self, build_uuid):
self.build_uuid = build_uuid
@property
def repo_build(self):
return self._load_repo_build()
@property
def repo_build(self):
return self._load_repo_build()
@lru_cache(maxsize=1)
def _load_repo_build(self):
try:
return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException:
raise BuildJobLoadException(
'Could not load repository build with ID %s' % self.build_uuid)
@lru_cache(maxsize=1)
def _load_repo_build(self):
try:
return model.build.get_repository_build(self.build_uuid)
except model.InvalidRepositoryBuildException:
raise BuildJobLoadException(
"Could not load repository build with ID %s" % self.build_uuid
)
@property
def build_config(self):
try:
return json.loads(self.repo_build.job_config)
except ValueError:
raise BuildJobLoadException(
'Could not parse repository build job config with ID %s' % self.repo_build.uuid
)
@property
def build_config(self):
try:
return json.loads(self.repo_build.job_config)
except ValueError:
raise BuildJobLoadException(
"Could not parse repository build job config with ID %s"
% self.repo_build.uuid
)
def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None):
with UseThenDisconnect(app.config):
tags = self.build_config.get('docker_tags', ['latest'])
trigger = self.repo_build.trigger
if trigger is not None and trigger.id is not None:
trigger_kind = trigger.service.name
else:
trigger_kind = None
def send_notification(
self, kind, error_message=None, image_id=None, manifest_digests=None
):
with UseThenDisconnect(app.config):
tags = self.build_config.get("docker_tags", ["latest"])
trigger = self.repo_build.trigger
if trigger is not None and trigger.id is not None:
trigger_kind = trigger.service.name
else:
trigger_kind = None
event_data = {
'build_id': self.repo_build.uuid,
'build_name': self.repo_build.display_name,
'docker_tags': tags,
'trigger_id': trigger.uuid if trigger is not None else None,
'trigger_kind': trigger_kind,
'trigger_metadata': self.build_config.get('trigger_metadata', {})
}
event_data = {
"build_id": self.repo_build.uuid,
"build_name": self.repo_build.display_name,
"docker_tags": tags,
"trigger_id": trigger.uuid if trigger is not None else None,
"trigger_kind": trigger_kind,
"trigger_metadata": self.build_config.get("trigger_metadata", {}),
}
if image_id is not None:
event_data['image_id'] = image_id
if image_id is not None:
event_data["image_id"] = image_id
if manifest_digests:
event_data['manifest_digests'] = manifest_digests
if manifest_digests:
event_data["manifest_digests"] = manifest_digests
if error_message is not None:
event_data['error_message'] = error_message
if error_message is not None:
event_data["error_message"] = error_message
# TODO: remove when more endpoints have been converted to using
# interfaces
repo = AttrDict({
'namespace_name': self.repo_build.repository.namespace_user.username,
'name': self.repo_build.repository.name,
})
spawn_notification(repo, kind, event_data,
subpage='build/%s' % self.repo_build.uuid,
pathargs=['build', self.repo_build.uuid])
# TODO: remove when more endpoints have been converted to using
# interfaces
repo = AttrDict(
{
"namespace_name": self.repo_build.repository.namespace_user.username,
"name": self.repo_build.repository.name,
}
)
spawn_notification(
repo,
kind,
event_data,
subpage="build/%s" % self.repo_build.uuid,
pathargs=["build", self.repo_build.uuid],
)

View file

@ -13,76 +13,94 @@ logger = logging.getLogger(__name__)
class StatusHandler(object):
""" Context wrapper for writing status to build logs. """
""" Context wrapper for writing status to build logs. """
def __init__(self, build_logs, repository_build_uuid):
self._current_phase = None
self._current_command = None
self._uuid = repository_build_uuid
self._build_logs = AsyncWrapper(build_logs)
self._sync_build_logs = build_logs
self._build_model = AsyncWrapper(model.build)
def __init__(self, build_logs, repository_build_uuid):
self._current_phase = None
self._current_command = None
self._uuid = repository_build_uuid
self._build_logs = AsyncWrapper(build_logs)
self._sync_build_logs = build_logs
self._build_model = AsyncWrapper(model.build)
self._status = {
'total_commands': 0,
'current_command': None,
'push_completion': 0.0,
'pull_completion': 0.0,
}
self._status = {
"total_commands": 0,
"current_command": None,
"push_completion": 0.0,
"pull_completion": 0.0,
}
# Write the initial status.
self.__exit__(None, None, None)
# Write the initial status.
self.__exit__(None, None, None)
@coroutine
def _append_log_message(self, log_message, log_type=None, log_data=None):
log_data = log_data or {}
log_data['datetime'] = str(datetime.datetime.now())
@coroutine
def _append_log_message(self, log_message, log_type=None, log_data=None):
log_data = log_data or {}
log_data["datetime"] = str(datetime.datetime.now())
try:
yield From(self._build_logs.append_log_message(self._uuid, log_message, log_type, log_data))
except RedisError:
logger.exception('Could not save build log for build %s: %s', self._uuid, log_message)
try:
yield From(
self._build_logs.append_log_message(
self._uuid, log_message, log_type, log_data
)
)
except RedisError:
logger.exception(
"Could not save build log for build %s: %s", self._uuid, log_message
)
@coroutine
def append_log(self, log_message, extra_data=None):
if log_message is None:
return
@coroutine
def append_log(self, log_message, extra_data=None):
if log_message is None:
return
yield From(self._append_log_message(log_message, log_data=extra_data))
yield From(self._append_log_message(log_message, log_data=extra_data))
@coroutine
def set_command(self, command, extra_data=None):
if self._current_command == command:
raise Return()
@coroutine
def set_command(self, command, extra_data=None):
if self._current_command == command:
raise Return()
self._current_command = command
yield From(self._append_log_message(command, self._build_logs.COMMAND, extra_data))
self._current_command = command
yield From(
self._append_log_message(command, self._build_logs.COMMAND, extra_data)
)
@coroutine
def set_error(self, error_message, extra_data=None, internal_error=False, requeued=False):
error_phase = BUILD_PHASE.INTERNAL_ERROR if internal_error and requeued else BUILD_PHASE.ERROR
yield From(self.set_phase(error_phase))
@coroutine
def set_error(
self, error_message, extra_data=None, internal_error=False, requeued=False
):
error_phase = (
BUILD_PHASE.INTERNAL_ERROR
if internal_error and requeued
else BUILD_PHASE.ERROR
)
yield From(self.set_phase(error_phase))
extra_data = extra_data or {}
extra_data['internal_error'] = internal_error
yield From(self._append_log_message(error_message, self._build_logs.ERROR, extra_data))
extra_data = extra_data or {}
extra_data["internal_error"] = internal_error
yield From(
self._append_log_message(error_message, self._build_logs.ERROR, extra_data)
)
@coroutine
def set_phase(self, phase, extra_data=None):
if phase == self._current_phase:
raise Return(False)
@coroutine
def set_phase(self, phase, extra_data=None):
if phase == self._current_phase:
raise Return(False)
self._current_phase = phase
yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data))
self._current_phase = phase
yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data))
# Update the repository build with the new phase
raise Return(self._build_model.update_phase_then_close(self._uuid, phase))
# Update the repository build with the new phase
raise Return(self._build_model.update_phase_then_close(self._uuid, phase))
def __enter__(self):
return self._status
def __enter__(self):
return self._status
def __exit__(self, exc_type, value, traceback):
try:
self._sync_build_logs.set_status(self._uuid, self._status)
except RedisError:
logger.exception('Could not set status of build %s to %s', self._uuid, self._status)
def __exit__(self, exc_type, value, traceback):
try:
self._sync_build_logs.set_status(self._uuid, self._status)
except RedisError:
logger.exception(
"Could not set status of build %s to %s", self._uuid, self._status
)

View file

@ -1,119 +1,99 @@
class WorkerError(object):
""" Helper class which represents errors raised by a build worker. """
def __init__(self, error_code, base_message=None):
self._error_code = error_code
self._base_message = base_message
""" Helper class which represents errors raised by a build worker. """
self._error_handlers = {
'io.quay.builder.buildpackissue': {
'message': 'Could not load build package',
'is_internal': True,
},
def __init__(self, error_code, base_message=None):
self._error_code = error_code
self._base_message = base_message
'io.quay.builder.gitfailure': {
'message': 'Could not clone git repository',
'show_base_error': True,
},
self._error_handlers = {
"io.quay.builder.buildpackissue": {
"message": "Could not load build package",
"is_internal": True,
},
"io.quay.builder.gitfailure": {
"message": "Could not clone git repository",
"show_base_error": True,
},
"io.quay.builder.gitcheckout": {
"message": "Could not checkout git ref. If you force pushed recently, "
+ "the commit may be missing.",
"show_base_error": True,
},
"io.quay.builder.cannotextractbuildpack": {
"message": "Could not extract the contents of the build package"
},
"io.quay.builder.cannotpullforcache": {
"message": "Could not pull cached image",
"is_internal": True,
},
"io.quay.builder.dockerfileissue": {
"message": "Could not find or parse Dockerfile",
"show_base_error": True,
},
"io.quay.builder.cannotpullbaseimage": {
"message": "Could not pull base image",
"show_base_error": True,
},
"io.quay.builder.internalerror": {
"message": "An internal error occurred while building. Please submit a ticket.",
"is_internal": True,
},
"io.quay.builder.buildrunerror": {
"message": "Could not start the build process",
"is_internal": True,
},
"io.quay.builder.builderror": {
"message": "A build step failed",
"show_base_error": True,
},
"io.quay.builder.tagissue": {
"message": "Could not tag built image",
"is_internal": True,
},
"io.quay.builder.pushissue": {
"message": "Could not push built image",
"show_base_error": True,
"is_internal": True,
},
"io.quay.builder.dockerconnecterror": {
"message": "Could not connect to Docker daemon",
"is_internal": True,
},
"io.quay.builder.missingorinvalidargument": {
"message": "Missing required arguments for builder",
"is_internal": True,
},
"io.quay.builder.cachelookupissue": {
"message": "Error checking for a cached tag",
"is_internal": True,
},
"io.quay.builder.errorduringphasetransition": {
"message": "Error during phase transition. If this problem persists "
+ "please contact customer support.",
"is_internal": True,
},
"io.quay.builder.clientrejectedtransition": {
"message": "Build can not be finished due to user cancellation."
},
}
'io.quay.builder.gitcheckout': {
'message': 'Could not checkout git ref. If you force pushed recently, ' +
'the commit may be missing.',
'show_base_error': True,
},
def is_internal_error(self):
handler = self._error_handlers.get(self._error_code)
return handler.get("is_internal", False) if handler else True
'io.quay.builder.cannotextractbuildpack': {
'message': 'Could not extract the contents of the build package'
},
def public_message(self):
handler = self._error_handlers.get(self._error_code)
if not handler:
return "An unknown error occurred"
'io.quay.builder.cannotpullforcache': {
'message': 'Could not pull cached image',
'is_internal': True
},
message = handler["message"]
if handler.get("show_base_error", False) and self._base_message:
message = message + ": " + self._base_message
'io.quay.builder.dockerfileissue': {
'message': 'Could not find or parse Dockerfile',
'show_base_error': True
},
return message
'io.quay.builder.cannotpullbaseimage': {
'message': 'Could not pull base image',
'show_base_error': True
},
def extra_data(self):
if self._base_message:
return {"base_error": self._base_message, "error_code": self._error_code}
'io.quay.builder.internalerror': {
'message': 'An internal error occurred while building. Please submit a ticket.',
'is_internal': True
},
'io.quay.builder.buildrunerror': {
'message': 'Could not start the build process',
'is_internal': True
},
'io.quay.builder.builderror': {
'message': 'A build step failed',
'show_base_error': True
},
'io.quay.builder.tagissue': {
'message': 'Could not tag built image',
'is_internal': True
},
'io.quay.builder.pushissue': {
'message': 'Could not push built image',
'show_base_error': True,
'is_internal': True
},
'io.quay.builder.dockerconnecterror': {
'message': 'Could not connect to Docker daemon',
'is_internal': True
},
'io.quay.builder.missingorinvalidargument': {
'message': 'Missing required arguments for builder',
'is_internal': True
},
'io.quay.builder.cachelookupissue': {
'message': 'Error checking for a cached tag',
'is_internal': True
},
'io.quay.builder.errorduringphasetransition': {
'message': 'Error during phase transition. If this problem persists ' +
'please contact customer support.',
'is_internal': True
},
'io.quay.builder.clientrejectedtransition': {
'message': 'Build can not be finished due to user cancellation.',
}
}
def is_internal_error(self):
handler = self._error_handlers.get(self._error_code)
return handler.get('is_internal', False) if handler else True
def public_message(self):
handler = self._error_handlers.get(self._error_code)
if not handler:
return 'An unknown error occurred'
message = handler['message']
if handler.get('show_base_error', False) and self._base_message:
message = message + ': ' + self._base_message
return message
def extra_data(self):
if self._base_message:
return {
'base_error': self._base_message,
'error_code': self._error_code
}
return {
'error_code': self._error_code
}
return {"error_code": self._error_code}

View file

@ -1,71 +1,80 @@
from trollius import coroutine
class BaseManager(object):
""" Base for all worker managers. """
def __init__(self, register_component, unregister_component, job_heartbeat_callback,
job_complete_callback, manager_hostname, heartbeat_period_sec):
self.register_component = register_component
self.unregister_component = unregister_component
self.job_heartbeat_callback = job_heartbeat_callback
self.job_complete_callback = job_complete_callback
self.manager_hostname = manager_hostname
self.heartbeat_period_sec = heartbeat_period_sec
""" Base for all worker managers. """
@coroutine
def job_heartbeat(self, build_job):
""" Method invoked to tell the manager that a job is still running. This method will be called
def __init__(
self,
register_component,
unregister_component,
job_heartbeat_callback,
job_complete_callback,
manager_hostname,
heartbeat_period_sec,
):
self.register_component = register_component
self.unregister_component = unregister_component
self.job_heartbeat_callback = job_heartbeat_callback
self.job_complete_callback = job_complete_callback
self.manager_hostname = manager_hostname
self.heartbeat_period_sec = heartbeat_period_sec
@coroutine
def job_heartbeat(self, build_job):
""" Method invoked to tell the manager that a job is still running. This method will be called
every few minutes. """
self.job_heartbeat_callback(build_job)
self.job_heartbeat_callback(build_job)
def overall_setup_time(self):
""" Returns the number of seconds that the build system should wait before allowing the job
def overall_setup_time(self):
""" Returns the number of seconds that the build system should wait before allowing the job
to be picked up again after called 'schedule'.
"""
raise NotImplementedError
raise NotImplementedError
def shutdown(self):
""" Indicates that the build controller server is in a shutdown state and that no new jobs
def shutdown(self):
""" Indicates that the build controller server is in a shutdown state and that no new jobs
or workers should be performed. Existing workers should be cleaned up once their jobs
have completed
"""
raise NotImplementedError
raise NotImplementedError
@coroutine
def schedule(self, build_job):
""" Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was
@coroutine
def schedule(self, build_job):
""" Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was
properly scheduled and (False, a retry timeout in seconds) if all workers are busy or an
error occurs.
"""
raise NotImplementedError
raise NotImplementedError
def initialize(self, manager_config):
""" Runs any initialization code for the manager. Called once the server is in a ready state.
def initialize(self, manager_config):
""" Runs any initialization code for the manager. Called once the server is in a ready state.
"""
raise NotImplementedError
raise NotImplementedError
@coroutine
def build_component_ready(self, build_component):
""" Method invoked whenever a build component announces itself as ready.
@coroutine
def build_component_ready(self, build_component):
""" Method invoked whenever a build component announces itself as ready.
"""
raise NotImplementedError
raise NotImplementedError
def build_component_disposed(self, build_component, timed_out):
""" Method invoked whenever a build component has been disposed. The timed_out boolean indicates
def build_component_disposed(self, build_component, timed_out):
""" Method invoked whenever a build component has been disposed. The timed_out boolean indicates
whether the component's heartbeat timed out.
"""
raise NotImplementedError
raise NotImplementedError
@coroutine
def job_completed(self, build_job, job_status, build_component):
""" Method invoked once a job_item has completed, in some manner. The job_status will be
@coroutine
def job_completed(self, build_job, job_status, build_component):
""" Method invoked once a job_item has completed, in some manner. The job_status will be
one of: incomplete, error, complete. Implementations of this method should call coroutine
self.job_complete_callback with a status of Incomplete if they wish for the job to be
automatically requeued.
"""
raise NotImplementedError
raise NotImplementedError
def num_workers(self):
""" Returns the number of active build workers currently registered. This includes those
def num_workers(self):
""" Returns the number of active build workers currently registered. This includes those
that are currently busy and awaiting more work.
"""
raise NotImplementedError
raise NotImplementedError

View file

@ -5,23 +5,23 @@ from buildman.manager.noop_canceller import NoopCanceller
logger = logging.getLogger(__name__)
CANCELLERS = {'ephemeral': OrchestratorCanceller}
CANCELLERS = {"ephemeral": OrchestratorCanceller}
class BuildCanceller(object):
""" A class to manage cancelling a build """
""" A class to manage cancelling a build """
def __init__(self, app=None):
self.build_manager_config = app.config.get('BUILD_MANAGER')
if app is None or self.build_manager_config is None:
self.handler = NoopCanceller()
else:
self.handler = None
def __init__(self, app=None):
self.build_manager_config = app.config.get("BUILD_MANAGER")
if app is None or self.build_manager_config is None:
self.handler = NoopCanceller()
else:
self.handler = None
def try_cancel_build(self, uuid):
""" A method to kill a running build """
if self.handler is None:
canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller)
self.handler = canceller(self.build_manager_config[1])
def try_cancel_build(self, uuid):
""" A method to kill a running build """
if self.handler is None:
canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller)
self.handler = canceller(self.build_manager_config[1])
return self.handler.try_cancel_build(uuid)
return self.handler.try_cancel_build(uuid)

View file

@ -7,86 +7,89 @@ from buildman.manager.basemanager import BaseManager
from trollius import From, Return, coroutine
REGISTRATION_REALM = 'registration'
REGISTRATION_REALM = "registration"
RETRY_TIMEOUT = 5
logger = logging.getLogger(__name__)
class DynamicRegistrationComponent(BaseComponent):
""" Component session that handles dynamic registration of the builder components. """
""" Component session that handles dynamic registration of the builder components. """
def onConnect(self):
self.join(REGISTRATION_REALM)
def onConnect(self):
self.join(REGISTRATION_REALM)
def onJoin(self, details):
logger.debug('Registering registration method')
yield From(self.register(self._worker_register, u'io.quay.buildworker.register'))
def onJoin(self, details):
logger.debug("Registering registration method")
yield From(
self.register(self._worker_register, u"io.quay.buildworker.register")
)
def _worker_register(self):
realm = self.parent_manager.add_build_component()
logger.debug('Registering new build component+worker with realm %s', realm)
return realm
def _worker_register(self):
realm = self.parent_manager.add_build_component()
logger.debug("Registering new build component+worker with realm %s", realm)
return realm
def kind(self):
return 'registration'
def kind(self):
return "registration"
class EnterpriseManager(BaseManager):
""" Build manager implementation for the Enterprise Registry. """
""" Build manager implementation for the Enterprise Registry. """
def __init__(self, *args, **kwargs):
self.ready_components = set()
self.all_components = set()
self.shutting_down = False
def __init__(self, *args, **kwargs):
self.ready_components = set()
self.all_components = set()
self.shutting_down = False
super(EnterpriseManager, self).__init__(*args, **kwargs)
super(EnterpriseManager, self).__init__(*args, **kwargs)
def initialize(self, manager_config):
# Add a component which is used by build workers for dynamic registration. Unlike
# production, build workers in enterprise are long-lived and register dynamically.
self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent)
def initialize(self, manager_config):
# Add a component which is used by build workers for dynamic registration. Unlike
# production, build workers in enterprise are long-lived and register dynamically.
self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent)
def overall_setup_time(self):
# Builders are already registered, so the setup time should be essentially instant. We therefore
# only return a minute here.
return 60
def overall_setup_time(self):
# Builders are already registered, so the setup time should be essentially instant. We therefore
# only return a minute here.
return 60
def add_build_component(self):
""" Adds a new build component for an Enterprise Registry. """
# Generate a new unique realm ID for the build worker.
realm = str(uuid.uuid4())
new_component = self.register_component(realm, BuildComponent, token="")
self.all_components.add(new_component)
return realm
def add_build_component(self):
""" Adds a new build component for an Enterprise Registry. """
# Generate a new unique realm ID for the build worker.
realm = str(uuid.uuid4())
new_component = self.register_component(realm, BuildComponent, token="")
self.all_components.add(new_component)
return realm
@coroutine
def schedule(self, build_job):
""" Schedules a build for an Enterprise Registry. """
if self.shutting_down or not self.ready_components:
raise Return(False, RETRY_TIMEOUT)
@coroutine
def schedule(self, build_job):
""" Schedules a build for an Enterprise Registry. """
if self.shutting_down or not self.ready_components:
raise Return(False, RETRY_TIMEOUT)
component = self.ready_components.pop()
component = self.ready_components.pop()
yield From(component.start_build(build_job))
yield From(component.start_build(build_job))
raise Return(True, None)
raise Return(True, None)
@coroutine
def build_component_ready(self, build_component):
self.ready_components.add(build_component)
@coroutine
def build_component_ready(self, build_component):
self.ready_components.add(build_component)
def shutdown(self):
self.shutting_down = True
def shutdown(self):
self.shutting_down = True
@coroutine
def job_completed(self, build_job, job_status, build_component):
yield From(self.job_complete_callback(build_job, job_status))
@coroutine
def job_completed(self, build_job, job_status, build_component):
yield From(self.job_complete_callback(build_job, job_status))
def build_component_disposed(self, build_component, timed_out):
self.all_components.remove(build_component)
if build_component in self.ready_components:
self.ready_components.remove(build_component)
def build_component_disposed(self, build_component, timed_out):
self.all_components.remove(build_component)
if build_component in self.ready_components:
self.ready_components.remove(build_component)
self.unregister_component(build_component)
self.unregister_component(build_component)
def num_workers(self):
return len(self.all_components)
def num_workers(self):
return len(self.all_components)

View file

@ -5,33 +5,36 @@ logger = logging.getLogger(__name__)
class EtcdCanceller(object):
""" A class that sends a message to etcd to cancel a build """
""" A class that sends a message to etcd to cancel a build """
def __init__(self, config):
etcd_host = config.get('ETCD_HOST', '127.0.0.1')
etcd_port = config.get('ETCD_PORT', 2379)
etcd_ca_cert = config.get('ETCD_CA_CERT', None)
etcd_auth = config.get('ETCD_CERT_AND_KEY', None)
if etcd_auth is not None:
etcd_auth = tuple(etcd_auth)
def __init__(self, config):
etcd_host = config.get("ETCD_HOST", "127.0.0.1")
etcd_port = config.get("ETCD_PORT", 2379)
etcd_ca_cert = config.get("ETCD_CA_CERT", None)
etcd_auth = config.get("ETCD_CERT_AND_KEY", None)
if etcd_auth is not None:
etcd_auth = tuple(etcd_auth)
etcd_protocol = 'http' if etcd_auth is None else 'https'
logger.debug('Connecting to etcd on %s:%s', etcd_host, etcd_port)
self._cancel_prefix = config.get('ETCD_CANCEL_PREFIX', 'cancel/')
self._etcd_client = etcd.Client(
host=etcd_host,
port=etcd_port,
cert=etcd_auth,
ca_cert=etcd_ca_cert,
protocol=etcd_protocol,
read_timeout=5)
etcd_protocol = "http" if etcd_auth is None else "https"
logger.debug("Connecting to etcd on %s:%s", etcd_host, etcd_port)
self._cancel_prefix = config.get("ETCD_CANCEL_PREFIX", "cancel/")
self._etcd_client = etcd.Client(
host=etcd_host,
port=etcd_port,
cert=etcd_auth,
ca_cert=etcd_ca_cert,
protocol=etcd_protocol,
read_timeout=5,
)
def try_cancel_build(self, build_uuid):
""" Writes etcd message to cancel build_uuid. """
logger.info("Cancelling build %s".format(build_uuid))
try:
self._etcd_client.write("{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60)
return True
except etcd.EtcdException:
logger.exception("Failed to write to etcd client %s", build_uuid)
return False
def try_cancel_build(self, build_uuid):
""" Writes etcd message to cancel build_uuid. """
logger.info("Cancelling build %s".format(build_uuid))
try:
self._etcd_client.write(
"{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60
)
return True
except etcd.EtcdException:
logger.exception("Failed to write to etcd client %s", build_uuid)
return False

File diff suppressed because it is too large Load diff

View file

@ -1,8 +1,9 @@
class NoopCanceller(object):
""" A class that can not cancel a build """
def __init__(self, config=None):
pass
""" A class that can not cancel a build """
def try_cancel_build(self, uuid):
""" Does nothing and fails to cancel build. """
return False
def __init__(self, config=None):
pass
def try_cancel_build(self, uuid):
""" Does nothing and fails to cancel build. """
return False

View file

@ -7,20 +7,23 @@ from util import slash_join
logger = logging.getLogger(__name__)
CANCEL_PREFIX = 'cancel/'
CANCEL_PREFIX = "cancel/"
class OrchestratorCanceller(object):
""" An asynchronous way to cancel a build with any Orchestrator. """
def __init__(self, config):
self._orchestrator = orchestrator_from_config(config, canceller_only=True)
""" An asynchronous way to cancel a build with any Orchestrator. """
def try_cancel_build(self, build_uuid):
logger.info('Cancelling build %s', build_uuid)
cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
try:
self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
return True
except OrchestratorError:
logger.exception('Failed to write cancel action to redis with uuid %s', build_uuid)
return False
def __init__(self, config):
self._orchestrator = orchestrator_from_config(config, canceller_only=True)
def try_cancel_build(self, build_uuid):
logger.info("Cancelling build %s", build_uuid)
cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
try:
self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
return True
except OrchestratorError:
logger.exception(
"Failed to write cancel action to redis with uuid %s", build_uuid
)
return False

File diff suppressed because it is too large Load diff

View file

@ -2,4 +2,3 @@ import buildtrigger.bitbuckethandler
import buildtrigger.customhandler
import buildtrigger.githubhandler
import buildtrigger.gitlabhandler

View file

@ -9,359 +9,360 @@ from data import model
from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException
NAMESPACES_SCHEMA = {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'personal': {
'type': 'boolean',
'description': 'True if the namespace is the user\'s personal namespace',
},
'score': {
'type': 'number',
'description': 'Score of the relevance of the namespace',
},
'avatar_url': {
'type': ['string', 'null'],
'description': 'URL of the avatar for this namespace',
},
'url': {
'type': 'string',
'description': 'URL of the website to view the namespace',
},
'id': {
'type': 'string',
'description': 'Trigger-internal ID of the namespace',
},
'title': {
'type': 'string',
'description': 'Human-readable title of the namespace',
},
"type": "array",
"items": {
"type": "object",
"properties": {
"personal": {
"type": "boolean",
"description": "True if the namespace is the user's personal namespace",
},
"score": {
"type": "number",
"description": "Score of the relevance of the namespace",
},
"avatar_url": {
"type": ["string", "null"],
"description": "URL of the avatar for this namespace",
},
"url": {
"type": "string",
"description": "URL of the website to view the namespace",
},
"id": {
"type": "string",
"description": "Trigger-internal ID of the namespace",
},
"title": {
"type": "string",
"description": "Human-readable title of the namespace",
},
},
"required": ["personal", "score", "avatar_url", "id", "title"],
},
'required': ['personal', 'score', 'avatar_url', 'id', 'title'],
},
}
BUILD_SOURCES_SCHEMA = {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'name': {
'type': 'string',
'description': 'The name of the repository, without its namespace',
},
'full_name': {
'type': 'string',
'description': 'The name of the repository, with its namespace',
},
'description': {
'type': 'string',
'description': 'The description of the repository. May be an empty string',
},
'last_updated': {
'type': 'number',
'description': 'The date/time when the repository was last updated, since epoch in UTC',
},
'url': {
'type': 'string',
'description': 'The URL at which to view the repository in the browser',
},
'has_admin_permissions': {
'type': 'boolean',
'description': 'True if the current user has admin permissions on the repository',
},
'private': {
'type': 'boolean',
'description': 'True if the repository is private',
},
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the repository, without its namespace",
},
"full_name": {
"type": "string",
"description": "The name of the repository, with its namespace",
},
"description": {
"type": "string",
"description": "The description of the repository. May be an empty string",
},
"last_updated": {
"type": "number",
"description": "The date/time when the repository was last updated, since epoch in UTC",
},
"url": {
"type": "string",
"description": "The URL at which to view the repository in the browser",
},
"has_admin_permissions": {
"type": "boolean",
"description": "True if the current user has admin permissions on the repository",
},
"private": {
"type": "boolean",
"description": "True if the repository is private",
},
},
"required": [
"name",
"full_name",
"description",
"last_updated",
"has_admin_permissions",
"private",
],
},
'required': ['name', 'full_name', 'description', 'last_updated',
'has_admin_permissions', 'private'],
},
}
METADATA_SCHEMA = {
'type': 'object',
'properties': {
'commit': {
'type': 'string',
'description': 'first 7 characters of the SHA-1 identifier for a git commit',
'pattern': '^([A-Fa-f0-9]{7,})$',
},
'git_url': {
'type': 'string',
'description': 'The GIT url to use for the checkout',
},
'ref': {
'type': 'string',
'description': 'git reference for a git commit',
'pattern': r'^refs\/(heads|tags|remotes)\/(.+)$',
},
'default_branch': {
'type': 'string',
'description': 'default branch of the git repository',
},
'commit_info': {
'type': 'object',
'description': 'metadata about a git commit',
'properties': {
'url': {
'type': 'string',
'description': 'URL to view a git commit',
"type": "object",
"properties": {
"commit": {
"type": "string",
"description": "first 7 characters of the SHA-1 identifier for a git commit",
"pattern": "^([A-Fa-f0-9]{7,})$",
},
'message': {
'type': 'string',
'description': 'git commit message',
"git_url": {
"type": "string",
"description": "The GIT url to use for the checkout",
},
'date': {
'type': 'string',
'description': 'timestamp for a git commit'
"ref": {
"type": "string",
"description": "git reference for a git commit",
"pattern": r"^refs\/(heads|tags|remotes)\/(.+)$",
},
'author': {
'type': 'object',
'description': 'metadata about the author of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the author',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
},
},
'required': ['username'],
"default_branch": {
"type": "string",
"description": "default branch of the git repository",
},
'committer': {
'type': 'object',
'description': 'metadata about the committer of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the committer',
"commit_info": {
"type": "object",
"description": "metadata about a git commit",
"properties": {
"url": {"type": "string", "description": "URL to view a git commit"},
"message": {"type": "string", "description": "git commit message"},
"date": {"type": "string", "description": "timestamp for a git commit"},
"author": {
"type": "object",
"description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
},
"url": {
"type": "string",
"description": "URL to view the profile of the author",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the author",
},
},
"required": ["username"],
},
"committer": {
"type": "object",
"description": "metadata about the committer of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the committer",
},
"url": {
"type": "string",
"description": "URL to view the profile of the committer",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the committer",
},
},
"required": ["username"],
},
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the committer',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the committer',
},
},
'required': ['username'],
"required": ["message"],
},
},
'required': ['message'],
},
},
'required': ['commit', 'git_url'],
"required": ["commit", "git_url"],
}
@add_metaclass(ABCMeta)
class BuildTriggerHandler(object):
def __init__(self, trigger, override_config=None):
self.trigger = trigger
self.config = override_config or get_trigger_config(trigger)
def __init__(self, trigger, override_config=None):
self.trigger = trigger
self.config = override_config or get_trigger_config(trigger)
@property
def auth_token(self):
""" Returns the auth token for the trigger. """
# NOTE: This check is for testing.
if isinstance(self.trigger.auth_token, str):
return self.trigger.auth_token
@property
def auth_token(self):
""" Returns the auth token for the trigger. """
# NOTE: This check is for testing.
if isinstance(self.trigger.auth_token, str):
return self.trigger.auth_token
# TODO(remove-unenc): Remove legacy field.
if self.trigger.secure_auth_token is not None:
return self.trigger.secure_auth_token.decrypt()
# TODO(remove-unenc): Remove legacy field.
if self.trigger.secure_auth_token is not None:
return self.trigger.secure_auth_token.decrypt()
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
return self.trigger.auth_token
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
return self.trigger.auth_token
return None
return None
@abstractmethod
def load_dockerfile_contents(self):
"""
@abstractmethod
def load_dockerfile_contents(self):
"""
Loads the Dockerfile found for the trigger's config and returns them or None if none could
be found/loaded.
"""
pass
pass
@abstractmethod
def list_build_source_namespaces(self):
"""
@abstractmethod
def list_build_source_namespaces(self):
"""
Take the auth information for the specific trigger type and load the
list of namespaces that can contain build sources.
"""
pass
pass
@abstractmethod
def list_build_sources_for_namespace(self, namespace):
"""
@abstractmethod
def list_build_sources_for_namespace(self, namespace):
"""
Take the auth information for the specific trigger type and load the
list of repositories under the given namespace.
"""
pass
pass
@abstractmethod
def list_build_subdirs(self):
"""
@abstractmethod
def list_build_subdirs(self):
"""
Take the auth information and the specified config so far and list all of
the possible subdirs containing dockerfiles.
"""
pass
pass
@abstractmethod
def handle_trigger_request(self, request):
"""
@abstractmethod
def handle_trigger_request(self, request):
"""
Transform the incoming request data into a set of actions. Returns a PreparedBuild.
"""
pass
pass
@abstractmethod
def is_active(self):
"""
@abstractmethod
def is_active(self):
"""
Returns True if the current build trigger is active. Inactive means further
setup is needed.
"""
pass
pass
@abstractmethod
def activate(self, standard_webhook_url):
"""
@abstractmethod
def activate(self, standard_webhook_url):
"""
Activates the trigger for the service, with the given new configuration.
Returns new public and private config that should be stored if successful.
"""
pass
pass
@abstractmethod
def deactivate(self):
"""
@abstractmethod
def deactivate(self):
"""
Deactivates the trigger for the service, removing any hooks installed in
the remote service. Returns the new config that should be stored if this
trigger is going to be re-activated.
"""
pass
pass
@abstractmethod
def manual_start(self, run_parameters=None):
"""
@abstractmethod
def manual_start(self, run_parameters=None):
"""
Manually creates a repository build for this trigger. Returns a PreparedBuild.
"""
pass
pass
@abstractmethod
def list_field_values(self, field_name, limit=None):
"""
@abstractmethod
def list_field_values(self, field_name, limit=None):
"""
Lists all values for the given custom trigger field. For example, a trigger might have a
field named "branches", and this method would return all branches.
"""
pass
pass
@abstractmethod
def get_repository_url(self):
""" Returns the URL of the current trigger's repository. Note that this operation
@abstractmethod
def get_repository_url(self):
""" Returns the URL of the current trigger's repository. Note that this operation
can be called in a loop, so it should be as fast as possible. """
pass
pass
@classmethod
def filename_is_dockerfile(cls, file_name):
""" Returns whether the file is named Dockerfile or follows the convention <name>.Dockerfile"""
return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name
@classmethod
def filename_is_dockerfile(cls, file_name):
""" Returns whether the file is named Dockerfile or follows the convention <name>.Dockerfile"""
return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name
@classmethod
def service_name(cls):
"""
@classmethod
def service_name(cls):
"""
Particular service implemented by subclasses.
"""
raise NotImplementedError
raise NotImplementedError
@classmethod
def get_handler(cls, trigger, override_config=None):
for subc in cls.__subclasses__():
if subc.service_name() == trigger.service.name:
return subc(trigger, override_config)
@classmethod
def get_handler(cls, trigger, override_config=None):
for subc in cls.__subclasses__():
if subc.service_name() == trigger.service.name:
return subc(trigger, override_config)
raise InvalidServiceException('Unable to find service: %s' % trigger.service.name)
raise InvalidServiceException(
"Unable to find service: %s" % trigger.service.name
)
def put_config_key(self, key, value):
""" Updates a config key in the trigger, saving it to the DB. """
self.config[key] = value
model.build.update_build_trigger(self.trigger, self.config)
def put_config_key(self, key, value):
""" Updates a config key in the trigger, saving it to the DB. """
self.config[key] = value
model.build.update_build_trigger(self.trigger, self.config)
def set_auth_token(self, auth_token):
""" Sets the auth token for the trigger, saving it to the DB. """
model.build.update_build_trigger(self.trigger, self.config, auth_token=auth_token)
def set_auth_token(self, auth_token):
""" Sets the auth token for the trigger, saving it to the DB. """
model.build.update_build_trigger(
self.trigger, self.config, auth_token=auth_token
)
def get_dockerfile_path(self):
""" Returns the normalized path to the Dockerfile found in the subdirectory
def get_dockerfile_path(self):
""" Returns the normalized path to the Dockerfile found in the subdirectory
in the config. """
dockerfile_path = self.config.get('dockerfile_path') or 'Dockerfile'
if dockerfile_path[0] == '/':
dockerfile_path = dockerfile_path[1:]
return dockerfile_path
dockerfile_path = self.config.get("dockerfile_path") or "Dockerfile"
if dockerfile_path[0] == "/":
dockerfile_path = dockerfile_path[1:]
return dockerfile_path
def prepare_build(self, metadata, is_manual=False):
# Ensure that the metadata meets the scheme.
validate(metadata, METADATA_SCHEMA)
def prepare_build(self, metadata, is_manual=False):
# Ensure that the metadata meets the scheme.
validate(metadata, METADATA_SCHEMA)
config = self.config
ref = metadata.get('ref', None)
commit_sha = metadata['commit']
default_branch = metadata.get('default_branch', None)
prepared = PreparedBuild(self.trigger)
prepared.name_from_sha(commit_sha)
prepared.subdirectory = config.get('dockerfile_path', None)
prepared.context = config.get('context', None)
prepared.is_manual = is_manual
prepared.metadata = metadata
config = self.config
ref = metadata.get("ref", None)
commit_sha = metadata["commit"]
default_branch = metadata.get("default_branch", None)
prepared = PreparedBuild(self.trigger)
prepared.name_from_sha(commit_sha)
prepared.subdirectory = config.get("dockerfile_path", None)
prepared.context = config.get("context", None)
prepared.is_manual = is_manual
prepared.metadata = metadata
if ref is not None:
prepared.tags_from_ref(ref, default_branch)
else:
prepared.tags = [commit_sha[:7]]
if ref is not None:
prepared.tags_from_ref(ref, default_branch)
else:
prepared.tags = [commit_sha[:7]]
return prepared
return prepared
@classmethod
def build_sources_response(cls, sources):
validate(sources, BUILD_SOURCES_SCHEMA)
return sources
@classmethod
def build_sources_response(cls, sources):
validate(sources, BUILD_SOURCES_SCHEMA)
return sources
@classmethod
def build_namespaces_response(cls, namespaces_dict):
namespaces = list(namespaces_dict.values())
validate(namespaces, NAMESPACES_SCHEMA)
return namespaces
@classmethod
def build_namespaces_response(cls, namespaces_dict):
namespaces = list(namespaces_dict.values())
validate(namespaces, NAMESPACES_SCHEMA)
return namespaces
@classmethod
def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None):
""" Returns a map of dockerfile_paths to it's possible contexts. """
if dockerfile_path == "":
return {}
@classmethod
def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None):
""" Returns a map of dockerfile_paths to it's possible contexts. """
if dockerfile_path == "":
return {}
if dockerfile_path[0] != os.path.sep:
dockerfile_path = os.path.sep + dockerfile_path
if dockerfile_path[0] != os.path.sep:
dockerfile_path = os.path.sep + dockerfile_path
dockerfile_path = os.path.normpath(dockerfile_path)
all_paths = set()
path, _ = os.path.split(dockerfile_path)
if path == "":
path = os.path.sep
dockerfile_path = os.path.normpath(dockerfile_path)
all_paths = set()
path, _ = os.path.split(dockerfile_path)
if path == "":
path = os.path.sep
all_paths.add(path)
for i in range(1, len(path.split(os.path.sep))):
path, _ = os.path.split(path)
all_paths.add(path)
all_paths.add(path)
for i in range(1, len(path.split(os.path.sep))):
path, _ = os.path.split(path)
all_paths.add(path)
if current_paths:
return dict({dockerfile_path: list(all_paths)}, **current_paths)
if current_paths:
return dict({dockerfile_path: list(all_paths)}, **current_paths)
return {dockerfile_path: list(all_paths)}
return {dockerfile_path: list(all_paths)}

View file

@ -9,537 +9,551 @@ from jsonschema import validate
from app import app, get_app_url
from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
TriggerDeactivationException, TriggerStartException,
InvalidPayloadException, TriggerProviderException,
SkipRequestException,
determine_build_ref, raise_if_skipped_build,
find_matching_branches)
from buildtrigger.triggerutil import (
RepositoryReadException,
TriggerActivationException,
TriggerDeactivationException,
TriggerStartException,
InvalidPayloadException,
TriggerProviderException,
SkipRequestException,
determine_build_ref,
raise_if_skipped_build,
find_matching_branches,
)
from util.dict_wrappers import JSONPathDict, SafeDictSetter
from util.security.ssh import generate_ssh_keypair
logger = logging.getLogger(__name__)
_BITBUCKET_COMMIT_URL = 'https://bitbucket.org/%s/commits/%s'
_RAW_AUTHOR_REGEX = re.compile(r'.*<(.+)>')
_BITBUCKET_COMMIT_URL = "https://bitbucket.org/%s/commits/%s"
_RAW_AUTHOR_REGEX = re.compile(r".*<(.+)>")
BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = {
'type': 'object',
'properties': {
'repository': {
'type': 'object',
'properties': {
'full_name': {
'type': 'string',
},
},
'required': ['full_name'],
}, # /Repository
'push': {
'type': 'object',
'properties': {
'changes': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'new': {
'type': 'object',
'properties': {
'target': {
'type': 'object',
'properties': {
'hash': {
'type': 'string'
},
'message': {
'type': 'string'
},
'date': {
'type': 'string'
},
'author': {
'type': 'object',
'properties': {
'user': {
'type': 'object',
'properties': {
'display_name': {
'type': 'string',
},
'account_id': {
'type': 'string',
},
'links': {
'type': 'object',
'properties': {
'avatar': {
'type': 'object',
'properties': {
'href': {
'type': 'string',
},
},
'required': ['href'],
},
"type": "object",
"properties": {
"repository": {
"type": "object",
"properties": {"full_name": {"type": "string"}},
"required": ["full_name"],
}, # /Repository
"push": {
"type": "object",
"properties": {
"changes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"new": {
"type": "object",
"properties": {
"target": {
"type": "object",
"properties": {
"hash": {"type": "string"},
"message": {"type": "string"},
"date": {"type": "string"},
"author": {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"display_name": {
"type": "string"
},
"account_id": {
"type": "string"
},
"links": {
"type": "object",
"properties": {
"avatar": {
"type": "object",
"properties": {
"href": {
"type": "string"
}
},
"required": [
"href"
],
}
},
"required": ["avatar"],
}, # /User
},
} # /Author
},
},
},
"required": ["hash", "message", "date"],
} # /Target
},
'required': ['avatar'],
}, # /User
},
}, # /Author
"required": ["name", "target"],
} # /New
},
},
},
'required': ['hash', 'message', 'date'],
}, # /Target
},
'required': ['name', 'target'],
}, # /New
}, # /Changes item
} # /Changes
},
}, # /Changes item
}, # /Changes
},
'required': ['changes'],
}, # / Push
},
'actor': {
'type': 'object',
'properties': {
'account_id': {
'type': 'string',
},
'display_name': {
'type': 'string',
},
'links': {
'type': 'object',
'properties': {
'avatar': {
'type': 'object',
'properties': {
'href': {
'type': 'string',
},
},
'required': ['href'],
},
},
'required': ['avatar'],
},
"required": ["changes"],
}, # / Push
},
}, # /Actor
'required': ['push', 'repository'],
} # /Root
"actor": {
"type": "object",
"properties": {
"account_id": {"type": "string"},
"display_name": {"type": "string"},
"links": {
"type": "object",
"properties": {
"avatar": {
"type": "object",
"properties": {"href": {"type": "string"}},
"required": ["href"],
}
},
"required": ["avatar"],
},
},
}, # /Actor
"required": ["push", "repository"],
} # /Root
BITBUCKET_COMMIT_INFO_SCHEMA = {
'type': 'object',
'properties': {
'node': {
'type': 'string',
"type": "object",
"properties": {
"node": {"type": "string"},
"message": {"type": "string"},
"timestamp": {"type": "string"},
"raw_author": {"type": "string"},
},
'message': {
'type': 'string',
},
'timestamp': {
'type': 'string',
},
'raw_author': {
'type': 'string',
},
},
'required': ['node', 'message', 'timestamp']
"required": ["node", "message", "timestamp"],
}
def get_transformed_commit_info(bb_commit, ref, default_branch, repository_name, lookup_author):
""" Returns the BitBucket commit information transformed into our own
def get_transformed_commit_info(
bb_commit, ref, default_branch, repository_name, lookup_author
):
""" Returns the BitBucket commit information transformed into our own
payload format.
"""
try:
validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
except Exception as exc:
logger.exception('Exception when validating Bitbucket commit information: %s from %s', exc.message, bb_commit)
raise InvalidPayloadException(exc.message)
try:
validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
except Exception as exc:
logger.exception(
"Exception when validating Bitbucket commit information: %s from %s",
exc.message,
bb_commit,
)
raise InvalidPayloadException(exc.message)
commit = JSONPathDict(bb_commit)
commit = JSONPathDict(bb_commit)
config = SafeDictSetter()
config['commit'] = commit['node']
config['ref'] = ref
config['default_branch'] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name
config = SafeDictSetter()
config["commit"] = commit["node"]
config["ref"] = ref
config["default_branch"] = default_branch
config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = _BITBUCKET_COMMIT_URL % (repository_name, commit['node'])
config['commit_info.message'] = commit['message']
config['commit_info.date'] = commit['timestamp']
config["commit_info.url"] = _BITBUCKET_COMMIT_URL % (
repository_name,
commit["node"],
)
config["commit_info.message"] = commit["message"]
config["commit_info.date"] = commit["timestamp"]
match = _RAW_AUTHOR_REGEX.match(commit['raw_author'])
if match:
author = lookup_author(match.group(1))
author_info = JSONPathDict(author) if author is not None else None
if author_info:
config['commit_info.author.username'] = author_info['user.display_name']
config['commit_info.author.avatar_url'] = author_info['user.avatar']
match = _RAW_AUTHOR_REGEX.match(commit["raw_author"])
if match:
author = lookup_author(match.group(1))
author_info = JSONPathDict(author) if author is not None else None
if author_info:
config["commit_info.author.username"] = author_info["user.display_name"]
config["commit_info.author.avatar_url"] = author_info["user.avatar"]
return config.dict_value()
return config.dict_value()
def get_transformed_webhook_payload(bb_payload, default_branch=None):
""" Returns the BitBucket webhook JSON payload transformed into our own payload
""" Returns the BitBucket webhook JSON payload transformed into our own payload
format. If the bb_payload is not valid, returns None.
"""
try:
validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
except Exception as exc:
logger.exception('Exception when validating Bitbucket webhook payload: %s from %s', exc.message,
bb_payload)
raise InvalidPayloadException(exc.message)
try:
validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
except Exception as exc:
logger.exception(
"Exception when validating Bitbucket webhook payload: %s from %s",
exc.message,
bb_payload,
)
raise InvalidPayloadException(exc.message)
payload = JSONPathDict(bb_payload)
change = payload['push.changes[-1].new']
if not change:
raise SkipRequestException
payload = JSONPathDict(bb_payload)
change = payload["push.changes[-1].new"]
if not change:
raise SkipRequestException
is_branch = change['type'] == 'branch'
ref = 'refs/heads/' + change['name'] if is_branch else 'refs/tags/' + change['name']
is_branch = change["type"] == "branch"
ref = "refs/heads/" + change["name"] if is_branch else "refs/tags/" + change["name"]
repository_name = payload['repository.full_name']
target = change['target']
repository_name = payload["repository.full_name"]
target = change["target"]
config = SafeDictSetter()
config['commit'] = target['hash']
config['ref'] = ref
config['default_branch'] = default_branch
config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name
config = SafeDictSetter()
config["commit"] = target["hash"]
config["ref"] = ref
config["default_branch"] = default_branch
config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
config['commit_info.url'] = target['links.html.href'] or ''
config['commit_info.message'] = target['message']
config['commit_info.date'] = target['date']
config["commit_info.url"] = target["links.html.href"] or ""
config["commit_info.message"] = target["message"]
config["commit_info.date"] = target["date"]
config['commit_info.author.username'] = target['author.user.display_name']
config['commit_info.author.avatar_url'] = target['author.user.links.avatar.href']
config["commit_info.author.username"] = target["author.user.display_name"]
config["commit_info.author.avatar_url"] = target["author.user.links.avatar.href"]
config['commit_info.committer.username'] = payload['actor.display_name']
config['commit_info.committer.avatar_url'] = payload['actor.links.avatar.href']
return config.dict_value()
config["commit_info.committer.username"] = payload["actor.display_name"]
config["commit_info.committer.avatar_url"] = payload["actor.links.avatar.href"]
return config.dict_value()
class BitbucketBuildTrigger(BuildTriggerHandler):
"""
"""
BuildTrigger for Bitbucket.
"""
@classmethod
def service_name(cls):
return 'bitbucket'
def _get_client(self):
""" Returns a BitBucket API client for this trigger's config. """
key = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_KEY', '')
secret = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_SECRET', '')
@classmethod
def service_name(cls):
return "bitbucket"
trigger_uuid = self.trigger.uuid
callback_url = '%s/oauth1/bitbucket/callback/trigger/%s' % (get_app_url(), trigger_uuid)
def _get_client(self):
""" Returns a BitBucket API client for this trigger's config. """
key = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get("CONSUMER_KEY", "")
secret = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get(
"CONSUMER_SECRET", ""
)
return BitBucket(key, secret, callback_url, timeout=15)
trigger_uuid = self.trigger.uuid
callback_url = "%s/oauth1/bitbucket/callback/trigger/%s" % (
get_app_url(),
trigger_uuid,
)
def _get_authorized_client(self):
""" Returns an authorized API client. """
base_client = self._get_client()
auth_token = self.auth_token or 'invalid:invalid'
token_parts = auth_token.split(':')
if len(token_parts) != 2:
token_parts = ['invalid', 'invalid']
return BitBucket(key, secret, callback_url, timeout=15)
(access_token, access_token_secret) = token_parts
return base_client.get_authorized_client(access_token, access_token_secret)
def _get_authorized_client(self):
""" Returns an authorized API client. """
base_client = self._get_client()
auth_token = self.auth_token or "invalid:invalid"
token_parts = auth_token.split(":")
if len(token_parts) != 2:
token_parts = ["invalid", "invalid"]
def _get_repository_client(self):
""" Returns an API client for working with this config's BB repository. """
source = self.config['build_source']
(namespace, name) = source.split('/')
bitbucket_client = self._get_authorized_client()
return bitbucket_client.for_namespace(namespace).repositories().get(name)
(access_token, access_token_secret) = token_parts
return base_client.get_authorized_client(access_token, access_token_secret)
def _get_default_branch(self, repository, default_value='master'):
""" Returns the default branch for the repository or the value given. """
(result, data, _) = repository.get_main_branch()
if result:
return data['name']
def _get_repository_client(self):
""" Returns an API client for working with this config's BB repository. """
source = self.config["build_source"]
(namespace, name) = source.split("/")
bitbucket_client = self._get_authorized_client()
return bitbucket_client.for_namespace(namespace).repositories().get(name)
return default_value
def _get_default_branch(self, repository, default_value="master"):
""" Returns the default branch for the repository or the value given. """
(result, data, _) = repository.get_main_branch()
if result:
return data["name"]
def get_oauth_url(self):
""" Returns the OAuth URL to authorize Bitbucket. """
bitbucket_client = self._get_client()
(result, data, err_msg) = bitbucket_client.get_authorization_url()
if not result:
raise TriggerProviderException(err_msg)
return default_value
return data
def get_oauth_url(self):
""" Returns the OAuth URL to authorize Bitbucket. """
bitbucket_client = self._get_client()
(result, data, err_msg) = bitbucket_client.get_authorization_url()
if not result:
raise TriggerProviderException(err_msg)
def exchange_verifier(self, verifier):
""" Exchanges the given verifier token to setup this trigger. """
bitbucket_client = self._get_client()
access_token = self.config.get('access_token', '')
access_token_secret = self.auth_token
return data
# Exchange the verifier for a new access token.
(result, data, _) = bitbucket_client.verify_token(access_token, access_token_secret, verifier)
if not result:
return False
def exchange_verifier(self, verifier):
""" Exchanges the given verifier token to setup this trigger. """
bitbucket_client = self._get_client()
access_token = self.config.get("access_token", "")
access_token_secret = self.auth_token
# Save the updated access token and secret.
self.set_auth_token(data[0] + ':' + data[1])
# Exchange the verifier for a new access token.
(result, data, _) = bitbucket_client.verify_token(
access_token, access_token_secret, verifier
)
if not result:
return False
# Retrieve the current authorized user's information and store the username in the config.
authorized_client = self._get_authorized_client()
(result, data, _) = authorized_client.get_current_user()
if not result:
return False
# Save the updated access token and secret.
self.set_auth_token(data[0] + ":" + data[1])
self.put_config_key('account_id', data['user']['account_id'])
self.put_config_key('nickname', data['user']['nickname'])
return True
# Retrieve the current authorized user's information and store the username in the config.
authorized_client = self._get_authorized_client()
(result, data, _) = authorized_client.get_current_user()
if not result:
return False
def is_active(self):
return 'webhook_id' in self.config
self.put_config_key("account_id", data["user"]["account_id"])
self.put_config_key("nickname", data["user"]["nickname"])
return True
def activate(self, standard_webhook_url):
config = self.config
def is_active(self):
return "webhook_id" in self.config
# Add a deploy key to the repository.
public_key, private_key = generate_ssh_keypair()
config['credentials'] = [
{
'name': 'SSH Public Key',
'value': public_key,
},
]
def activate(self, standard_webhook_url):
config = self.config
repository = self._get_repository_client()
(result, created_deploykey, err_msg) = repository.deploykeys().create(
app.config['REGISTRY_TITLE'] + ' webhook key', public_key)
# Add a deploy key to the repository.
public_key, private_key = generate_ssh_keypair()
config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
if not result:
msg = 'Unable to add deploy key to repository: %s' % err_msg
raise TriggerActivationException(msg)
repository = self._get_repository_client()
(result, created_deploykey, err_msg) = repository.deploykeys().create(
app.config["REGISTRY_TITLE"] + " webhook key", public_key
)
config['deploy_key_id'] = created_deploykey['pk']
if not result:
msg = "Unable to add deploy key to repository: %s" % err_msg
raise TriggerActivationException(msg)
# Add a webhook callback.
description = 'Webhook for invoking builds on %s' % app.config['REGISTRY_TITLE_SHORT']
webhook_events = ['repo:push']
(result, created_webhook, err_msg) = repository.webhooks().create(
description, standard_webhook_url, webhook_events)
config["deploy_key_id"] = created_deploykey["pk"]
if not result:
msg = 'Unable to add webhook to repository: %s' % err_msg
raise TriggerActivationException(msg)
# Add a webhook callback.
description = (
"Webhook for invoking builds on %s" % app.config["REGISTRY_TITLE_SHORT"]
)
webhook_events = ["repo:push"]
(result, created_webhook, err_msg) = repository.webhooks().create(
description, standard_webhook_url, webhook_events
)
config['webhook_id'] = created_webhook['uuid']
self.config = config
return config, {'private_key': private_key}
if not result:
msg = "Unable to add webhook to repository: %s" % err_msg
raise TriggerActivationException(msg)
def deactivate(self):
config = self.config
config["webhook_id"] = created_webhook["uuid"]
self.config = config
return config, {"private_key": private_key}
webhook_id = config.pop('webhook_id', None)
deploy_key_id = config.pop('deploy_key_id', None)
repository = self._get_repository_client()
def deactivate(self):
config = self.config
# Remove the webhook.
if webhook_id is not None:
(result, _, err_msg) = repository.webhooks().delete(webhook_id)
if not result:
msg = 'Unable to remove webhook from repository: %s' % err_msg
raise TriggerDeactivationException(msg)
webhook_id = config.pop("webhook_id", None)
deploy_key_id = config.pop("deploy_key_id", None)
repository = self._get_repository_client()
# Remove the public key.
if deploy_key_id is not None:
(result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
if not result:
msg = 'Unable to remove deploy key from repository: %s' % err_msg
raise TriggerDeactivationException(msg)
# Remove the webhook.
if webhook_id is not None:
(result, _, err_msg) = repository.webhooks().delete(webhook_id)
if not result:
msg = "Unable to remove webhook from repository: %s" % err_msg
raise TriggerDeactivationException(msg)
return config
# Remove the public key.
if deploy_key_id is not None:
(result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
if not result:
msg = "Unable to remove deploy key from repository: %s" % err_msg
raise TriggerDeactivationException(msg)
def list_build_source_namespaces(self):
bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException('Could not read repository list: ' + err_msg)
return config
namespaces = {}
for repo in data:
owner = repo['owner']
def list_build_source_namespaces(self):
bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException("Could not read repository list: " + err_msg)
if owner in namespaces:
namespaces[owner]['score'] = namespaces[owner]['score'] + 1
else:
namespaces[owner] = {
'personal': owner == self.config.get('nickname', self.config.get('username')),
'id': owner,
'title': owner,
'avatar_url': repo['logo'],
'url': 'https://bitbucket.org/%s' % (owner),
'score': 1,
}
namespaces = {}
for repo in data:
owner = repo["owner"]
return BuildTriggerHandler.build_namespaces_response(namespaces)
if owner in namespaces:
namespaces[owner]["score"] = namespaces[owner]["score"] + 1
else:
namespaces[owner] = {
"personal": owner
== self.config.get("nickname", self.config.get("username")),
"id": owner,
"title": owner,
"avatar_url": repo["logo"],
"url": "https://bitbucket.org/%s" % (owner),
"score": 1,
}
def list_build_sources_for_namespace(self, namespace):
def repo_view(repo):
last_modified = dateutil.parser.parse(repo['utc_last_updated'])
return BuildTriggerHandler.build_namespaces_response(namespaces)
return {
'name': repo['slug'],
'full_name': '%s/%s' % (repo['owner'], repo['slug']),
'description': repo['description'] or '',
'last_updated': timegm(last_modified.utctimetuple()),
'url': 'https://bitbucket.org/%s/%s' % (repo['owner'], repo['slug']),
'has_admin_permissions': repo['read_only'] is False,
'private': repo['is_private'],
}
def list_build_sources_for_namespace(self, namespace):
def repo_view(repo):
last_modified = dateutil.parser.parse(repo["utc_last_updated"])
bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException('Could not read repository list: ' + err_msg)
return {
"name": repo["slug"],
"full_name": "%s/%s" % (repo["owner"], repo["slug"]),
"description": repo["description"] or "",
"last_updated": timegm(last_modified.utctimetuple()),
"url": "https://bitbucket.org/%s/%s" % (repo["owner"], repo["slug"]),
"has_admin_permissions": repo["read_only"] is False,
"private": repo["is_private"],
}
repos = [repo_view(repo) for repo in data if repo['owner'] == namespace]
return BuildTriggerHandler.build_sources_response(repos)
bitbucket_client = self._get_authorized_client()
(result, data, err_msg) = bitbucket_client.get_visible_repositories()
if not result:
raise RepositoryReadException("Could not read repository list: " + err_msg)
def list_build_subdirs(self):
config = self.config
repository = self._get_repository_client()
repos = [repo_view(repo) for repo in data if repo["owner"] == namespace]
return BuildTriggerHandler.build_sources_response(repos)
# Find the first matching branch.
repo_branches = self.list_field_values('branch_name') or []
branches = find_matching_branches(config, repo_branches)
if not branches:
branches = [self._get_default_branch(repository)]
def list_build_subdirs(self):
config = self.config
repository = self._get_repository_client()
(result, data, err_msg) = repository.get_path_contents('', revision=branches[0])
if not result:
raise RepositoryReadException(err_msg)
# Find the first matching branch.
repo_branches = self.list_field_values("branch_name") or []
branches = find_matching_branches(config, repo_branches)
if not branches:
branches = [self._get_default_branch(repository)]
files = set([f['path'] for f in data['files']])
return ["/" + file_path for file_path in files if self.filename_is_dockerfile(os.path.basename(file_path))]
(result, data, err_msg) = repository.get_path_contents("", revision=branches[0])
if not result:
raise RepositoryReadException(err_msg)
def load_dockerfile_contents(self):
repository = self._get_repository_client()
path = self.get_dockerfile_path()
files = set([f["path"] for f in data["files"]])
return [
"/" + file_path
for file_path in files
if self.filename_is_dockerfile(os.path.basename(file_path))
]
(result, data, err_msg) = repository.get_raw_path_contents(path, revision='master')
if not result:
return None
def load_dockerfile_contents(self):
repository = self._get_repository_client()
path = self.get_dockerfile_path()
return data
(result, data, err_msg) = repository.get_raw_path_contents(
path, revision="master"
)
if not result:
return None
def list_field_values(self, field_name, limit=None):
if 'build_source' not in self.config:
return None
return data
source = self.config['build_source']
(namespace, name) = source.split('/')
def list_field_values(self, field_name, limit=None):
if "build_source" not in self.config:
return None
bitbucket_client = self._get_authorized_client()
repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
source = self.config["build_source"]
(namespace, name) = source.split("/")
bitbucket_client = self._get_authorized_client()
repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
if field_name == "refs":
(result, data, _) = repository.get_branches_and_tags()
if not result:
return None
branches = [b["name"] for b in data["branches"]]
tags = [t["name"] for t in data["tags"]]
return [{"kind": "branch", "name": b} for b in branches] + [
{"kind": "tag", "name": tag} for tag in tags
]
if field_name == "tag_name":
(result, data, _) = repository.get_tags()
if not result:
return None
tags = list(data.keys())
if limit:
tags = tags[0:limit]
return tags
if field_name == "branch_name":
(result, data, _) = repository.get_branches()
if not result:
return None
branches = list(data.keys())
if limit:
branches = branches[0:limit]
return branches
if field_name == 'refs':
(result, data, _) = repository.get_branches_and_tags()
if not result:
return None
branches = [b['name'] for b in data['branches']]
tags = [t['name'] for t in data['tags']]
def get_repository_url(self):
source = self.config["build_source"]
(namespace, name) = source.split("/")
return "https://bitbucket.org/%s/%s" % (namespace, name)
return ([{'kind': 'branch', 'name': b} for b in branches] +
[{'kind': 'tag', 'name': tag} for tag in tags])
def handle_trigger_request(self, request):
payload = request.get_json()
if payload is None:
raise InvalidPayloadException("Missing payload")
if field_name == 'tag_name':
(result, data, _) = repository.get_tags()
if not result:
return None
logger.debug("Got BitBucket request: %s", payload)
tags = list(data.keys())
if limit:
tags = tags[0:limit]
repository = self._get_repository_client()
default_branch = self._get_default_branch(repository)
return tags
metadata = get_transformed_webhook_payload(
payload, default_branch=default_branch
)
prepared = self.prepare_build(metadata)
if field_name == 'branch_name':
(result, data, _) = repository.get_branches()
if not result:
return None
# Check if we should skip this build.
raise_if_skipped_build(prepared, self.config)
return prepared
branches = list(data.keys())
if limit:
branches = branches[0:limit]
def manual_start(self, run_parameters=None):
run_parameters = run_parameters or {}
repository = self._get_repository_client()
bitbucket_client = self._get_authorized_client()
return branches
def get_branch_sha(branch_name):
# Lookup the commit SHA for the branch.
(result, data, _) = repository.get_branch(branch_name)
if not result:
raise TriggerStartException("Could not find branch in repository")
return None
return data["target"]["hash"]
def get_repository_url(self):
source = self.config['build_source']
(namespace, name) = source.split('/')
return 'https://bitbucket.org/%s/%s' % (namespace, name)
def get_tag_sha(tag_name):
# Lookup the commit SHA for the tag.
(result, data, _) = repository.get_tag(tag_name)
if not result:
raise TriggerStartException("Could not find tag in repository")
def handle_trigger_request(self, request):
payload = request.get_json()
if payload is None:
raise InvalidPayloadException('Missing payload')
return data["target"]["hash"]
logger.debug('Got BitBucket request: %s', payload)
def lookup_author(email_address):
(result, data, _) = bitbucket_client.accounts().get_profile(email_address)
return data if result else None
repository = self._get_repository_client()
default_branch = self._get_default_branch(repository)
# Find the branch or tag to build.
default_branch = self._get_default_branch(repository)
(commit_sha, ref) = determine_build_ref(
run_parameters, get_branch_sha, get_tag_sha, default_branch
)
metadata = get_transformed_webhook_payload(payload, default_branch=default_branch)
prepared = self.prepare_build(metadata)
# Lookup the commit SHA in BitBucket.
(result, commit_info, _) = repository.changesets().get(commit_sha)
if not result:
raise TriggerStartException("Could not lookup commit SHA")
# Check if we should skip this build.
raise_if_skipped_build(prepared, self.config)
return prepared
# Return a prepared build for the commit.
repository_name = "%s/%s" % (repository.namespace, repository.repository_name)
metadata = get_transformed_commit_info(
commit_info, ref, default_branch, repository_name, lookup_author
)
def manual_start(self, run_parameters=None):
run_parameters = run_parameters or {}
repository = self._get_repository_client()
bitbucket_client = self._get_authorized_client()
def get_branch_sha(branch_name):
# Lookup the commit SHA for the branch.
(result, data, _) = repository.get_branch(branch_name)
if not result:
raise TriggerStartException('Could not find branch in repository')
return data['target']['hash']
def get_tag_sha(tag_name):
# Lookup the commit SHA for the tag.
(result, data, _) = repository.get_tag(tag_name)
if not result:
raise TriggerStartException('Could not find tag in repository')
return data['target']['hash']
def lookup_author(email_address):
(result, data, _) = bitbucket_client.accounts().get_profile(email_address)
return data if result else None
# Find the branch or tag to build.
default_branch = self._get_default_branch(repository)
(commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha,
default_branch)
# Lookup the commit SHA in BitBucket.
(result, commit_info, _) = repository.changesets().get(commit_sha)
if not result:
raise TriggerStartException('Could not lookup commit SHA')
# Return a prepared build for the commit.
repository_name = '%s/%s' % (repository.namespace, repository.repository_name)
metadata = get_transformed_commit_info(commit_info, ref, default_branch,
repository_name, lookup_author)
return self.prepare_build(metadata, is_manual=True)
return self.prepare_build(metadata, is_manual=True)

View file

@ -2,22 +2,33 @@ import logging
import json
from jsonschema import validate, ValidationError
from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
TriggerStartException, ValidationRequestException,
InvalidPayloadException,
SkipRequestException, raise_if_skipped_build,
find_matching_branches)
from buildtrigger.triggerutil import (
RepositoryReadException,
TriggerActivationException,
TriggerStartException,
ValidationRequestException,
InvalidPayloadException,
SkipRequestException,
raise_if_skipped_build,
find_matching_branches,
)
from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.bitbuckethandler import (BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
get_transformed_webhook_payload as bb_payload)
from buildtrigger.bitbuckethandler import (
BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
get_transformed_webhook_payload as bb_payload,
)
from buildtrigger.githubhandler import (GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
get_transformed_webhook_payload as gh_payload)
from buildtrigger.githubhandler import (
GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
get_transformed_webhook_payload as gh_payload,
)
from buildtrigger.gitlabhandler import (GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
get_transformed_webhook_payload as gl_payload)
from buildtrigger.gitlabhandler import (
GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
get_transformed_webhook_payload as gl_payload,
)
from util.security.ssh import generate_ssh_keypair
@ -27,203 +38,191 @@ logger = logging.getLogger(__name__)
# Defines an ordered set of tuples of the schemas and associated transformation functions
# for incoming webhook payloads.
SCHEMA_AND_HANDLERS = [
(gh_schema, gh_payload),
(bb_schema, bb_payload),
(gl_schema, gl_payload),
(gh_schema, gh_payload),
(bb_schema, bb_payload),
(gl_schema, gl_payload),
]
def custom_trigger_payload(metadata, git_url):
# First try the customhandler schema. If it matches, nothing more to do.
custom_handler_validation_error = None
try:
validate(metadata, CustomBuildTrigger.payload_schema)
except ValidationError as vex:
custom_handler_validation_error = vex
# Otherwise, try the defined schemas, in order, until we find a match.
for schema, handler in SCHEMA_AND_HANDLERS:
# First try the customhandler schema. If it matches, nothing more to do.
custom_handler_validation_error = None
try:
validate(metadata, schema)
except ValidationError:
continue
validate(metadata, CustomBuildTrigger.payload_schema)
except ValidationError as vex:
custom_handler_validation_error = vex
result = handler(metadata)
result['git_url'] = git_url
return result
# Otherwise, try the defined schemas, in order, until we find a match.
for schema, handler in SCHEMA_AND_HANDLERS:
try:
validate(metadata, schema)
except ValidationError:
continue
# If we have reached this point and no other schemas validated, then raise the error for the
# custom schema.
if custom_handler_validation_error is not None:
raise InvalidPayloadException(custom_handler_validation_error.message)
result = handler(metadata)
result["git_url"] = git_url
return result
metadata['git_url'] = git_url
return metadata
# If we have reached this point and no other schemas validated, then raise the error for the
# custom schema.
if custom_handler_validation_error is not None:
raise InvalidPayloadException(custom_handler_validation_error.message)
metadata["git_url"] = git_url
return metadata
class CustomBuildTrigger(BuildTriggerHandler):
payload_schema = {
'type': 'object',
'properties': {
'commit': {
'type': 'string',
'description': 'first 7 characters of the SHA-1 identifier for a git commit',
'pattern': '^([A-Fa-f0-9]{7,})$',
},
'ref': {
'type': 'string',
'description': 'git reference for a git commit',
'pattern': '^refs\/(heads|tags|remotes)\/(.+)$',
},
'default_branch': {
'type': 'string',
'description': 'default branch of the git repository',
},
'commit_info': {
'type': 'object',
'description': 'metadata about a git commit',
'properties': {
'url': {
'type': 'string',
'description': 'URL to view a git commit',
},
'message': {
'type': 'string',
'description': 'git commit message',
},
'date': {
'type': 'string',
'description': 'timestamp for a git commit'
},
'author': {
'type': 'object',
'description': 'metadata about the author of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the author',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the author',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the author',
},
payload_schema = {
"type": "object",
"properties": {
"commit": {
"type": "string",
"description": "first 7 characters of the SHA-1 identifier for a git commit",
"pattern": "^([A-Fa-f0-9]{7,})$",
},
'required': ['username', 'url', 'avatar_url'],
},
'committer': {
'type': 'object',
'description': 'metadata about the committer of a git commit',
'properties': {
'username': {
'type': 'string',
'description': 'username of the committer',
},
'url': {
'type': 'string',
'description': 'URL to view the profile of the committer',
},
'avatar_url': {
'type': 'string',
'description': 'URL to view the avatar of the committer',
},
"ref": {
"type": "string",
"description": "git reference for a git commit",
"pattern": "^refs\/(heads|tags|remotes)\/(.+)$",
},
"default_branch": {
"type": "string",
"description": "default branch of the git repository",
},
"commit_info": {
"type": "object",
"description": "metadata about a git commit",
"properties": {
"url": {
"type": "string",
"description": "URL to view a git commit",
},
"message": {"type": "string", "description": "git commit message"},
"date": {
"type": "string",
"description": "timestamp for a git commit",
},
"author": {
"type": "object",
"description": "metadata about the author of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the author",
},
"url": {
"type": "string",
"description": "URL to view the profile of the author",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the author",
},
},
"required": ["username", "url", "avatar_url"],
},
"committer": {
"type": "object",
"description": "metadata about the committer of a git commit",
"properties": {
"username": {
"type": "string",
"description": "username of the committer",
},
"url": {
"type": "string",
"description": "URL to view the profile of the committer",
},
"avatar_url": {
"type": "string",
"description": "URL to view the avatar of the committer",
},
},
"required": ["username", "url", "avatar_url"],
},
},
"required": ["url", "message", "date"],
},
'required': ['username', 'url', 'avatar_url'],
},
},
'required': ['url', 'message', 'date'],
},
},
'required': ['commit', 'ref', 'default_branch'],
}
@classmethod
def service_name(cls):
return 'custom-git'
def is_active(self):
return self.config.has_key('credentials')
def _metadata_from_payload(self, payload, git_url):
# Parse the JSON payload.
try:
metadata = json.loads(payload)
except ValueError as vex:
raise InvalidPayloadException(vex.message)
return custom_trigger_payload(metadata, git_url)
def handle_trigger_request(self, request):
payload = request.data
if not payload:
raise InvalidPayloadException('Missing expected payload')
logger.debug('Payload %s', payload)
metadata = self._metadata_from_payload(payload, self.config['build_source'])
prepared = self.prepare_build(metadata)
# Check if we should skip this build.
raise_if_skipped_build(prepared, self.config)
return prepared
def manual_start(self, run_parameters=None):
# commit_sha is the only required parameter
commit_sha = run_parameters.get('commit_sha')
if commit_sha is None:
raise TriggerStartException('missing required parameter')
config = self.config
metadata = {
'commit': commit_sha,
'git_url': config['build_source'],
"required": ["commit", "ref", "default_branch"],
}
try:
return self.prepare_build(metadata, is_manual=True)
except ValidationError as ve:
raise TriggerStartException(ve.message)
@classmethod
def service_name(cls):
return "custom-git"
def activate(self, standard_webhook_url):
config = self.config
public_key, private_key = generate_ssh_keypair()
config['credentials'] = [
{
'name': 'SSH Public Key',
'value': public_key,
},
{
'name': 'Webhook Endpoint URL',
'value': standard_webhook_url,
},
]
self.config = config
return config, {'private_key': private_key}
def is_active(self):
return self.config.has_key("credentials")
def deactivate(self):
config = self.config
config.pop('credentials', None)
self.config = config
return config
def _metadata_from_payload(self, payload, git_url):
# Parse the JSON payload.
try:
metadata = json.loads(payload)
except ValueError as vex:
raise InvalidPayloadException(vex.message)
def get_repository_url(self):
return None
return custom_trigger_payload(metadata, git_url)
def list_build_source_namespaces(self):
raise NotImplementedError
def handle_trigger_request(self, request):
payload = request.data
if not payload:
raise InvalidPayloadException("Missing expected payload")
def list_build_sources_for_namespace(self, namespace):
raise NotImplementedError
logger.debug("Payload %s", payload)
def list_build_subdirs(self):
raise NotImplementedError
metadata = self._metadata_from_payload(payload, self.config["build_source"])
prepared = self.prepare_build(metadata)
def list_field_values(self, field_name, limit=None):
raise NotImplementedError
# Check if we should skip this build.
raise_if_skipped_build(prepared, self.config)
def load_dockerfile_contents(self):
raise NotImplementedError
return prepared
def manual_start(self, run_parameters=None):
# commit_sha is the only required parameter
commit_sha = run_parameters.get("commit_sha")
if commit_sha is None:
raise TriggerStartException("missing required parameter")
config = self.config
metadata = {"commit": commit_sha, "git_url": config["build_source"]}
try:
return self.prepare_build(metadata, is_manual=True)
except ValidationError as ve:
raise TriggerStartException(ve.message)
def activate(self, standard_webhook_url):
config = self.config
public_key, private_key = generate_ssh_keypair()
config["credentials"] = [
{"name": "SSH Public Key", "value": public_key},
{"name": "Webhook Endpoint URL", "value": standard_webhook_url},
]
self.config = config
return config, {"private_key": private_key}
def deactivate(self):
config = self.config
config.pop("credentials", None)
self.config = config
return config
def get_repository_url(self):
return None
def list_build_source_namespaces(self):
raise NotImplementedError
def list_build_sources_for_namespace(self, namespace):
raise NotImplementedError
def list_build_subdirs(self):
raise NotImplementedError
def list_field_values(self, field_name, limit=None):
raise NotImplementedError
def load_dockerfile_contents(self):
raise NotImplementedError

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,156 +4,168 @@ from mock import Mock
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
from util.morecollections import AttrDict
def get_bitbucket_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
trigger = BitbucketBuildTrigger(trigger_obj, {
'build_source': 'foo/bar',
'dockerfile_path': dockerfile_path,
'nickname': 'knownuser',
'account_id': 'foo',
})
trigger._get_client = get_mock_bitbucket
return trigger
def get_bitbucket_trigger(dockerfile_path=""):
trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger = BitbucketBuildTrigger(
trigger_obj,
{
"build_source": "foo/bar",
"dockerfile_path": dockerfile_path,
"nickname": "knownuser",
"account_id": "foo",
},
)
trigger._get_client = get_mock_bitbucket
return trigger
def get_repo_path_contents(path, revision):
data = {
'files': [{'path': 'Dockerfile'}],
}
data = {"files": [{"path": "Dockerfile"}]}
return (True, data, None)
return (True, data, None)
def get_raw_path_contents(path, revision):
if path == 'Dockerfile':
return (True, 'hello world', None)
if path == "Dockerfile":
return (True, "hello world", None)
if path == 'somesubdir/Dockerfile':
return (True, 'hi universe', None)
if path == "somesubdir/Dockerfile":
return (True, "hi universe", None)
return (False, None, None)
return (False, None, None)
def get_branches_and_tags():
data = {
'branches': [{'name': 'master'}, {'name': 'otherbranch'}],
'tags': [{'name': 'sometag'}, {'name': 'someothertag'}],
}
return (True, data, None)
data = {
"branches": [{"name": "master"}, {"name": "otherbranch"}],
"tags": [{"name": "sometag"}, {"name": "someothertag"}],
}
return (True, data, None)
def get_branches():
return (True, {'master': {}, 'otherbranch': {}}, None)
return (True, {"master": {}, "otherbranch": {}}, None)
def get_tags():
return (True, {'sometag': {}, 'someothertag': {}}, None)
return (True, {"sometag": {}, "someothertag": {}}, None)
def get_branch(branch_name):
if branch_name != 'master':
return (False, None, None)
if branch_name != "master":
return (False, None, None)
data = {
'target': {
'hash': 'aaaaaaa',
},
}
data = {"target": {"hash": "aaaaaaa"}}
return (True, data, None)
return (True, data, None)
def get_tag(tag_name):
if tag_name != 'sometag':
return (False, None, None)
if tag_name != "sometag":
return (False, None, None)
data = {
'target': {
'hash': 'aaaaaaa',
},
}
data = {"target": {"hash": "aaaaaaa"}}
return (True, data, None)
return (True, data, None)
def get_changeset_mock(commit_sha):
if commit_sha != 'aaaaaaa':
return (False, None, 'Not found')
if commit_sha != "aaaaaaa":
return (False, None, "Not found")
data = {
'node': 'aaaaaaa',
'message': 'some message',
'timestamp': 'now',
'raw_author': 'foo@bar.com',
}
data = {
"node": "aaaaaaa",
"message": "some message",
"timestamp": "now",
"raw_author": "foo@bar.com",
}
return (True, data, None)
return (True, data, None)
def get_changesets():
changesets_mock = Mock()
changesets_mock.get = Mock(side_effect=get_changeset_mock)
return changesets_mock
changesets_mock = Mock()
changesets_mock.get = Mock(side_effect=get_changeset_mock)
return changesets_mock
def get_deploykeys():
deploykeys_mock = Mock()
deploykeys_mock.create = Mock(return_value=(True, {'pk': 'someprivatekey'}, None))
deploykeys_mock.delete = Mock(return_value=(True, {}, None))
return deploykeys_mock
deploykeys_mock = Mock()
deploykeys_mock.create = Mock(return_value=(True, {"pk": "someprivatekey"}, None))
deploykeys_mock.delete = Mock(return_value=(True, {}, None))
return deploykeys_mock
def get_webhooks():
webhooks_mock = Mock()
webhooks_mock.create = Mock(return_value=(True, {'uuid': 'someuuid'}, None))
webhooks_mock.delete = Mock(return_value=(True, {}, None))
return webhooks_mock
webhooks_mock = Mock()
webhooks_mock.create = Mock(return_value=(True, {"uuid": "someuuid"}, None))
webhooks_mock.delete = Mock(return_value=(True, {}, None))
return webhooks_mock
def get_repo_mock(name):
if name != 'bar':
return None
if name != "bar":
return None
repo_mock = Mock()
repo_mock.get_main_branch = Mock(return_value=(True, {'name': 'master'}, None))
repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
repo_mock.get_branches = Mock(side_effect=get_branches)
repo_mock.get_tags = Mock(side_effect=get_tags)
repo_mock.get_branch = Mock(side_effect=get_branch)
repo_mock.get_tag = Mock(side_effect=get_tag)
repo_mock = Mock()
repo_mock.get_main_branch = Mock(return_value=(True, {"name": "master"}, None))
repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
repo_mock.get_branches = Mock(side_effect=get_branches)
repo_mock.get_tags = Mock(side_effect=get_tags)
repo_mock.get_branch = Mock(side_effect=get_branch)
repo_mock.get_tag = Mock(side_effect=get_tag)
repo_mock.changesets = Mock(side_effect=get_changesets)
repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
repo_mock.webhooks = Mock(side_effect=get_webhooks)
return repo_mock
repo_mock.changesets = Mock(side_effect=get_changesets)
repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
repo_mock.webhooks = Mock(side_effect=get_webhooks)
return repo_mock
def get_repositories_mock():
repos_mock = Mock()
repos_mock.get = Mock(side_effect=get_repo_mock)
return repos_mock
repos_mock = Mock()
repos_mock.get = Mock(side_effect=get_repo_mock)
return repos_mock
def get_namespace_mock(namespace):
namespace_mock = Mock()
namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
return namespace_mock
namespace_mock = Mock()
namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
return namespace_mock
def get_repo(namespace, name):
return {
'owner': namespace,
'logo': 'avatarurl',
'slug': name,
'description': 'some %s repo' % (name),
'utc_last_updated': str(datetime.utcfromtimestamp(0)),
'read_only': namespace != 'knownuser',
'is_private': name == 'somerepo',
}
return {
"owner": namespace,
"logo": "avatarurl",
"slug": name,
"description": "some %s repo" % (name),
"utc_last_updated": str(datetime.utcfromtimestamp(0)),
"read_only": namespace != "knownuser",
"is_private": name == "somerepo",
}
def get_visible_repos():
repos = [
get_repo('knownuser', 'somerepo'),
get_repo('someorg', 'somerepo'),
get_repo('someorg', 'anotherrepo'),
]
return (True, repos, None)
repos = [
get_repo("knownuser", "somerepo"),
get_repo("someorg", "somerepo"),
get_repo("someorg", "anotherrepo"),
]
return (True, repos, None)
def get_authed_mock(token, secret):
authed_mock = Mock()
authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
return authed_mock
authed_mock = Mock()
authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
return authed_mock
def get_mock_bitbucket():
bitbucket_mock = Mock()
bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)
return bitbucket_mock
bitbucket_mock = Mock()
bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)
return bitbucket_mock

View file

@ -6,173 +6,178 @@ from github import GithubException
from buildtrigger.githubhandler import GithubBuildTrigger
from util.morecollections import AttrDict
def get_github_trigger(dockerfile_path=''):
trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
trigger = GithubBuildTrigger(trigger_obj, {'build_source': 'foo', 'dockerfile_path': dockerfile_path})
trigger._get_client = get_mock_github
return trigger
def get_github_trigger(dockerfile_path=""):
trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
trigger = GithubBuildTrigger(
trigger_obj, {"build_source": "foo", "dockerfile_path": dockerfile_path}
)
trigger._get_client = get_mock_github
return trigger
def get_mock_github():
def get_commit_mock(commit_sha):
if commit_sha == 'aaaaaaa':
commit_mock = Mock()
commit_mock.sha = commit_sha
commit_mock.html_url = 'http://url/to/commit'
commit_mock.last_modified = 'now'
def get_commit_mock(commit_sha):
if commit_sha == "aaaaaaa":
commit_mock = Mock()
commit_mock.sha = commit_sha
commit_mock.html_url = "http://url/to/commit"
commit_mock.last_modified = "now"
commit_mock.commit = Mock()
commit_mock.commit.message = 'some cool message'
commit_mock.commit = Mock()
commit_mock.commit.message = "some cool message"
commit_mock.committer = Mock()
commit_mock.committer.login = 'someuser'
commit_mock.committer.avatar_url = 'avatarurl'
commit_mock.committer.html_url = 'htmlurl'
commit_mock.committer = Mock()
commit_mock.committer.login = "someuser"
commit_mock.committer.avatar_url = "avatarurl"
commit_mock.committer.html_url = "htmlurl"
commit_mock.author = Mock()
commit_mock.author.login = 'someuser'
commit_mock.author.avatar_url = 'avatarurl'
commit_mock.author.html_url = 'htmlurl'
return commit_mock
commit_mock.author = Mock()
commit_mock.author.login = "someuser"
commit_mock.author.avatar_url = "avatarurl"
commit_mock.author.html_url = "htmlurl"
return commit_mock
raise GithubException(None, None)
raise GithubException(None, None)
def get_branch_mock(branch_name):
if branch_name == 'master':
branch_mock = Mock()
branch_mock.commit = Mock()
branch_mock.commit.sha = 'aaaaaaa'
return branch_mock
def get_branch_mock(branch_name):
if branch_name == "master":
branch_mock = Mock()
branch_mock.commit = Mock()
branch_mock.commit.sha = "aaaaaaa"
return branch_mock
raise GithubException(None, None)
raise GithubException(None, None)
def get_repo_mock(namespace, name):
repo_mock = Mock()
repo_mock.owner = Mock()
repo_mock.owner.login = namespace
repo_mock.full_name = "%s/%s" % (namespace, name)
repo_mock.name = name
repo_mock.description = "some %s repo" % (name)
if name != "anotherrepo":
repo_mock.pushed_at = datetime.utcfromtimestamp(0)
else:
repo_mock.pushed_at = None
repo_mock.html_url = "https://bitbucket.org/%s/%s" % (namespace, name)
repo_mock.private = name == "somerepo"
repo_mock.permissions = Mock()
repo_mock.permissions.admin = namespace == "knownuser"
return repo_mock
def get_user_repos_mock(type="all", sort="created"):
return [get_repo_mock("knownuser", "somerepo")]
def get_org_repos_mock(type="all"):
return [
get_repo_mock("someorg", "somerepo"),
get_repo_mock("someorg", "anotherrepo"),
]
def get_orgs_mock():
return [get_org_mock("someorg")]
def get_user_mock(username="knownuser"):
if username == "knownuser":
user_mock = Mock()
user_mock.name = username
user_mock.plan = Mock()
user_mock.plan.private_repos = 1
user_mock.login = username
user_mock.html_url = "https://bitbucket.org/%s" % (username)
user_mock.avatar_url = "avatarurl"
user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
return user_mock
raise GithubException(None, None)
def get_org_mock(namespace):
if namespace == "someorg":
org_mock = Mock()
org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
org_mock.login = namespace
org_mock.html_url = "https://bitbucket.org/%s" % (namespace)
org_mock.avatar_url = "avatarurl"
org_mock.name = namespace
org_mock.plan = Mock()
org_mock.plan.private_repos = 2
return org_mock
raise GithubException(None, None)
def get_tags_mock():
sometag = Mock()
sometag.name = "sometag"
sometag.commit = get_commit_mock("aaaaaaa")
someothertag = Mock()
someothertag.name = "someothertag"
someothertag.commit = get_commit_mock("aaaaaaa")
return [sometag, someothertag]
def get_branches_mock():
master = Mock()
master.name = "master"
master.commit = get_commit_mock("aaaaaaa")
otherbranch = Mock()
otherbranch.name = "otherbranch"
otherbranch.commit = get_commit_mock("aaaaaaa")
return [master, otherbranch]
def get_contents_mock(filepath):
if filepath == "Dockerfile":
m = Mock()
m.content = "hello world"
return m
if filepath == "somesubdir/Dockerfile":
m = Mock()
m.content = "hi universe"
return m
raise GithubException(None, None)
def get_git_tree_mock(commit_sha, recursive=False):
first_file = Mock()
first_file.type = "blob"
first_file.path = "Dockerfile"
second_file = Mock()
second_file.type = "other"
second_file.path = "/some/Dockerfile"
third_file = Mock()
third_file.type = "blob"
third_file.path = "somesubdir/Dockerfile"
t = Mock()
if commit_sha == "aaaaaaa":
t.tree = [first_file, second_file, third_file]
else:
t.tree = []
return t
def get_repo_mock(namespace, name):
repo_mock = Mock()
repo_mock.owner = Mock()
repo_mock.owner.login = namespace
repo_mock.default_branch = "master"
repo_mock.ssh_url = "ssh_url"
repo_mock.full_name = '%s/%s' % (namespace, name)
repo_mock.name = name
repo_mock.description = 'some %s repo' % (name)
repo_mock.get_branch = Mock(side_effect=get_branch_mock)
repo_mock.get_tags = Mock(side_effect=get_tags_mock)
repo_mock.get_branches = Mock(side_effect=get_branches_mock)
repo_mock.get_commit = Mock(side_effect=get_commit_mock)
repo_mock.get_contents = Mock(side_effect=get_contents_mock)
repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
if name != 'anotherrepo':
repo_mock.pushed_at = datetime.utcfromtimestamp(0)
else:
repo_mock.pushed_at = None
repo_mock.html_url = 'https://bitbucket.org/%s/%s' % (namespace, name)
repo_mock.private = name == 'somerepo'
repo_mock.permissions = Mock()
repo_mock.permissions.admin = namespace == 'knownuser'
return repo_mock
def get_user_repos_mock(type='all', sort='created'):
return [get_repo_mock('knownuser', 'somerepo')]
def get_org_repos_mock(type='all'):
return [get_repo_mock('someorg', 'somerepo'), get_repo_mock('someorg', 'anotherrepo')]
def get_orgs_mock():
return [get_org_mock('someorg')]
def get_user_mock(username='knownuser'):
if username == 'knownuser':
user_mock = Mock()
user_mock.name = username
user_mock.plan = Mock()
user_mock.plan.private_repos = 1
user_mock.login = username
user_mock.html_url = 'https://bitbucket.org/%s' % (username)
user_mock.avatar_url = 'avatarurl'
user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
return user_mock
raise GithubException(None, None)
def get_org_mock(namespace):
if namespace == 'someorg':
org_mock = Mock()
org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
org_mock.login = namespace
org_mock.html_url = 'https://bitbucket.org/%s' % (namespace)
org_mock.avatar_url = 'avatarurl'
org_mock.name = namespace
org_mock.plan = Mock()
org_mock.plan.private_repos = 2
return org_mock
raise GithubException(None, None)
def get_tags_mock():
sometag = Mock()
sometag.name = 'sometag'
sometag.commit = get_commit_mock('aaaaaaa')
someothertag = Mock()
someothertag.name = 'someothertag'
someothertag.commit = get_commit_mock('aaaaaaa')
return [sometag, someothertag]
def get_branches_mock():
master = Mock()
master.name = 'master'
master.commit = get_commit_mock('aaaaaaa')
otherbranch = Mock()
otherbranch.name = 'otherbranch'
otherbranch.commit = get_commit_mock('aaaaaaa')
return [master, otherbranch]
def get_contents_mock(filepath):
if filepath == 'Dockerfile':
m = Mock()
m.content = 'hello world'
return m
if filepath == 'somesubdir/Dockerfile':
m = Mock()
m.content = 'hi universe'
return m
raise GithubException(None, None)
def get_git_tree_mock(commit_sha, recursive=False):
first_file = Mock()
first_file.type = 'blob'
first_file.path = 'Dockerfile'
second_file = Mock()
second_file.type = 'other'
second_file.path = '/some/Dockerfile'
third_file = Mock()
third_file.type = 'blob'
third_file.path = 'somesubdir/Dockerfile'
t = Mock()
if commit_sha == 'aaaaaaa':
t.tree = [
first_file, second_file, third_file,
]
else:
t.tree = []
return t
repo_mock = Mock()
repo_mock.default_branch = 'master'
repo_mock.ssh_url = 'ssh_url'
repo_mock.get_branch = Mock(side_effect=get_branch_mock)
repo_mock.get_tags = Mock(side_effect=get_tags_mock)
repo_mock.get_branches = Mock(side_effect=get_branches_mock)
repo_mock.get_commit = Mock(side_effect=get_commit_mock)
repo_mock.get_contents = Mock(side_effect=get_contents_mock)
repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
gh_mock = Mock()
gh_mock.get_repo = Mock(return_value=repo_mock)
gh_mock.get_user = Mock(side_effect=get_user_mock)
gh_mock.get_organization = Mock(side_effect=get_org_mock)
return gh_mock
gh_mock = Mock()
gh_mock.get_repo = Mock(return_value=repo_mock)
gh_mock.get_user = Mock(side_effect=get_user_mock)
gh_mock.get_organization = Mock(side_effect=get_org_mock)
return gh_mock

File diff suppressed because it is too large Load diff

View file

@ -3,53 +3,74 @@ import pytest
from buildtrigger.basehandler import BuildTriggerHandler
@pytest.mark.parametrize('input,output', [
("Dockerfile", True),
("server.Dockerfile", True),
(u"Dockerfile", True),
(u"server.Dockerfile", True),
("bad file name", False),
(u"bad file name", False),
])
@pytest.mark.parametrize(
"input,output",
[
("Dockerfile", True),
("server.Dockerfile", True),
(u"Dockerfile", True),
(u"server.Dockerfile", True),
("bad file name", False),
(u"bad file name", False),
],
)
def test_path_is_dockerfile(input, output):
assert BuildTriggerHandler.filename_is_dockerfile(input) == output
assert BuildTriggerHandler.filename_is_dockerfile(input) == output
@pytest.mark.parametrize('input,output', [
("", {}),
("/a", {"/a": ["/"]}),
("a", {"/a": ["/"]}),
("/b/a", {"/b/a": ["/b", "/"]}),
("b/a", {"/b/a": ["/b", "/"]}),
("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
("/a", {"/a": ["/"]}),
])
@pytest.mark.parametrize(
"input,output",
[
("", {}),
("/a", {"/a": ["/"]}),
("a", {"/a": ["/"]}),
("/b/a", {"/b/a": ["/b", "/"]}),
("b/a", {"/b/a": ["/b", "/"]}),
("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
("/a", {"/a": ["/"]}),
],
)
def test_subdir_path_map_no_previous(input, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
for key in actual_mapping:
value = actual_mapping[key]
actual_mapping[key] = value.sort()
for key in output:
value = output[key]
output[key] = value.sort()
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
for key in actual_mapping:
value = actual_mapping[key]
actual_mapping[key] = value.sort()
for key in output:
value = output[key]
output[key] = value.sort()
assert actual_mapping == output
assert actual_mapping == output
@pytest.mark.parametrize('new_path,original_dictionary,output', [
("/a", {}, {"/a": ["/"]}),
("b", {"/a": ["some_path", "another_path"]}, {"/a": ["some_path", "another_path"], "/b": ["/"]}),
("/a/b/c/d", {"/e": ["some_path", "another_path"]},
{"/e": ["some_path", "another_path"], "/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"]}),
])
@pytest.mark.parametrize(
"new_path,original_dictionary,output",
[
("/a", {}, {"/a": ["/"]}),
(
"b",
{"/a": ["some_path", "another_path"]},
{"/a": ["some_path", "another_path"], "/b": ["/"]},
),
(
"/a/b/c/d",
{"/e": ["some_path", "another_path"]},
{
"/e": ["some_path", "another_path"],
"/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"],
},
),
],
)
def test_subdir_path_map(new_path, original_dictionary, output):
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(new_path, original_dictionary)
for key in actual_mapping:
value = actual_mapping[key]
actual_mapping[key] = value.sort()
for key in output:
value = output[key]
output[key] = value.sort()
actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(
new_path, original_dictionary
)
for key in actual_mapping:
value = actual_mapping[key]
actual_mapping[key] = value.sort()
for key in output:
value = output[key]
output[key] = value.sort()
assert actual_mapping == output
assert actual_mapping == output

View file

@ -2,35 +2,44 @@ import json
import pytest
from buildtrigger.test.bitbucketmock import get_bitbucket_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
InvalidPayloadException)
from buildtrigger.triggerutil import (
SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
@pytest.fixture
def bitbucket_trigger():
return get_bitbucket_trigger()
return get_bitbucket_trigger()
def test_list_build_subdirs(bitbucket_trigger):
assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [
('/Dockerfile', 'hello world'),
('somesubdir/Dockerfile', 'hi universe'),
('unknownpath', None),
])
@pytest.mark.parametrize(
"dockerfile_path, contents",
[
("/Dockerfile", "hello world"),
("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_bitbucket_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents
trigger = get_bitbucket_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('{}', InvalidPayloadException, "'push' is a required property"),
# Valid payload:
('''{
@pytest.mark.parametrize(
"payload, expected_error, expected_message",
[
("{}", InvalidPayloadException, "'push' is a required property"),
# Valid payload:
(
"""{
"push": {
"changes": [{
"new": {
@ -51,10 +60,13 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": {
"full_name": "foo/bar"
}
}''', None, None),
# Skip message:
('''{
}""",
None,
None,
),
# Skip message:
(
"""{
"push": {
"changes": [{
"new": {
@ -75,17 +87,25 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": {
"full_name": "foo/bar"
}
}''', SkipRequestException, ''),
])
def test_handle_trigger_request(bitbucket_trigger, payload, expected_error, expected_message):
def get_payload():
return json.loads(payload)
}""",
SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(
bitbucket_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload))
request = AttrDict(dict(get_json=get_payload))
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
bitbucket_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(bitbucket_trigger.handle_trigger_request(request), PreparedBuild)
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
bitbucket_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(
bitbucket_trigger.handle_trigger_request(request), PreparedBuild
)

View file

@ -1,20 +1,32 @@
import pytest
from buildtrigger.customhandler import CustomBuildTrigger
from buildtrigger.triggerutil import (InvalidPayloadException, SkipRequestException,
TriggerStartException)
from buildtrigger.triggerutil import (
InvalidPayloadException,
SkipRequestException,
TriggerStartException,
)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('', InvalidPayloadException, 'Missing expected payload'),
('{}', InvalidPayloadException, "'commit' is a required property"),
('{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
InvalidPayloadException, "u'foo' does not match '^([A-Fa-f0-9]{7,})$'"),
('{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}', None, None),
('''{
@pytest.mark.parametrize(
"payload, expected_error, expected_message",
[
("", InvalidPayloadException, "Missing expected payload"),
("{}", InvalidPayloadException, "'commit' is a required property"),
(
'{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
InvalidPayloadException,
"u'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
(
'{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}',
None,
None,
),
(
"""{
"commit": "11d6fbc",
"ref": "refs/heads/something",
"default_branch": "baz",
@ -23,29 +35,41 @@ from util.morecollections import AttrDict
"url": "http://foo.bar",
"date": "NOW"
}
}''', SkipRequestException, ''),
])
}""",
SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(payload, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'})
request = AttrDict(dict(data=payload))
trigger = CustomBuildTrigger(None, {"build_source": "foo"})
request = AttrDict(dict(data=payload))
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
({}, TriggerStartException, 'missing required parameter'),
({'commit_sha': 'foo'}, TriggerStartException, "'foo' does not match '^([A-Fa-f0-9]{7,})$'"),
({'commit_sha': '11d6fbc'}, None, None),
])
@pytest.mark.parametrize(
"run_parameters, expected_error, expected_message",
[
({}, TriggerStartException, "missing required parameter"),
(
{"commit_sha": "foo"},
TriggerStartException,
"'foo' does not match '^([A-Fa-f0-9]{7,})$'",
),
({"commit_sha": "11d6fbc"}, None, None),
],
)
def test_manual_start(run_parameters, expected_error, expected_message):
trigger = CustomBuildTrigger(None, {'build_source': 'foo'})
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(trigger.manual_start(run_parameters), PreparedBuild)
trigger = CustomBuildTrigger(None, {"build_source": "foo"})
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(trigger.manual_start(run_parameters), PreparedBuild)

View file

@ -9,113 +9,145 @@ from endpoints.building import PreparedBuild
# in this fixture. Each trigger's mock is expected to return the same data for all of these calls.
@pytest.fixture(params=[get_github_trigger(), get_bitbucket_trigger()])
def githost_trigger(request):
return request.param
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
# No branch or tag specified: use the commit of the default branch.
({}, None, None),
# Invalid branch.
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException,
'Could not find branch in repository'),
# Invalid tag.
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException,
'Could not find tag in repository'),
# Valid branch.
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None),
# Valid tag.
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None),
])
def test_manual_start(run_parameters, expected_error, expected_message, githost_trigger):
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
githost_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
return request.param
@pytest.mark.parametrize('name, expected', [
('refs', [
{'kind': 'branch', 'name': 'master'},
{'kind': 'branch', 'name': 'otherbranch'},
{'kind': 'tag', 'name': 'sometag'},
{'kind': 'tag', 'name': 'someothertag'},
]),
('tag_name', set(['sometag', 'someothertag'])),
('branch_name', set(['master', 'otherbranch'])),
('invalid', None)
])
@pytest.mark.parametrize(
"run_parameters, expected_error, expected_message",
[
# No branch or tag specified: use the commit of the default branch.
({}, None, None),
# Invalid branch.
(
{"refs": {"kind": "branch", "name": "invalid"}},
TriggerStartException,
"Could not find branch in repository",
),
# Invalid tag.
(
{"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
"Could not find tag in repository",
),
# Valid branch.
({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag.
({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
],
)
def test_manual_start(
run_parameters, expected_error, expected_message, githost_trigger
):
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
githost_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
@pytest.mark.parametrize(
"name, expected",
[
(
"refs",
[
{"kind": "branch", "name": "master"},
{"kind": "branch", "name": "otherbranch"},
{"kind": "tag", "name": "sometag"},
{"kind": "tag", "name": "someothertag"},
],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, githost_trigger):
if expected is None:
assert githost_trigger.list_field_values(name) is None
elif isinstance(expected, set):
assert set(githost_trigger.list_field_values(name)) == set(expected)
else:
assert githost_trigger.list_field_values(name) == expected
if expected is None:
assert githost_trigger.list_field_values(name) is None
elif isinstance(expected, set):
assert set(githost_trigger.list_field_values(name)) == set(expected)
else:
assert githost_trigger.list_field_values(name) == expected
def test_list_build_source_namespaces():
namespaces_expected = [
{
'personal': True,
'score': 1,
'avatar_url': 'avatarurl',
'id': 'knownuser',
'title': 'knownuser',
'url': 'https://bitbucket.org/knownuser',
},
{
'score': 2,
'title': 'someorg',
'personal': False,
'url': 'https://bitbucket.org/someorg',
'avatar_url': 'avatarurl',
'id': 'someorg'
}
]
namespaces_expected = [
{
"personal": True,
"score": 1,
"avatar_url": "avatarurl",
"id": "knownuser",
"title": "knownuser",
"url": "https://bitbucket.org/knownuser",
},
{
"score": 2,
"title": "someorg",
"personal": False,
"url": "https://bitbucket.org/someorg",
"avatar_url": "avatarurl",
"id": "someorg",
},
]
found = get_bitbucket_trigger().list_build_source_namespaces()
found.sort()
found = get_bitbucket_trigger().list_build_source_namespaces()
found.sort()
namespaces_expected.sort()
assert found == namespaces_expected
namespaces_expected.sort()
assert found == namespaces_expected
@pytest.mark.parametrize('namespace, expected', [
('', []),
('unknown', []),
('knownuser', [
{
'last_updated': 0, 'name': 'somerepo',
'url': 'https://bitbucket.org/knownuser/somerepo', 'private': True,
'full_name': 'knownuser/somerepo', 'has_admin_permissions': True,
'description': 'some somerepo repo'
}]),
('someorg', [
{
'last_updated': 0, 'name': 'somerepo',
'url': 'https://bitbucket.org/someorg/somerepo', 'private': True,
'full_name': 'someorg/somerepo', 'has_admin_permissions': False,
'description': 'some somerepo repo'
},
{
'last_updated': 0, 'name': 'anotherrepo',
'url': 'https://bitbucket.org/someorg/anotherrepo', 'private': False,
'full_name': 'someorg/anotherrepo', 'has_admin_permissions': False,
'description': 'some anotherrepo repo'
}]),
])
@pytest.mark.parametrize(
"namespace, expected",
[
("", []),
("unknown", []),
(
"knownuser",
[
{
"last_updated": 0,
"name": "somerepo",
"url": "https://bitbucket.org/knownuser/somerepo",
"private": True,
"full_name": "knownuser/somerepo",
"has_admin_permissions": True,
"description": "some somerepo repo",
}
],
),
(
"someorg",
[
{
"last_updated": 0,
"name": "somerepo",
"url": "https://bitbucket.org/someorg/somerepo",
"private": True,
"full_name": "someorg/somerepo",
"has_admin_permissions": False,
"description": "some somerepo repo",
},
{
"last_updated": 0,
"name": "anotherrepo",
"url": "https://bitbucket.org/someorg/anotherrepo",
"private": False,
"full_name": "someorg/anotherrepo",
"has_admin_permissions": False,
"description": "some anotherrepo repo",
},
],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, githost_trigger):
assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
def test_activate_and_deactivate(githost_trigger):
_, private_key = githost_trigger.activate('http://some/url')
assert 'private_key' in private_key
githost_trigger.deactivate()
_, private_key = githost_trigger.activate("http://some/url")
assert "private_key" in private_key
githost_trigger.deactivate()

View file

@ -2,24 +2,33 @@ import json
import pytest
from buildtrigger.test.githubmock import get_github_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
InvalidPayloadException)
from buildtrigger.triggerutil import (
SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
@pytest.fixture
def github_trigger():
return get_github_trigger()
return get_github_trigger()
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('{"zen": true}', SkipRequestException, ""),
('{}', InvalidPayloadException, "Missing 'repository' on request"),
('{"repository": "foo"}', InvalidPayloadException, "Missing 'owner' on repository"),
# Valid payload:
('''{
@pytest.mark.parametrize(
"payload, expected_error, expected_message",
[
('{"zen": true}', SkipRequestException, ""),
("{}", InvalidPayloadException, "Missing 'repository' on request"),
(
'{"repository": "foo"}',
InvalidPayloadException,
"Missing 'owner' on repository",
),
# Valid payload:
(
"""{
"repository": {
"owner": {
"name": "someguy"
@ -34,10 +43,13 @@ def github_trigger():
"message": "some message",
"timestamp": "NOW"
}
}''', None, None),
# Skip message:
('''{
}""",
None,
None,
),
# Skip message:
(
"""{
"repository": {
"owner": {
"name": "someguy"
@ -52,66 +64,84 @@ def github_trigger():
"message": "[skip build]",
"timestamp": "NOW"
}
}''', SkipRequestException, ''),
])
def test_handle_trigger_request(github_trigger, payload, expected_error, expected_message):
def get_payload():
return json.loads(payload)
}""",
SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(
github_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload))
request = AttrDict(dict(get_json=get_payload))
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
github_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
github_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('dockerfile_path, contents', [
('/Dockerfile', 'hello world'),
('somesubdir/Dockerfile', 'hi universe'),
('unknownpath', None),
])
@pytest.mark.parametrize(
"dockerfile_path, contents",
[
("/Dockerfile", "hello world"),
("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents):
trigger = get_github_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents
trigger = get_github_trigger(dockerfile_path)
assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('username, expected_response', [
('unknownuser', None),
('knownuser', {'html_url': 'https://bitbucket.org/knownuser', 'avatar_url': 'avatarurl'}),
])
@pytest.mark.parametrize(
"username, expected_response",
[
("unknownuser", None),
(
"knownuser",
{"html_url": "https://bitbucket.org/knownuser", "avatar_url": "avatarurl"},
),
],
)
def test_lookup_user(username, expected_response, github_trigger):
assert github_trigger.lookup_user(username) == expected_response
assert github_trigger.lookup_user(username) == expected_response
def test_list_build_subdirs(github_trigger):
assert github_trigger.list_build_subdirs() == ['Dockerfile', 'somesubdir/Dockerfile']
assert github_trigger.list_build_subdirs() == [
"Dockerfile",
"somesubdir/Dockerfile",
]
def test_list_build_source_namespaces(github_trigger):
namespaces_expected = [
{
'personal': True,
'score': 1,
'avatar_url': 'avatarurl',
'id': 'knownuser',
'title': 'knownuser',
'url': 'https://bitbucket.org/knownuser',
},
{
'score': 0,
'title': 'someorg',
'personal': False,
'url': '',
'avatar_url': 'avatarurl',
'id': 'someorg'
}
]
namespaces_expected = [
{
"personal": True,
"score": 1,
"avatar_url": "avatarurl",
"id": "knownuser",
"title": "knownuser",
"url": "https://bitbucket.org/knownuser",
},
{
"score": 0,
"title": "someorg",
"personal": False,
"url": "",
"avatar_url": "avatarurl",
"id": "someorg",
},
]
found = github_trigger.list_build_source_namespaces()
found.sort()
found = github_trigger.list_build_source_namespaces()
found.sort()
namespaces_expected.sort()
assert found == namespaces_expected
namespaces_expected.sort()
assert found == namespaces_expected

View file

@ -4,91 +4,111 @@ import pytest
from mock import Mock
from buildtrigger.test.gitlabmock import get_gitlab_trigger
from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
InvalidPayloadException, TriggerStartException)
from buildtrigger.triggerutil import (
SkipRequestException,
ValidationRequestException,
InvalidPayloadException,
TriggerStartException,
)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
@pytest.fixture()
def gitlab_trigger():
with get_gitlab_trigger() as t:
yield t
with get_gitlab_trigger() as t:
yield t
def test_list_build_subdirs(gitlab_trigger):
assert gitlab_trigger.list_build_subdirs() == ['Dockerfile']
assert gitlab_trigger.list_build_subdirs() == ["Dockerfile"]
@pytest.mark.parametrize('dockerfile_path, contents', [
('/Dockerfile', 'hello world'),
('somesubdir/Dockerfile', 'hi universe'),
('unknownpath', None),
])
@pytest.mark.parametrize(
"dockerfile_path, contents",
[
("/Dockerfile", "hello world"),
("somesubdir/Dockerfile", "hi universe"),
("unknownpath", None),
],
)
def test_load_dockerfile_contents(dockerfile_path, contents):
with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
assert trigger.load_dockerfile_contents() == contents
with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
assert trigger.load_dockerfile_contents() == contents
@pytest.mark.parametrize('email, expected_response', [
('unknown@email.com', None),
('knownuser', {'username': 'knownuser', 'html_url': 'https://bitbucket.org/knownuser',
'avatar_url': 'avatarurl'}),
])
@pytest.mark.parametrize(
"email, expected_response",
[
("unknown@email.com", None),
(
"knownuser",
{
"username": "knownuser",
"html_url": "https://bitbucket.org/knownuser",
"avatar_url": "avatarurl",
},
),
],
)
def test_lookup_user(email, expected_response, gitlab_trigger):
assert gitlab_trigger.lookup_user(email) == expected_response
assert gitlab_trigger.lookup_user(email) == expected_response
def test_null_permissions():
with get_gitlab_trigger(add_permissions=False) as trigger:
sources = trigger.list_build_sources_for_namespace('someorg')
source = sources[0]
assert source['has_admin_permissions']
with get_gitlab_trigger(add_permissions=False) as trigger:
sources = trigger.list_build_sources_for_namespace("someorg")
source = sources[0]
assert source["has_admin_permissions"]
def test_list_build_sources():
with get_gitlab_trigger() as trigger:
sources = trigger.list_build_sources_for_namespace('someorg')
assert sources == [
{
'last_updated': 1380548762,
'name': u'someproject',
'url': u'http://example.com/someorg/someproject',
'private': True,
'full_name': u'someorg/someproject',
'has_admin_permissions': False,
'description': ''
},
{
'last_updated': 1380548762,
'name': u'anotherproject',
'url': u'http://example.com/someorg/anotherproject',
'private': False,
'full_name': u'someorg/anotherproject',
'has_admin_permissions': True,
'description': '',
}]
with get_gitlab_trigger() as trigger:
sources = trigger.list_build_sources_for_namespace("someorg")
assert sources == [
{
"last_updated": 1380548762,
"name": u"someproject",
"url": u"http://example.com/someorg/someproject",
"private": True,
"full_name": u"someorg/someproject",
"has_admin_permissions": False,
"description": "",
},
{
"last_updated": 1380548762,
"name": u"anotherproject",
"url": u"http://example.com/someorg/anotherproject",
"private": False,
"full_name": u"someorg/anotherproject",
"has_admin_permissions": True,
"description": "",
},
]
def test_null_avatar():
with get_gitlab_trigger(missing_avatar_url=True) as trigger:
namespace_data = trigger.list_build_source_namespaces()
expected = {
'avatar_url': None,
'personal': False,
'title': u'someorg',
'url': u'http://gitlab.com/groups/someorg',
'score': 1,
'id': '2',
}
with get_gitlab_trigger(missing_avatar_url=True) as trigger:
namespace_data = trigger.list_build_source_namespaces()
expected = {
"avatar_url": None,
"personal": False,
"title": u"someorg",
"url": u"http://gitlab.com/groups/someorg",
"score": 1,
"id": "2",
}
assert namespace_data == [expected]
assert namespace_data == [expected]
@pytest.mark.parametrize('payload, expected_error, expected_message', [
('{}', InvalidPayloadException, ''),
# Valid payload:
('''{
@pytest.mark.parametrize(
"payload, expected_error, expected_message",
[
("{}", InvalidPayloadException, ""),
# Valid payload:
(
"""{
"object_kind": "push",
"ref": "refs/heads/master",
"checkout_sha": "aaaaaaa",
@ -103,10 +123,13 @@ def test_null_avatar():
"timestamp": "now"
}
]
}''', None, None),
# Skip message:
('''{
}""",
None,
None,
),
# Skip message:
(
"""{
"object_kind": "push",
"ref": "refs/heads/master",
"checkout_sha": "aaaaaaa",
@ -121,111 +144,136 @@ def test_null_avatar():
"timestamp": "now"
}
]
}''', SkipRequestException, ''),
])
def test_handle_trigger_request(gitlab_trigger, payload, expected_error, expected_message):
def get_payload():
return json.loads(payload)
}""",
SkipRequestException,
"",
),
],
)
def test_handle_trigger_request(
gitlab_trigger, payload, expected_error, expected_message
):
def get_payload():
return json.loads(payload)
request = AttrDict(dict(get_json=get_payload))
request = AttrDict(dict(get_json=get_payload))
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
gitlab_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
gitlab_trigger.handle_trigger_request(request)
assert str(ipe.value) == expected_message
else:
assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
# No branch or tag specified: use the commit of the default branch.
({}, None, None),
# Invalid branch.
({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException,
'Could not find branch in repository'),
# Invalid tag.
({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException,
'Could not find tag in repository'),
# Valid branch.
({'refs': {'kind': 'branch', 'name': 'master'}}, None, None),
# Valid tag.
({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None),
])
@pytest.mark.parametrize(
"run_parameters, expected_error, expected_message",
[
# No branch or tag specified: use the commit of the default branch.
({}, None, None),
# Invalid branch.
(
{"refs": {"kind": "branch", "name": "invalid"}},
TriggerStartException,
"Could not find branch in repository",
),
# Invalid tag.
(
{"refs": {"kind": "tag", "name": "invalid"}},
TriggerStartException,
"Could not find tag in repository",
),
# Valid branch.
({"refs": {"kind": "branch", "name": "master"}}, None, None),
# Valid tag.
({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
],
)
def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger):
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
gitlab_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild)
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
gitlab_trigger.manual_start(run_parameters)
assert str(ipe.value) == expected_message
else:
assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild)
def test_activate_and_deactivate(gitlab_trigger):
_, private_key = gitlab_trigger.activate('http://some/url')
assert 'private_key' in private_key
_, private_key = gitlab_trigger.activate("http://some/url")
assert "private_key" in private_key
gitlab_trigger.deactivate()
gitlab_trigger.deactivate()
@pytest.mark.parametrize('name, expected', [
('refs', [
{'kind': 'branch', 'name': 'master'},
{'kind': 'branch', 'name': 'otherbranch'},
{'kind': 'tag', 'name': 'sometag'},
{'kind': 'tag', 'name': 'someothertag'},
]),
('tag_name', set(['sometag', 'someothertag'])),
('branch_name', set(['master', 'otherbranch'])),
('invalid', None)
])
@pytest.mark.parametrize(
"name, expected",
[
(
"refs",
[
{"kind": "branch", "name": "master"},
{"kind": "branch", "name": "otherbranch"},
{"kind": "tag", "name": "sometag"},
{"kind": "tag", "name": "someothertag"},
],
),
("tag_name", set(["sometag", "someothertag"])),
("branch_name", set(["master", "otherbranch"])),
("invalid", None),
],
)
def test_list_field_values(name, expected, gitlab_trigger):
if expected is None:
assert gitlab_trigger.list_field_values(name) is None
elif isinstance(expected, set):
assert set(gitlab_trigger.list_field_values(name)) == set(expected)
else:
assert gitlab_trigger.list_field_values(name) == expected
if expected is None:
assert gitlab_trigger.list_field_values(name) is None
elif isinstance(expected, set):
assert set(gitlab_trigger.list_field_values(name)) == set(expected)
else:
assert gitlab_trigger.list_field_values(name) == expected
@pytest.mark.parametrize('namespace, expected', [
('', []),
('unknown', []),
('knownuser', [
{
'last_updated': 1380548762,
'name': u'anotherproject',
'url': u'http://example.com/knownuser/anotherproject',
'private': False,
'full_name': u'knownuser/anotherproject',
'has_admin_permissions': True,
'description': ''
},
]),
('someorg', [
{
'last_updated': 1380548762,
'name': u'someproject',
'url': u'http://example.com/someorg/someproject',
'private': True,
'full_name': u'someorg/someproject',
'has_admin_permissions': False,
'description': ''
},
{
'last_updated': 1380548762,
'name': u'anotherproject',
'url': u'http://example.com/someorg/anotherproject',
'private': False,
'full_name': u'someorg/anotherproject',
'has_admin_permissions': True,
'description': '',
}]),
])
@pytest.mark.parametrize(
"namespace, expected",
[
("", []),
("unknown", []),
(
"knownuser",
[
{
"last_updated": 1380548762,
"name": u"anotherproject",
"url": u"http://example.com/knownuser/anotherproject",
"private": False,
"full_name": u"knownuser/anotherproject",
"has_admin_permissions": True,
"description": "",
}
],
),
(
"someorg",
[
{
"last_updated": 1380548762,
"name": u"someproject",
"url": u"http://example.com/someorg/someproject",
"private": True,
"full_name": u"someorg/someproject",
"has_admin_permissions": False,
"description": "",
},
{
"last_updated": 1380548762,
"name": u"anotherproject",
"url": u"http://example.com/someorg/anotherproject",
"private": False,
"full_name": u"someorg/anotherproject",
"has_admin_permissions": True,
"description": "",
},
],
),
],
)
def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger):
assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected
assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected

File diff suppressed because it is too large Load diff

View file

@ -4,22 +4,43 @@ import pytest
from buildtrigger.triggerutil import matches_ref
@pytest.mark.parametrize('ref, filt, matches', [
('ref/heads/master', '.+', True),
('ref/heads/master', 'heads/.+', True),
('ref/heads/master', 'heads/master', True),
('ref/heads/slash/branch', 'heads/slash/branch', True),
('ref/heads/slash/branch', 'heads/.+', True),
('ref/heads/foobar', 'heads/master', False),
('ref/heads/master', 'tags/master', False),
('ref/heads/master', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
('ref/heads/alpha', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
('ref/heads/beta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
('ref/heads/gamma', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
('ref/heads/delta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', False),
])
@pytest.mark.parametrize(
"ref, filt, matches",
[
("ref/heads/master", ".+", True),
("ref/heads/master", "heads/.+", True),
("ref/heads/master", "heads/master", True),
("ref/heads/slash/branch", "heads/slash/branch", True),
("ref/heads/slash/branch", "heads/.+", True),
("ref/heads/foobar", "heads/master", False),
("ref/heads/master", "tags/master", False),
(
"ref/heads/master",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/alpha",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/beta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/gamma",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
True,
),
(
"ref/heads/delta",
"(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
False,
),
],
)
def test_matches_ref(ref, filt, matches):
assert matches_ref(ref, re.compile(filt)) == matches
assert matches_ref(ref, re.compile(filt)) == matches

View file

@ -3,128 +3,146 @@ import io
import logging
import re
class TriggerException(Exception):
pass
pass
class TriggerAuthException(TriggerException):
pass
pass
class InvalidPayloadException(TriggerException):
pass
pass
class BuildArchiveException(TriggerException):
pass
pass
class InvalidServiceException(TriggerException):
pass
pass
class TriggerActivationException(TriggerException):
pass
pass
class TriggerDeactivationException(TriggerException):
pass
pass
class TriggerStartException(TriggerException):
pass
pass
class ValidationRequestException(TriggerException):
pass
pass
class SkipRequestException(TriggerException):
pass
pass
class EmptyRepositoryException(TriggerException):
pass
pass
class RepositoryReadException(TriggerException):
pass
pass
class TriggerProviderException(TriggerException):
pass
pass
logger = logging.getLogger(__name__)
def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch):
run_parameters = run_parameters or {}
run_parameters = run_parameters or {}
kind = ''
value = ''
kind = ""
value = ""
if 'refs' in run_parameters and run_parameters['refs']:
kind = run_parameters['refs']['kind']
value = run_parameters['refs']['name']
elif 'branch_name' in run_parameters:
kind = 'branch'
value = run_parameters['branch_name']
if "refs" in run_parameters and run_parameters["refs"]:
kind = run_parameters["refs"]["kind"]
value = run_parameters["refs"]["name"]
elif "branch_name" in run_parameters:
kind = "branch"
value = run_parameters["branch_name"]
kind = kind or 'branch'
value = value or default_branch or 'master'
kind = kind or "branch"
value = value or default_branch or "master"
ref = 'refs/tags/' + value if kind == 'tag' else 'refs/heads/' + value
commit_sha = get_tag_sha(value) if kind == 'tag' else get_branch_sha(value)
return (commit_sha, ref)
ref = "refs/tags/" + value if kind == "tag" else "refs/heads/" + value
commit_sha = get_tag_sha(value) if kind == "tag" else get_branch_sha(value)
return (commit_sha, ref)
def find_matching_branches(config, branches):
if 'branchtag_regex' in config:
try:
regex = re.compile(config['branchtag_regex'])
return [branch for branch in branches
if matches_ref('refs/heads/' + branch, regex)]
except:
pass
if "branchtag_regex" in config:
try:
regex = re.compile(config["branchtag_regex"])
return [
branch
for branch in branches
if matches_ref("refs/heads/" + branch, regex)
]
except:
pass
return branches
return branches
def should_skip_commit(metadata):
if 'commit_info' in metadata:
message = metadata['commit_info']['message']
return '[skip build]' in message or '[build skip]' in message
return False
if "commit_info" in metadata:
message = metadata["commit_info"]["message"]
return "[skip build]" in message or "[build skip]" in message
return False
def raise_if_skipped_build(prepared_build, config):
""" Raises a SkipRequestException if the given build should be skipped. """
# Check to ensure we have metadata.
if not prepared_build.metadata:
logger.debug('Skipping request due to missing metadata for prepared build')
raise SkipRequestException()
""" Raises a SkipRequestException if the given build should be skipped. """
# Check to ensure we have metadata.
if not prepared_build.metadata:
logger.debug("Skipping request due to missing metadata for prepared build")
raise SkipRequestException()
# Check the branchtag regex.
if 'branchtag_regex' in config:
try:
regex = re.compile(config['branchtag_regex'])
except:
regex = re.compile('.*')
# Check the branchtag regex.
if "branchtag_regex" in config:
try:
regex = re.compile(config["branchtag_regex"])
except:
regex = re.compile(".*")
if not matches_ref(prepared_build.metadata.get('ref'), regex):
raise SkipRequestException()
if not matches_ref(prepared_build.metadata.get("ref"), regex):
raise SkipRequestException()
# Check the commit message.
if should_skip_commit(prepared_build.metadata):
logger.debug('Skipping request due to commit message request')
raise SkipRequestException()
# Check the commit message.
if should_skip_commit(prepared_build.metadata):
logger.debug("Skipping request due to commit message request")
raise SkipRequestException()
def matches_ref(ref, regex):
match_string = ref.split('/', 1)[1]
if not regex:
return False
match_string = ref.split("/", 1)[1]
if not regex:
return False
m = regex.match(match_string)
if not m:
return False
m = regex.match(match_string)
if not m:
return False
return len(m.group(0)) == len(match_string)
return len(m.group(0)) == len(match_string)
def raise_unsupported():
raise io.UnsupportedOperation
raise io.UnsupportedOperation
def get_trigger_config(trigger):
try:
return json.loads(trigger.config)
except ValueError:
return {}
try:
return json.loads(trigger.config)
except ValueError:
return {}

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -10,18 +11,24 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000'
workers = get_worker_count('local', 2, minimum=2, maximum=8)
worker_class = 'gevent'
bind = "0.0.0.0:5000"
workers = get_worker_count("local", 2, minimum=2, maximum=8)
worker_class = "gevent"
daemon = False
pythonpath = '.'
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_registry.sock'
workers = get_worker_count('registry', 4, minimum=8, maximum=64)
worker_class = 'gevent'
pythonpath = '.'
bind = "unix:/tmp/gunicorn_registry.sock"
workers = get_worker_count("registry", 4, minimum=8, maximum=64)
worker_class = "gevent"
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting registry gunicorn with %s workers and %s worker class', workers,
worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting registry gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_secscan.sock'
workers = get_worker_count('secscan', 2, minimum=2, maximum=4)
worker_class = 'gevent'
pythonpath = '.'
bind = "unix:/tmp/gunicorn_secscan.sock"
workers = get_worker_count("secscan", 2, minimum=2, maximum=4)
worker_class = "gevent"
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting secscan gunicorn with %s workers and %s worker class', workers,
worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting secscan gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -10,18 +11,21 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_verbs.sock'
workers = get_worker_count('verbs', 2, minimum=2, maximum=32)
pythonpath = '.'
bind = "unix:/tmp/gunicorn_verbs.sock"
workers = get_worker_count("verbs", 2, minimum=2, maximum=32)
pythonpath = "."
preload_app = True
timeout = 2000 # Because sync workers
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting verbs gunicorn with %s workers and sync worker class', workers)
logger = logging.getLogger(__name__)
logger.debug(
"Starting verbs gunicorn with %s workers and sync worker class", workers
)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -11,18 +12,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
bind = 'unix:/tmp/gunicorn_web.sock'
workers = get_worker_count('web', 2, minimum=2, maximum=32)
worker_class = 'gevent'
pythonpath = '.'
bind = "unix:/tmp/gunicorn_web.sock"
workers = get_worker_count("web", 2, minimum=2, maximum=32)
worker_class = "gevent"
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting web gunicorn with %s workers and %s worker class', workers,
worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting web gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -7,120 +7,130 @@ import jinja2
QUAYPATH = os.getenv("QUAYPATH", ".")
QUAYDIR = os.getenv("QUAYDIR", "/")
QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf"))
STATIC_DIR = os.path.join(QUAYDIR, 'static')
STATIC_DIR = os.path.join(QUAYDIR, "static")
SSL_PROTOCOL_DEFAULTS = ['TLSv1', 'TLSv1.1', 'TLSv1.2']
SSL_PROTOCOL_DEFAULTS = ["TLSv1", "TLSv1.1", "TLSv1.2"]
SSL_CIPHER_DEFAULTS = [
'ECDHE-RSA-AES128-GCM-SHA256',
'ECDHE-ECDSA-AES128-GCM-SHA256',
'ECDHE-RSA-AES256-GCM-SHA384',
'ECDHE-ECDSA-AES256-GCM-SHA384',
'DHE-RSA-AES128-GCM-SHA256',
'DHE-DSS-AES128-GCM-SHA256',
'kEDH+AESGCM',
'ECDHE-RSA-AES128-SHA256',
'ECDHE-ECDSA-AES128-SHA256',
'ECDHE-RSA-AES128-SHA',
'ECDHE-ECDSA-AES128-SHA',
'ECDHE-RSA-AES256-SHA384',
'ECDHE-ECDSA-AES256-SHA384',
'ECDHE-RSA-AES256-SHA',
'ECDHE-ECDSA-AES256-SHA',
'DHE-RSA-AES128-SHA256',
'DHE-RSA-AES128-SHA',
'DHE-DSS-AES128-SHA256',
'DHE-RSA-AES256-SHA256',
'DHE-DSS-AES256-SHA',
'DHE-RSA-AES256-SHA',
'AES128-GCM-SHA256',
'AES256-GCM-SHA384',
'AES128-SHA256',
'AES256-SHA256',
'AES128-SHA',
'AES256-SHA',
'AES',
'CAMELLIA',
'!3DES',
'!aNULL',
'!eNULL',
'!EXPORT',
'!DES',
'!RC4',
'!MD5',
'!PSK',
'!aECDH',
'!EDH-DSS-DES-CBC3-SHA',
'!EDH-RSA-DES-CBC3-SHA',
'!KRB5-DES-CBC3-SHA',
"ECDHE-RSA-AES128-GCM-SHA256",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"DHE-RSA-AES128-GCM-SHA256",
"DHE-DSS-AES128-GCM-SHA256",
"kEDH+AESGCM",
"ECDHE-RSA-AES128-SHA256",
"ECDHE-ECDSA-AES128-SHA256",
"ECDHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES128-SHA",
"ECDHE-RSA-AES256-SHA384",
"ECDHE-ECDSA-AES256-SHA384",
"ECDHE-RSA-AES256-SHA",
"ECDHE-ECDSA-AES256-SHA",
"DHE-RSA-AES128-SHA256",
"DHE-RSA-AES128-SHA",
"DHE-DSS-AES128-SHA256",
"DHE-RSA-AES256-SHA256",
"DHE-DSS-AES256-SHA",
"DHE-RSA-AES256-SHA",
"AES128-GCM-SHA256",
"AES256-GCM-SHA384",
"AES128-SHA256",
"AES256-SHA256",
"AES128-SHA",
"AES256-SHA",
"AES",
"CAMELLIA",
"!3DES",
"!aNULL",
"!eNULL",
"!EXPORT",
"!DES",
"!RC4",
"!MD5",
"!PSK",
"!aECDH",
"!EDH-DSS-DES-CBC3-SHA",
"!EDH-RSA-DES-CBC3-SHA",
"!KRB5-DES-CBC3-SHA",
]
def write_config(filename, **kwargs):
with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(kwargs)
with open(filename, 'w') as f:
f.write(rendered)
def write_config(filename, **kwargs):
with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(kwargs)
with open(filename, "w") as f:
f.write(rendered)
def generate_nginx_config(config):
"""
"""
Generates nginx config from the app config
"""
config = config or {}
use_https = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.key'))
use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.old.key'))
v1_only_domain = config.get('V1_ONLY_DOMAIN', None)
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
ssl_protocols = config.get('SSL_PROTOCOLS', SSL_PROTOCOL_DEFAULTS)
ssl_ciphers = config.get('SSL_CIPHERS', SSL_CIPHER_DEFAULTS)
config = config or {}
use_https = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.key"))
use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.old.key"))
v1_only_domain = config.get("V1_ONLY_DOMAIN", None)
enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
ssl_protocols = config.get("SSL_PROTOCOLS", SSL_PROTOCOL_DEFAULTS)
ssl_ciphers = config.get("SSL_CIPHERS", SSL_CIPHER_DEFAULTS)
write_config(os.path.join(QUAYCONF_DIR, 'nginx/nginx.conf'), use_https=use_https,
use_old_certs=use_old_certs,
enable_rate_limits=enable_rate_limits,
v1_only_domain=v1_only_domain,
ssl_protocols=ssl_protocols,
ssl_ciphers=':'.join(ssl_ciphers))
write_config(
os.path.join(QUAYCONF_DIR, "nginx/nginx.conf"),
use_https=use_https,
use_old_certs=use_old_certs,
enable_rate_limits=enable_rate_limits,
v1_only_domain=v1_only_domain,
ssl_protocols=ssl_protocols,
ssl_ciphers=":".join(ssl_ciphers),
)
def generate_server_config(config):
"""
"""
Generates server config from the app config
"""
config = config or {}
tuf_server = config.get('TUF_SERVER', None)
tuf_host = config.get('TUF_HOST', None)
signing_enabled = config.get('FEATURE_SIGNING', False)
maximum_layer_size = config.get('MAXIMUM_LAYER_SIZE', '20G')
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
config = config or {}
tuf_server = config.get("TUF_SERVER", None)
tuf_host = config.get("TUF_HOST", None)
signing_enabled = config.get("FEATURE_SIGNING", False)
maximum_layer_size = config.get("MAXIMUM_LAYER_SIZE", "20G")
enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config(
os.path.join(QUAYCONF_DIR, 'nginx/server-base.conf'), tuf_server=tuf_server, tuf_host=tuf_host,
signing_enabled=signing_enabled, maximum_layer_size=maximum_layer_size,
enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR)
write_config(
os.path.join(QUAYCONF_DIR, "nginx/server-base.conf"),
tuf_server=tuf_server,
tuf_host=tuf_host,
signing_enabled=signing_enabled,
maximum_layer_size=maximum_layer_size,
enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR,
)
def generate_rate_limiting_config(config):
"""
"""
Generates rate limiting config from the app config
"""
config = config or {}
non_rate_limited_namespaces = config.get('NON_RATE_LIMITED_NAMESPACES') or set()
enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
write_config(
os.path.join(QUAYCONF_DIR, 'nginx/rate-limiting.conf'),
non_rate_limited_namespaces=non_rate_limited_namespaces,
enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR)
config = config or {}
non_rate_limited_namespaces = config.get("NON_RATE_LIMITED_NAMESPACES") or set()
enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
write_config(
os.path.join(QUAYCONF_DIR, "nginx/rate-limiting.conf"),
non_rate_limited_namespaces=non_rate_limited_namespaces,
enable_rate_limits=enable_rate_limits,
static_dir=STATIC_DIR,
)
if __name__ == "__main__":
if os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/config.yaml')):
with open(os.path.join(QUAYCONF_DIR, 'stack/config.yaml'), 'r') as f:
config = yaml.load(f)
else:
config = None
if os.path.exists(os.path.join(QUAYCONF_DIR, "stack/config.yaml")):
with open(os.path.join(QUAYCONF_DIR, "stack/config.yaml"), "r") as f:
config = yaml.load(f)
else:
config = None
generate_rate_limiting_config(config)
generate_server_config(config)
generate_nginx_config(config)
generate_rate_limiting_config(config)
generate_server_config(config)
generate_nginx_config(config)

View file

@ -12,136 +12,74 @@ QUAY_OVERRIDE_SERVICES = os.getenv("QUAY_OVERRIDE_SERVICES", [])
def default_services():
return {
"blobuploadcleanupworker": {
"autostart": "true"
},
"buildlogsarchiver": {
"autostart": "true"
},
"builder": {
"autostart": "true"
},
"chunkcleanupworker": {
"autostart": "true"
},
"expiredappspecifictokenworker": {
"autostart": "true"
},
"exportactionlogsworker": {
"autostart": "true"
},
"gcworker": {
"autostart": "true"
},
"globalpromstats": {
"autostart": "true"
},
"labelbackfillworker": {
"autostart": "true"
},
"logrotateworker": {
"autostart": "true"
},
"namespacegcworker": {
"autostart": "true"
},
"notificationworker": {
"autostart": "true"
},
"queuecleanupworker": {
"autostart": "true"
},
"repositoryactioncounter": {
"autostart": "true"
},
"security_notification_worker": {
"autostart": "true"
},
"securityworker": {
"autostart": "true"
},
"storagereplication": {
"autostart": "true"
},
"tagbackfillworker": {
"autostart": "true"
},
"teamsyncworker": {
"autostart": "true"
},
"dnsmasq": {
"autostart": "true"
},
"gunicorn-registry": {
"autostart": "true"
},
"gunicorn-secscan": {
"autostart": "true"
},
"gunicorn-verbs": {
"autostart": "true"
},
"gunicorn-web": {
"autostart": "true"
},
"ip-resolver-update-worker": {
"autostart": "true"
},
"jwtproxy": {
"autostart": "true"
},
"memcache": {
"autostart": "true"
},
"nginx": {
"autostart": "true"
},
"prometheus-aggregator": {
"autostart": "true"
},
"servicekey": {
"autostart": "true"
},
"repomirrorworker": {
"autostart": "false"
return {
"blobuploadcleanupworker": {"autostart": "true"},
"buildlogsarchiver": {"autostart": "true"},
"builder": {"autostart": "true"},
"chunkcleanupworker": {"autostart": "true"},
"expiredappspecifictokenworker": {"autostart": "true"},
"exportactionlogsworker": {"autostart": "true"},
"gcworker": {"autostart": "true"},
"globalpromstats": {"autostart": "true"},
"labelbackfillworker": {"autostart": "true"},
"logrotateworker": {"autostart": "true"},
"namespacegcworker": {"autostart": "true"},
"notificationworker": {"autostart": "true"},
"queuecleanupworker": {"autostart": "true"},
"repositoryactioncounter": {"autostart": "true"},
"security_notification_worker": {"autostart": "true"},
"securityworker": {"autostart": "true"},
"storagereplication": {"autostart": "true"},
"tagbackfillworker": {"autostart": "true"},
"teamsyncworker": {"autostart": "true"},
"dnsmasq": {"autostart": "true"},
"gunicorn-registry": {"autostart": "true"},
"gunicorn-secscan": {"autostart": "true"},
"gunicorn-verbs": {"autostart": "true"},
"gunicorn-web": {"autostart": "true"},
"ip-resolver-update-worker": {"autostart": "true"},
"jwtproxy": {"autostart": "true"},
"memcache": {"autostart": "true"},
"nginx": {"autostart": "true"},
"prometheus-aggregator": {"autostart": "true"},
"servicekey": {"autostart": "true"},
"repomirrorworker": {"autostart": "false"},
}
}
def generate_supervisord_config(filename, config):
with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(config=config)
with open(filename + ".jnj") as f:
template = jinja2.Template(f.read())
rendered = template.render(config=config)
with open(filename, 'w') as f:
f.write(rendered)
with open(filename, "w") as f:
f.write(rendered)
def limit_services(config, enabled_services):
if enabled_services == []:
return
if enabled_services == []:
return
for service in config.keys():
if service in enabled_services:
config[service]["autostart"] = "true"
else:
config[service]["autostart"] = "false"
for service in config.keys():
if service in enabled_services:
config[service]["autostart"] = "true"
else:
config[service]["autostart"] = "false"
def override_services(config, override_services):
if override_services == []:
return
if override_services == []:
return
for service in config.keys():
if service + "=true" in override_services:
config[service]["autostart"] = "true"
elif service + "=false" in override_services:
config[service]["autostart"] = "false"
for service in config.keys():
if service + "=true" in override_services:
config[service]["autostart"] = "true"
elif service + "=false" in override_services:
config[service]["autostart"] = "false"
if __name__ == "__main__":
config = default_services()
limit_services(config, QUAY_SERVICES)
override_services(config, QUAY_OVERRIDE_SERVICES)
generate_supervisord_config(os.path.join(QUAYCONF_DIR, 'supervisord.conf'), config)
config = default_services()
limit_services(config, QUAY_SERVICES)
override_services(config, QUAY_OVERRIDE_SERVICES)
generate_supervisord_config(os.path.join(QUAYCONF_DIR, "supervisord.conf"), config)

View file

@ -6,17 +6,23 @@ import jinja2
from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services
def render_supervisord_conf(config):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj")) as f:
template = jinja2.Template(f.read())
return template.render(config=config)
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj"
)
) as f:
template = jinja2.Template(f.read())
return template.render(config=config)
def test_supervisord_conf_create_defaults():
config = default_services()
limit_services(config, [])
rendered = render_supervisord_conf(config)
config = default_services()
limit_services(config, [])
rendered = render_supervisord_conf(config)
expected = """[supervisord]
expected = """[supervisord]
nodaemon=true
[unix_http_server]
@ -392,14 +398,15 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true
stderr_events_enabled = true
# EOF NO NEWLINE"""
assert rendered == expected
assert rendered == expected
def test_supervisord_conf_create_all_overrides():
config = default_services()
limit_services(config, "servicekey,prometheus-aggregator")
rendered = render_supervisord_conf(config)
config = default_services()
limit_services(config, "servicekey,prometheus-aggregator")
rendered = render_supervisord_conf(config)
expected = """[supervisord]
expected = """[supervisord]
nodaemon=true
[unix_http_server]
@ -775,4 +782,4 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true
stderr_events_enabled = true
# EOF NO NEWLINE"""
assert rendered == expected
assert rendered == expected

943
config.py

File diff suppressed because it is too large Load diff

View file

@ -7,31 +7,31 @@ import subprocess
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
STATIC_DIR = os.path.join(ROOT_DIR, 'static/')
STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/')
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/')
TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/')
IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ
STATIC_DIR = os.path.join(ROOT_DIR, "static/")
STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
def _get_version_number_changelog():
try:
with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f:
return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0)
except IOError:
return ''
try:
with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
except IOError:
return ""
def _get_git_sha():
if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read()
else:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError):
pass
return "unknown"
if os.path.exists("GIT_HEAD"):
with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
return f.read()
else:
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
except (OSError, subprocess.CalledProcessError):
pass
return "unknown"
__version__ = _get_version_number_changelog()

View file

@ -15,28 +15,29 @@ app = Flask(__name__)
logger = logging.getLogger(__name__)
OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, 'config_app/conf/stack')
INIT_SCRIPTS_LOCATION = '/conf/init/'
OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, "config_app/conf/stack")
INIT_SCRIPTS_LOCATION = "/conf/init/"
is_testing = 'TEST' in os.environ
is_testing = "TEST" in os.environ
is_kubernetes = IS_KUBERNETES
logger.debug('Configuration is on a kubernetes deployment: %s' % IS_KUBERNETES)
logger.debug("Configuration is on a kubernetes deployment: %s" % IS_KUBERNETES)
config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py',
testing=is_testing)
config_provider = get_config_provider(
OVERRIDE_CONFIG_DIRECTORY, "config.yaml", "config.py", testing=is_testing
)
if is_testing:
from test.testconfig import TestConfig
from test.testconfig import TestConfig
logger.debug('Loading test config.')
app.config.from_object(TestConfig())
logger.debug("Loading test config.")
app.config.from_object(TestConfig())
else:
from config import DefaultConfig
from config import DefaultConfig
logger.debug('Loading default config.')
app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter)
logger.debug("Loading default config.")
app.config.from_object(DefaultConfig())
app.teardown_request(database.close_db_filter)
# Load the override config via the provider.
config_provider.update_app_config(app.config)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -9,18 +10,24 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True)
bind = '0.0.0.0:5000'
bind = "0.0.0.0:5000"
workers = 1
worker_class = 'gevent'
worker_class = "gevent"
daemon = False
pythonpath = '.'
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -1,5 +1,6 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@ -10,17 +11,23 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True)
bind = 'unix:/tmp/gunicorn_web.sock'
bind = "unix:/tmp/gunicorn_web.sock"
workers = 1
worker_class = 'gevent'
pythonpath = '.'
worker_class = "gevent"
pythonpath = "."
preload_app = True
def post_fork(server, worker):
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
# Reset the Random library to ensure it won't raise the "PID check failed." error after
# gunicorn forks.
Random.atfork()
def when_ready(server):
logger = logging.getLogger(__name__)
logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
logger = logging.getLogger(__name__)
logger.debug(
"Starting local gunicorn with %s workers and %s worker class",
workers,
worker_class,
)

View file

@ -3,6 +3,6 @@ from config_app.c_app import app as application
# Bind all of the blueprints
import config_web
if __name__ == '__main__':
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host='0.0.0.0')
if __name__ == "__main__":
logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")

View file

@ -13,141 +13,153 @@ from config_app.c_app import app, IS_KUBERNETES
from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest
logger = logging.getLogger(__name__)
api_bp = Blueprint('api', __name__)
api_bp = Blueprint("api", __name__)
CROSS_DOMAIN_HEADERS = ['Authorization', 'Content-Type', 'X-Requested-With']
CROSS_DOMAIN_HEADERS = ["Authorization", "Content-Type", "X-Requested-With"]
class ApiExceptionHandlingApi(Api):
pass
pass
@crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS)
def handle_error(self, error):
return super(ApiExceptionHandlingApi, self).handle_error(error)
@crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS)
def handle_error(self, error):
return super(ApiExceptionHandlingApi, self).handle_error(error)
api = ApiExceptionHandlingApi()
api.init_app(api_bp)
def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
if not metadata:
metadata = {}
if not metadata:
metadata = {}
if repo:
repo_name = repo.name
if repo:
repo_name = repo.name
model.log.log_action(
kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata
)
model.log.log_action(kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata)
def format_date(date):
""" Output an RFC822 date format. """
if date is None:
return None
return formatdate(timegm(date.utctimetuple()))
""" Output an RFC822 date format. """
if date is None:
return None
return formatdate(timegm(date.utctimetuple()))
def resource(*urls, **kwargs):
def wrapper(api_resource):
if not api_resource:
return None
def wrapper(api_resource):
if not api_resource:
return None
api_resource.registered = True
api.add_resource(api_resource, *urls, **kwargs)
return api_resource
api_resource.registered = True
api.add_resource(api_resource, *urls, **kwargs)
return api_resource
return wrapper
return wrapper
class ApiResource(Resource):
registered = False
method_decorators = []
registered = False
method_decorators = []
def options(self):
return None, 200
def options(self):
return None, 200
def add_method_metadata(name, value):
def modifier(func):
if func is None:
return None
def modifier(func):
if func is None:
return None
if '__api_metadata' not in dir(func):
func.__api_metadata = {}
func.__api_metadata[name] = value
return func
if "__api_metadata" not in dir(func):
func.__api_metadata = {}
func.__api_metadata[name] = value
return func
return modifier
return modifier
def method_metadata(func, name):
if func is None:
return None
if func is None:
return None
if '__api_metadata' in dir(func):
return func.__api_metadata.get(name, None)
return None
if "__api_metadata" in dir(func):
return func.__api_metadata.get(name, None)
return None
def no_cache(f):
@wraps(f)
def add_no_cache(*args, **kwargs):
response = f(*args, **kwargs)
if response is not None:
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
return response
return add_no_cache
@wraps(f)
def add_no_cache(*args, **kwargs):
response = f(*args, **kwargs)
if response is not None:
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
return response
return add_no_cache
def define_json_response(schema_name):
def wrapper(func):
@add_method_metadata('response_schema', schema_name)
@wraps(func)
def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name]
resp = func(self, *args, **kwargs)
def wrapper(func):
@add_method_metadata("response_schema", schema_name)
@wraps(func)
def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name]
resp = func(self, *args, **kwargs)
if app.config['TESTING']:
try:
validate(resp, schema)
except ValidationError as ex:
raise InvalidResponse(ex.message)
if app.config["TESTING"]:
try:
validate(resp, schema)
except ValidationError as ex:
raise InvalidResponse(ex.message)
return resp
return wrapped
return wrapper
return resp
return wrapped
return wrapper
def validate_json_request(schema_name, optional=False):
def wrapper(func):
@add_method_metadata('request_schema', schema_name)
@wraps(func)
def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name]
try:
json_data = request.get_json()
if json_data is None:
if not optional:
raise InvalidRequest('Missing JSON body')
else:
validate(json_data, schema)
return func(self, *args, **kwargs)
except ValidationError as ex:
raise InvalidRequest(ex.message)
return wrapped
return wrapper
def wrapper(func):
@add_method_metadata("request_schema", schema_name)
@wraps(func)
def wrapped(self, *args, **kwargs):
schema = self.schemas[schema_name]
try:
json_data = request.get_json()
if json_data is None:
if not optional:
raise InvalidRequest("Missing JSON body")
else:
validate(json_data, schema)
return func(self, *args, **kwargs)
except ValidationError as ex:
raise InvalidRequest(ex.message)
return wrapped
return wrapper
def kubernetes_only(f):
""" Aborts the request with a 400 if the app is not running on kubernetes """
@wraps(f)
def abort_if_not_kube(*args, **kwargs):
if not IS_KUBERNETES:
abort(400)
""" Aborts the request with a 400 if the app is not running on kubernetes """
return f(*args, **kwargs)
return abort_if_not_kube
@wraps(f)
def abort_if_not_kube(*args, **kwargs):
if not IS_KUBERNETES:
abort(400)
nickname = partial(add_method_metadata, 'nickname')
return f(*args, **kwargs)
return abort_if_not_kube
nickname = partial(add_method_metadata, "nickname")
import config_app.config_endpoints.api.discovery

View file

@ -5,250 +5,253 @@ from collections import OrderedDict
from config_app.c_app import app
from config_app.config_endpoints.api import method_metadata
from config_app.config_endpoints.common import fully_qualified_name, PARAM_REGEX, TYPE_CONVERTER
from config_app.config_endpoints.common import (
fully_qualified_name,
PARAM_REGEX,
TYPE_CONVERTER,
)
logger = logging.getLogger(__name__)
def generate_route_data():
include_internal = True
compact = True
include_internal = True
compact = True
def swagger_parameter(name, description, kind='path', param_type='string', required=True,
enum=None, schema=None):
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
parameter_info = {
'name': name,
'in': kind,
'required': required
}
def swagger_parameter(
name,
description,
kind="path",
param_type="string",
required=True,
enum=None,
schema=None,
):
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
parameter_info = {"name": name, "in": kind, "required": required}
if schema:
parameter_info['schema'] = {
'$ref': '#/definitions/%s' % schema
}
else:
parameter_info['type'] = param_type
if enum is not None and len(list(enum)) > 0:
parameter_info['enum'] = list(enum)
return parameter_info
paths = {}
models = {}
tags = []
tags_added = set()
operation_ids = set()
for rule in app.url_map.iter_rules():
endpoint_method = app.view_functions[rule.endpoint]
# Verify that we have a view class for this API method.
if not 'view_class' in dir(endpoint_method):
continue
view_class = endpoint_method.view_class
# Hide the class if it is internal.
internal = method_metadata(view_class, 'internal')
if not include_internal and internal:
continue
# Build the tag.
parts = fully_qualified_name(view_class).split('.')
tag_name = parts[-2]
if not tag_name in tags_added:
tags_added.add(tag_name)
tags.append({
'name': tag_name,
'description': (sys.modules[view_class.__module__].__doc__ or '').strip()
})
# Build the Swagger data for the path.
swagger_path = PARAM_REGEX.sub(r'{\2}', rule.rule)
full_name = fully_qualified_name(view_class)
path_swagger = {
'x-name': full_name,
'x-path': swagger_path,
'x-tag': tag_name
}
related_user_res = method_metadata(view_class, 'related_user_resource')
if related_user_res is not None:
path_swagger['x-user-related'] = fully_qualified_name(related_user_res)
paths[swagger_path] = path_swagger
# Add any global path parameters.
param_data_map = view_class.__api_path_params if '__api_path_params' in dir(
view_class) else {}
if param_data_map:
path_parameters_swagger = []
for path_parameter in param_data_map:
description = param_data_map[path_parameter].get('description')
path_parameters_swagger.append(swagger_parameter(path_parameter, description))
path_swagger['parameters'] = path_parameters_swagger
# Add the individual HTTP operations.
method_names = list(rule.methods.difference(['HEAD', 'OPTIONS']))
for method_name in method_names:
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
method = getattr(view_class, method_name.lower(), None)
if method is None:
logger.debug('Unable to find method for %s in class %s', method_name, view_class)
continue
operationId = method_metadata(method, 'nickname')
operation_swagger = {
'operationId': operationId,
'parameters': [],
}
if operationId is None:
continue
if operationId in operation_ids:
raise Exception('Duplicate operation Id: %s' % operationId)
operation_ids.add(operationId)
# Mark the method as internal.
internal = method_metadata(method, 'internal')
if internal is not None:
operation_swagger['x-internal'] = True
if include_internal:
requires_fresh_login = method_metadata(method, 'requires_fresh_login')
if requires_fresh_login is not None:
operation_swagger['x-requires-fresh-login'] = True
# Add the path parameters.
if rule.arguments:
for path_parameter in rule.arguments:
description = param_data_map.get(path_parameter, {}).get('description')
operation_swagger['parameters'].append(
swagger_parameter(path_parameter, description))
# Add the query parameters.
if '__api_query_params' in dir(method):
for query_parameter_info in method.__api_query_params:
name = query_parameter_info['name']
description = query_parameter_info['help']
param_type = TYPE_CONVERTER[query_parameter_info['type']]
required = query_parameter_info['required']
operation_swagger['parameters'].append(
swagger_parameter(name, description, kind='query',
param_type=param_type,
required=required,
enum=query_parameter_info['choices']))
# Add the OAuth security block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
scope = method_metadata(method, 'oauth2_scope')
if scope and not compact:
operation_swagger['security'] = [{'oauth2_implicit': [scope.scope]}]
# Add the responses block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
response_schema_name = method_metadata(method, 'response_schema')
if not compact:
if response_schema_name:
models[response_schema_name] = view_class.schemas[response_schema_name]
models['ApiError'] = {
'type': 'object',
'properties': {
'status': {
'type': 'integer',
'description': 'Status code of the response.'
},
'type': {
'type': 'string',
'description': 'Reference to the type of the error.'
},
'detail': {
'type': 'string',
'description': 'Details about the specific instance of the error.'
},
'title': {
'type': 'string',
'description': 'Unique error code to identify the type of error.'
},
'error_message': {
'type': 'string',
'description': 'Deprecated; alias for detail'
},
'error_type': {
'type': 'string',
'description': 'Deprecated; alias for detail'
}
},
'required': [
'status',
'type',
'title',
]
}
responses = {
'400': {
'description': 'Bad Request',
},
'401': {
'description': 'Session required',
},
'403': {
'description': 'Unauthorized access',
},
'404': {
'description': 'Not found',
},
}
for _, body in responses.items():
body['schema'] = {'$ref': '#/definitions/ApiError'}
if method_name == 'DELETE':
responses['204'] = {
'description': 'Deleted'
}
elif method_name == 'POST':
responses['201'] = {
'description': 'Successful creation'
}
if schema:
parameter_info["schema"] = {"$ref": "#/definitions/%s" % schema}
else:
responses['200'] = {
'description': 'Successful invocation'
}
parameter_info["type"] = param_type
if response_schema_name:
responses['200']['schema'] = {
'$ref': '#/definitions/%s' % response_schema_name
}
if enum is not None and len(list(enum)) > 0:
parameter_info["enum"] = list(enum)
operation_swagger['responses'] = responses
return parameter_info
# Add the request block.
request_schema_name = method_metadata(method, 'request_schema')
if request_schema_name and not compact:
models[request_schema_name] = view_class.schemas[request_schema_name]
paths = {}
models = {}
tags = []
tags_added = set()
operation_ids = set()
operation_swagger['parameters'].append(
swagger_parameter('body', 'Request body contents.', kind='body',
schema=request_schema_name))
for rule in app.url_map.iter_rules():
endpoint_method = app.view_functions[rule.endpoint]
# Add the operation to the parent path.
if not internal or (internal and include_internal):
path_swagger[method_name.lower()] = operation_swagger
# Verify that we have a view class for this API method.
if not "view_class" in dir(endpoint_method):
continue
tags.sort(key=lambda t: t['name'])
paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]['x-tag']))
view_class = endpoint_method.view_class
if compact:
return {'paths': paths}
# Hide the class if it is internal.
internal = method_metadata(view_class, "internal")
if not include_internal and internal:
continue
# Build the tag.
parts = fully_qualified_name(view_class).split(".")
tag_name = parts[-2]
if not tag_name in tags_added:
tags_added.add(tag_name)
tags.append(
{
"name": tag_name,
"description": (
sys.modules[view_class.__module__].__doc__ or ""
).strip(),
}
)
# Build the Swagger data for the path.
swagger_path = PARAM_REGEX.sub(r"{\2}", rule.rule)
full_name = fully_qualified_name(view_class)
path_swagger = {"x-name": full_name, "x-path": swagger_path, "x-tag": tag_name}
related_user_res = method_metadata(view_class, "related_user_resource")
if related_user_res is not None:
path_swagger["x-user-related"] = fully_qualified_name(related_user_res)
paths[swagger_path] = path_swagger
# Add any global path parameters.
param_data_map = (
view_class.__api_path_params
if "__api_path_params" in dir(view_class)
else {}
)
if param_data_map:
path_parameters_swagger = []
for path_parameter in param_data_map:
description = param_data_map[path_parameter].get("description")
path_parameters_swagger.append(
swagger_parameter(path_parameter, description)
)
path_swagger["parameters"] = path_parameters_swagger
# Add the individual HTTP operations.
method_names = list(rule.methods.difference(["HEAD", "OPTIONS"]))
for method_name in method_names:
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
method = getattr(view_class, method_name.lower(), None)
if method is None:
logger.debug(
"Unable to find method for %s in class %s", method_name, view_class
)
continue
operationId = method_metadata(method, "nickname")
operation_swagger = {"operationId": operationId, "parameters": []}
if operationId is None:
continue
if operationId in operation_ids:
raise Exception("Duplicate operation Id: %s" % operationId)
operation_ids.add(operationId)
# Mark the method as internal.
internal = method_metadata(method, "internal")
if internal is not None:
operation_swagger["x-internal"] = True
if include_internal:
requires_fresh_login = method_metadata(method, "requires_fresh_login")
if requires_fresh_login is not None:
operation_swagger["x-requires-fresh-login"] = True
# Add the path parameters.
if rule.arguments:
for path_parameter in rule.arguments:
description = param_data_map.get(path_parameter, {}).get(
"description"
)
operation_swagger["parameters"].append(
swagger_parameter(path_parameter, description)
)
# Add the query parameters.
if "__api_query_params" in dir(method):
for query_parameter_info in method.__api_query_params:
name = query_parameter_info["name"]
description = query_parameter_info["help"]
param_type = TYPE_CONVERTER[query_parameter_info["type"]]
required = query_parameter_info["required"]
operation_swagger["parameters"].append(
swagger_parameter(
name,
description,
kind="query",
param_type=param_type,
required=required,
enum=query_parameter_info["choices"],
)
)
# Add the OAuth security block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
scope = method_metadata(method, "oauth2_scope")
if scope and not compact:
operation_swagger["security"] = [{"oauth2_implicit": [scope.scope]}]
# Add the responses block.
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
response_schema_name = method_metadata(method, "response_schema")
if not compact:
if response_schema_name:
models[response_schema_name] = view_class.schemas[
response_schema_name
]
models["ApiError"] = {
"type": "object",
"properties": {
"status": {
"type": "integer",
"description": "Status code of the response.",
},
"type": {
"type": "string",
"description": "Reference to the type of the error.",
},
"detail": {
"type": "string",
"description": "Details about the specific instance of the error.",
},
"title": {
"type": "string",
"description": "Unique error code to identify the type of error.",
},
"error_message": {
"type": "string",
"description": "Deprecated; alias for detail",
},
"error_type": {
"type": "string",
"description": "Deprecated; alias for detail",
},
},
"required": ["status", "type", "title"],
}
responses = {
"400": {"description": "Bad Request"},
"401": {"description": "Session required"},
"403": {"description": "Unauthorized access"},
"404": {"description": "Not found"},
}
for _, body in responses.items():
body["schema"] = {"$ref": "#/definitions/ApiError"}
if method_name == "DELETE":
responses["204"] = {"description": "Deleted"}
elif method_name == "POST":
responses["201"] = {"description": "Successful creation"}
else:
responses["200"] = {"description": "Successful invocation"}
if response_schema_name:
responses["200"]["schema"] = {
"$ref": "#/definitions/%s" % response_schema_name
}
operation_swagger["responses"] = responses
# Add the request block.
request_schema_name = method_metadata(method, "request_schema")
if request_schema_name and not compact:
models[request_schema_name] = view_class.schemas[request_schema_name]
operation_swagger["parameters"].append(
swagger_parameter(
"body",
"Request body contents.",
kind="body",
schema=request_schema_name,
)
)
# Add the operation to the parent path.
if not internal or (internal and include_internal):
path_swagger[method_name.lower()] = operation_swagger
tags.sort(key=lambda t: t["name"])
paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]["x-tag"]))
if compact:
return {"paths": paths}

View file

@ -6,138 +6,152 @@ from config_app.config_util.config import get_config_as_kube_secret
from data.database import configure
from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname, kubernetes_only, validate_json_request
from config_app.config_util.k8saccessor import KubernetesAccessorSingleton, K8sApiException
from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
kubernetes_only,
validate_json_request,
)
from config_app.config_util.k8saccessor import (
KubernetesAccessorSingleton,
K8sApiException,
)
logger = logging.getLogger(__name__)
@resource('/v1/kubernetes/deployments/')
@resource("/v1/kubernetes/deployments/")
class SuperUserKubernetesDeployment(ApiResource):
""" Resource for the getting the status of Red Hat Quay deployments and cycling them """
schemas = {
'ValidateDeploymentNames': {
'type': 'object',
'description': 'Validates deployment names for cycling',
'required': [
'deploymentNames'
],
'properties': {
'deploymentNames': {
'type': 'array',
'description': 'The names of the deployments to cycle'
},
},
""" Resource for the getting the status of Red Hat Quay deployments and cycling them """
schemas = {
"ValidateDeploymentNames": {
"type": "object",
"description": "Validates deployment names for cycling",
"required": ["deploymentNames"],
"properties": {
"deploymentNames": {
"type": "array",
"description": "The names of the deployments to cycle",
}
},
}
}
}
@kubernetes_only
@nickname('scGetNumDeployments')
def get(self):
return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
@kubernetes_only
@nickname("scGetNumDeployments")
def get(self):
return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
@kubernetes_only
@validate_json_request('ValidateDeploymentNames')
@nickname('scCycleQEDeployments')
def put(self):
deployment_names = request.get_json()['deploymentNames']
return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(deployment_names)
@kubernetes_only
@validate_json_request("ValidateDeploymentNames")
@nickname("scCycleQEDeployments")
def put(self):
deployment_names = request.get_json()["deploymentNames"]
return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(
deployment_names
)
@resource('/v1/kubernetes/deployment/<deployment>/status')
@resource("/v1/kubernetes/deployment/<deployment>/status")
class QEDeploymentRolloutStatus(ApiResource):
@kubernetes_only
@nickname('scGetDeploymentRolloutStatus')
def get(self, deployment):
deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(deployment)
return {
'status': deployment_rollout_status.status,
'message': deployment_rollout_status.message,
}
@kubernetes_only
@nickname("scGetDeploymentRolloutStatus")
def get(self, deployment):
deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(
deployment
)
return {
"status": deployment_rollout_status.status,
"message": deployment_rollout_status.message,
}
@resource('/v1/kubernetes/deployments/rollback')
@resource("/v1/kubernetes/deployments/rollback")
class QEDeploymentRollback(ApiResource):
""" Resource for rolling back deployments """
schemas = {
'ValidateDeploymentNames': {
'type': 'object',
'description': 'Validates deployment names for rolling back',
'required': [
'deploymentNames'
],
'properties': {
'deploymentNames': {
'type': 'array',
'description': 'The names of the deployments to rollback'
},
},
}
}
""" Resource for rolling back deployments """
@kubernetes_only
@nickname('scRollbackDeployments')
@validate_json_request('ValidateDeploymentNames')
def post(self):
"""
schemas = {
"ValidateDeploymentNames": {
"type": "object",
"description": "Validates deployment names for rolling back",
"required": ["deploymentNames"],
"properties": {
"deploymentNames": {
"type": "array",
"description": "The names of the deployments to rollback",
}
},
}
}
@kubernetes_only
@nickname("scRollbackDeployments")
@validate_json_request("ValidateDeploymentNames")
def post(self):
"""
Returns the config to its original state and rolls back deployments
:return:
"""
deployment_names = request.get_json()['deploymentNames']
deployment_names = request.get_json()["deploymentNames"]
# To roll back a deployment, we must do 2 things:
# 1. Roll back the config secret to its old value (discarding changes we made in this session)
# 2. Trigger a rollback to the previous revision, so that the pods will be restarted with
# the old config
old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir())
kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.replace_qe_secret(old_secret)
# To roll back a deployment, we must do 2 things:
# 1. Roll back the config secret to its old value (discarding changes we made in this session)
# 2. Trigger a rollback to the previous revision, so that the pods will be restarted with
# the old config
old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir())
kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.replace_qe_secret(old_secret)
try:
for name in deployment_names:
kube_accessor.rollback_deployment(name)
except K8sApiException as e:
logger.exception('Failed to rollback deployment.')
return make_response(e.message, 503)
try:
for name in deployment_names:
kube_accessor.rollback_deployment(name)
except K8sApiException as e:
logger.exception("Failed to rollback deployment.")
return make_response(e.message, 503)
return make_response('Ok', 204)
return make_response("Ok", 204)
@resource('/v1/kubernetes/config')
@resource("/v1/kubernetes/config")
class SuperUserKubernetesConfiguration(ApiResource):
""" Resource for saving the config files to kubernetes secrets. """
""" Resource for saving the config files to kubernetes secrets. """
@kubernetes_only
@nickname('scDeployConfiguration')
def post(self):
try:
new_secret = get_config_as_kube_secret(config_provider.get_config_dir_path())
KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
except K8sApiException as e:
logger.exception('Failed to deploy qe config secret to kubernetes.')
return make_response(e.message, 503)
@kubernetes_only
@nickname("scDeployConfiguration")
def post(self):
try:
new_secret = get_config_as_kube_secret(
config_provider.get_config_dir_path()
)
KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
except K8sApiException as e:
logger.exception("Failed to deploy qe config secret to kubernetes.")
return make_response(e.message, 503)
return make_response('Ok', 201)
return make_response("Ok", 201)
@resource('/v1/kubernetes/config/populate')
@resource("/v1/kubernetes/config/populate")
class KubernetesConfigurationPopulator(ApiResource):
""" Resource for populating the local configuration from the cluster's kubernetes secrets. """
""" Resource for populating the local configuration from the cluster's kubernetes secrets. """
@kubernetes_only
@nickname('scKubePopulateConfig')
def post(self):
# Get a clean transient directory to write the config into
config_provider.new_config_dir()
@kubernetes_only
@nickname("scKubePopulateConfig")
def post(self):
# Get a clean transient directory to write the config into
config_provider.new_config_dir()
kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path())
config_provider.create_copy_of_config_dir()
kube_accessor = KubernetesAccessorSingleton.get_instance()
kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path())
config_provider.create_copy_of_config_dir()
# We update the db configuration to connect to their specified one
# (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error,
# and all of the options create a new clean dir, so we'll never pollute configs)
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
# We update the db configuration to connect to their specified one
# (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error,
# and all of the options create a new clean dir, so we'll never pollute configs)
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
return 200
return 200

View file

@ -2,301 +2,283 @@ import logging
from flask import abort, request
from config_app.config_endpoints.api.suconfig_models_pre_oci import pre_oci_model as model
from config_app.config_endpoints.api import resource, ApiResource, nickname, validate_json_request
from config_app.c_app import (app, config_provider, superusers, ip_resolver,
instance_keys, INIT_SCRIPTS_LOCATION)
from config_app.config_endpoints.api.suconfig_models_pre_oci import (
pre_oci_model as model,
)
from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
validate_json_request,
)
from config_app.c_app import (
app,
config_provider,
superusers,
ip_resolver,
instance_keys,
INIT_SCRIPTS_LOCATION,
)
from data.database import configure
from data.runmigration import run_alembic_migration
from util.config.configutil import add_enterprise_config_defaults
from util.config.validator import validate_service_for_config, ValidatorContext, \
is_valid_config_upload_filename
from util.config.validator import (
validate_service_for_config,
ValidatorContext,
is_valid_config_upload_filename,
)
logger = logging.getLogger(__name__)
def database_is_valid():
""" Returns whether the database, as configured, is valid. """
return model.is_valid()
""" Returns whether the database, as configured, is valid. """
return model.is_valid()
def database_has_users():
""" Returns whether the database has any users defined. """
return model.has_users()
""" Returns whether the database has any users defined. """
return model.has_users()
@resource('/v1/superuser/config')
@resource("/v1/superuser/config")
class SuperUserConfig(ApiResource):
""" Resource for fetching and updating the current configuration, if any. """
schemas = {
'UpdateConfig': {
'type': 'object',
'description': 'Updates the YAML config file',
'required': [
'config',
],
'properties': {
'config': {
'type': 'object'
},
'password': {
'type': 'string'
},
},
},
}
""" Resource for fetching and updating the current configuration, if any. """
@nickname('scGetConfig')
def get(self):
""" Returns the currently defined configuration, if any. """
config_object = config_provider.get_config()
return {
'config': config_object
schemas = {
"UpdateConfig": {
"type": "object",
"description": "Updates the YAML config file",
"required": ["config"],
"properties": {
"config": {"type": "object"},
"password": {"type": "string"},
},
}
}
@nickname('scUpdateConfig')
@validate_json_request('UpdateConfig')
def put(self):
""" Updates the config override file. """
# Note: This method is called to set the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup.
config_object = request.get_json()['config']
@nickname("scGetConfig")
def get(self):
""" Returns the currently defined configuration, if any. """
config_object = config_provider.get_config()
return {"config": config_object}
# Add any enterprise defaults missing from the config.
add_enterprise_config_defaults(config_object, app.config['SECRET_KEY'])
@nickname("scUpdateConfig")
@validate_json_request("UpdateConfig")
def put(self):
""" Updates the config override file. """
# Note: This method is called to set the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup.
config_object = request.get_json()["config"]
# Write the configuration changes to the config override file.
config_provider.save_config(config_object)
# Add any enterprise defaults missing from the config.
add_enterprise_config_defaults(config_object, app.config["SECRET_KEY"])
# now try to connect to the db provided in their config to validate it works
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined, testing=app.config['TESTING'])
# Write the configuration changes to the config override file.
config_provider.save_config(config_object)
return {
'exists': True,
'config': config_object
}
# now try to connect to the db provided in their config to validate it works
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined, testing=app.config["TESTING"])
return {"exists": True, "config": config_object}
@resource('/v1/superuser/registrystatus')
@resource("/v1/superuser/registrystatus")
class SuperUserRegistryStatus(ApiResource):
""" Resource for determining the status of the registry, such as if config exists,
""" Resource for determining the status of the registry, such as if config exists,
if a database is configured, and if it has any defined users.
"""
@nickname('scRegistryStatus')
def get(self):
""" Returns the status of the registry. """
# If there is no config file, we need to setup the database.
if not config_provider.config_exists():
return {
'status': 'config-db'
}
@nickname("scRegistryStatus")
def get(self):
""" Returns the status of the registry. """
# If there is no config file, we need to setup the database.
if not config_provider.config_exists():
return {"status": "config-db"}
# If the database isn't yet valid, then we need to set it up.
if not database_is_valid():
return {
'status': 'setup-db'
}
# If the database isn't yet valid, then we need to set it up.
if not database_is_valid():
return {"status": "setup-db"}
config = config_provider.get_config()
if config and config.get('SETUP_COMPLETE'):
return {
'status': 'config'
}
config = config_provider.get_config()
if config and config.get("SETUP_COMPLETE"):
return {"status": "config"}
return {
'status': 'create-superuser' if not database_has_users() else 'config'
}
return {"status": "create-superuser" if not database_has_users() else "config"}
class _AlembicLogHandler(logging.Handler):
def __init__(self):
super(_AlembicLogHandler, self).__init__()
self.records = []
def __init__(self):
super(_AlembicLogHandler, self).__init__()
self.records = []
def emit(self, record):
self.records.append({
'level': record.levelname,
'message': record.getMessage()
})
def emit(self, record):
self.records.append({"level": record.levelname, "message": record.getMessage()})
def _reload_config():
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
return combined
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
return combined
@resource('/v1/superuser/setupdb')
@resource("/v1/superuser/setupdb")
class SuperUserSetupDatabase(ApiResource):
""" Resource for invoking alembic to setup the database. """
""" Resource for invoking alembic to setup the database. """
@nickname('scSetupDatabase')
def get(self):
""" Invokes the alembic upgrade process. """
# Note: This method is called after the database configured is saved, but before the
# database has any tables. Therefore, we only allow it to be run in that unique case.
if config_provider.config_exists() and not database_is_valid():
combined = _reload_config()
@nickname("scSetupDatabase")
def get(self):
""" Invokes the alembic upgrade process. """
# Note: This method is called after the database configured is saved, but before the
# database has any tables. Therefore, we only allow it to be run in that unique case.
if config_provider.config_exists() and not database_is_valid():
combined = _reload_config()
app.config['DB_URI'] = combined['DB_URI']
db_uri = app.config['DB_URI']
escaped_db_uri = db_uri.replace('%', '%%')
app.config["DB_URI"] = combined["DB_URI"]
db_uri = app.config["DB_URI"]
escaped_db_uri = db_uri.replace("%", "%%")
log_handler = _AlembicLogHandler()
log_handler = _AlembicLogHandler()
try:
run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
except Exception as ex:
return {
'error': str(ex)
}
try:
run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
except Exception as ex:
return {"error": str(ex)}
return {
'logs': log_handler.records
}
return {"logs": log_handler.records}
abort(403)
abort(403)
@resource('/v1/superuser/config/createsuperuser')
@resource("/v1/superuser/config/createsuperuser")
class SuperUserCreateInitialSuperUser(ApiResource):
""" Resource for creating the initial super user. """
schemas = {
'CreateSuperUser': {
'type': 'object',
'description': 'Information for creating the initial super user',
'required': [
'username',
'password',
'email'
],
'properties': {
'username': {
'type': 'string',
'description': 'The username for the superuser'
},
'password': {
'type': 'string',
'description': 'The password for the superuser'
},
'email': {
'type': 'string',
'description': 'The e-mail address for the superuser'
},
},
},
}
""" Resource for creating the initial super user. """
@nickname('scCreateInitialSuperuser')
@validate_json_request('CreateSuperUser')
def post(self):
""" Creates the initial super user, updates the underlying configuration and
schemas = {
"CreateSuperUser": {
"type": "object",
"description": "Information for creating the initial super user",
"required": ["username", "password", "email"],
"properties": {
"username": {
"type": "string",
"description": "The username for the superuser",
},
"password": {
"type": "string",
"description": "The password for the superuser",
},
"email": {
"type": "string",
"description": "The e-mail address for the superuser",
},
},
}
}
@nickname("scCreateInitialSuperuser")
@validate_json_request("CreateSuperUser")
def post(self):
""" Creates the initial super user, updates the underlying configuration and
sets the current session to have that super user. """
_reload_config()
_reload_config()
# Special security check: This method is only accessible when:
# - There is a valid config YAML file.
# - There are currently no users in the database (clean install)
#
# We do this special security check because at the point this method is called, the database
# is clean but does not (yet) have any super users for our permissions code to check against.
if config_provider.config_exists() and not database_has_users():
data = request.get_json()
username = data['username']
password = data['password']
email = data['email']
# Special security check: This method is only accessible when:
# - There is a valid config YAML file.
# - There are currently no users in the database (clean install)
#
# We do this special security check because at the point this method is called, the database
# is clean but does not (yet) have any super users for our permissions code to check against.
if config_provider.config_exists() and not database_has_users():
data = request.get_json()
username = data["username"]
password = data["password"]
email = data["email"]
# Create the user in the database.
superuser_uuid = model.create_superuser(username, password, email)
# Create the user in the database.
superuser_uuid = model.create_superuser(username, password, email)
# Add the user to the config.
config_object = config_provider.get_config()
config_object['SUPER_USERS'] = [username]
config_provider.save_config(config_object)
# Add the user to the config.
config_object = config_provider.get_config()
config_object["SUPER_USERS"] = [username]
config_provider.save_config(config_object)
# Update the in-memory config for the new superuser.
superusers.register_superuser(username)
# Update the in-memory config for the new superuser.
superusers.register_superuser(username)
return {
'status': True
}
return {"status": True}
abort(403)
abort(403)
@resource('/v1/superuser/config/validate/<service>')
@resource("/v1/superuser/config/validate/<service>")
class SuperUserConfigValidate(ApiResource):
""" Resource for validating a block of configuration against an external service. """
schemas = {
'ValidateConfig': {
'type': 'object',
'description': 'Validates configuration',
'required': [
'config'
],
'properties': {
'config': {
'type': 'object'
},
'password': {
'type': 'string',
'description': 'The users password, used for auth validation'
""" Resource for validating a block of configuration against an external service. """
schemas = {
"ValidateConfig": {
"type": "object",
"description": "Validates configuration",
"required": ["config"],
"properties": {
"config": {"type": "object"},
"password": {
"type": "string",
"description": "The users password, used for auth validation",
},
},
}
},
},
}
}
@nickname('scValidateConfig')
@validate_json_request('ValidateConfig')
def post(self, service):
""" Validates the given config for the given service. """
# Note: This method is called to validate the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup. Note that
# this is also safe since this method does not access any information not given in the request.
config = request.get_json()['config']
validator_context = ValidatorContext.from_app(app, config,
request.get_json().get('password', ''),
instance_keys=instance_keys,
ip_resolver=ip_resolver,
config_provider=config_provider,
init_scripts_location=INIT_SCRIPTS_LOCATION)
@nickname("scValidateConfig")
@validate_json_request("ValidateConfig")
def post(self, service):
""" Validates the given config for the given service. """
# Note: This method is called to validate the database configuration before super users exists,
# so we also allow it to be called if there is no valid registry configuration setup. Note that
# this is also safe since this method does not access any information not given in the request.
config = request.get_json()["config"]
validator_context = ValidatorContext.from_app(
app,
config,
request.get_json().get("password", ""),
instance_keys=instance_keys,
ip_resolver=ip_resolver,
config_provider=config_provider,
init_scripts_location=INIT_SCRIPTS_LOCATION,
)
return validate_service_for_config(service, validator_context)
return validate_service_for_config(service, validator_context)
@resource('/v1/superuser/config/file/<filename>')
@resource("/v1/superuser/config/file/<filename>")
class SuperUserConfigFile(ApiResource):
""" Resource for fetching the status of config files and overriding them. """
""" Resource for fetching the status of config files and overriding them. """
@nickname('scConfigFileExists')
def get(self, filename):
""" Returns whether the configuration file with the given name exists. """
if not is_valid_config_upload_filename(filename):
abort(404)
@nickname("scConfigFileExists")
def get(self, filename):
""" Returns whether the configuration file with the given name exists. """
if not is_valid_config_upload_filename(filename):
abort(404)
return {
'exists': config_provider.volume_file_exists(filename)
}
return {"exists": config_provider.volume_file_exists(filename)}
@nickname('scUpdateConfigFile')
def post(self, filename):
""" Updates the configuration file with the given name. """
if not is_valid_config_upload_filename(filename):
abort(404)
@nickname("scUpdateConfigFile")
def post(self, filename):
""" Updates the configuration file with the given name. """
if not is_valid_config_upload_filename(filename):
abort(404)
# Note: This method can be called before the configuration exists
# to upload the database SSL cert.
uploaded_file = request.files['file']
if not uploaded_file:
abort(400)
# Note: This method can be called before the configuration exists
# to upload the database SSL cert.
uploaded_file = request.files["file"]
if not uploaded_file:
abort(400)
config_provider.save_volume_file(filename, uploaded_file)
return {
'status': True
}
config_provider.save_volume_file(filename, uploaded_file)
return {"status": True}

View file

@ -4,36 +4,36 @@ from six import add_metaclass
@add_metaclass(ABCMeta)
class SuperuserConfigDataInterface(object):
"""
"""
Interface that represents all data store interactions required by the superuser config API.
"""
@abstractmethod
def is_valid(self):
"""
@abstractmethod
def is_valid(self):
"""
Returns true if the configured database is valid.
"""
@abstractmethod
def has_users(self):
"""
@abstractmethod
def has_users(self):
"""
Returns true if there are any users defined.
"""
@abstractmethod
def create_superuser(self, username, password, email):
"""
@abstractmethod
def create_superuser(self, username, password, email):
"""
Creates a new superuser with the given username, password and email. Returns the user's UUID.
"""
@abstractmethod
def has_federated_login(self, username, service_name):
"""
@abstractmethod
def has_federated_login(self, username, service_name):
"""
Returns true if the matching user has a federated login under the matching service.
"""
@abstractmethod
def attach_federated_login(self, username, service_name, federated_username):
"""
@abstractmethod
def attach_federated_login(self, username, service_name, federated_username):
"""
Attaches a federatated login to the matching user, under the given service.
"""

View file

@ -1,37 +1,39 @@
from data import model
from data.database import User
from config_app.config_endpoints.api.suconfig_models_interface import SuperuserConfigDataInterface
from config_app.config_endpoints.api.suconfig_models_interface import (
SuperuserConfigDataInterface,
)
class PreOCIModel(SuperuserConfigDataInterface):
# Note: this method is different than has_users: the user select will throw if the user
# table does not exist, whereas has_users assumes the table is valid
def is_valid(self):
try:
list(User.select().limit(1))
return True
except:
return False
# Note: this method is different than has_users: the user select will throw if the user
# table does not exist, whereas has_users assumes the table is valid
def is_valid(self):
try:
list(User.select().limit(1))
return True
except:
return False
def has_users(self):
return bool(list(User.select().limit(1)))
def has_users(self):
return bool(list(User.select().limit(1)))
def create_superuser(self, username, password, email):
return model.user.create_user(username, password, email, auto_verify=True).uuid
def create_superuser(self, username, password, email):
return model.user.create_user(username, password, email, auto_verify=True).uuid
def has_federated_login(self, username, service_name):
user = model.user.get_user(username)
if user is None:
return False
def has_federated_login(self, username, service_name):
user = model.user.get_user(username)
if user is None:
return False
return bool(model.user.lookup_federated_login(user, service_name))
return bool(model.user.lookup_federated_login(user, service_name))
def attach_federated_login(self, username, service_name, federated_username):
user = model.user.get_user(username)
if user is None:
return False
def attach_federated_login(self, username, service_name, federated_username):
user = model.user.get_user(username)
if user is None:
return False
model.user.attach_federated_login(user, service_name, federated_username)
model.user.attach_federated_login(user, service_name, federated_username)
pre_oci_model = PreOCIModel()

View file

@ -12,7 +12,13 @@ from data.model import ServiceKeyDoesNotExist
from util.config.validator import EXTRA_CA_DIRECTORY
from config_app.config_endpoints.exception import InvalidRequest
from config_app.config_endpoints.api import resource, ApiResource, nickname, log_action, validate_json_request
from config_app.config_endpoints.api import (
resource,
ApiResource,
nickname,
log_action,
validate_json_request,
)
from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model
from config_app.config_util.ssl import load_certificate, CertInvalidException
from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
@ -21,228 +27,233 @@ from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
logger = logging.getLogger(__name__)
@resource('/v1/superuser/customcerts/<certpath>')
@resource("/v1/superuser/customcerts/<certpath>")
class SuperUserCustomCertificate(ApiResource):
""" Resource for managing a custom certificate. """
""" Resource for managing a custom certificate. """
@nickname('uploadCustomCertificate')
def post(self, certpath):
uploaded_file = request.files['file']
if not uploaded_file:
raise InvalidRequest('Missing certificate file')
@nickname("uploadCustomCertificate")
def post(self, certpath):
uploaded_file = request.files["file"]
if not uploaded_file:
raise InvalidRequest("Missing certificate file")
# Save the certificate.
certpath = pathvalidate.sanitize_filename(certpath)
if not certpath.endswith('.crt'):
raise InvalidRequest('Invalid certificate file: must have suffix `.crt`')
# Save the certificate.
certpath = pathvalidate.sanitize_filename(certpath)
if not certpath.endswith(".crt"):
raise InvalidRequest("Invalid certificate file: must have suffix `.crt`")
logger.debug('Saving custom certificate %s', certpath)
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.save_volume_file(cert_full_path, uploaded_file)
logger.debug('Saved custom certificate %s', certpath)
logger.debug("Saving custom certificate %s", certpath)
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.save_volume_file(cert_full_path, uploaded_file)
logger.debug("Saved custom certificate %s", certpath)
# Validate the certificate.
try:
logger.debug('Loading custom certificate %s', certpath)
with config_provider.get_volume_file(cert_full_path) as f:
load_certificate(f.read())
except CertInvalidException:
logger.exception('Got certificate invalid error for cert %s', certpath)
return '', 204
except IOError:
logger.exception('Got IO error for cert %s', certpath)
return '', 204
# Validate the certificate.
try:
logger.debug("Loading custom certificate %s", certpath)
with config_provider.get_volume_file(cert_full_path) as f:
load_certificate(f.read())
except CertInvalidException:
logger.exception("Got certificate invalid error for cert %s", certpath)
return "", 204
except IOError:
logger.exception("Got IO error for cert %s", certpath)
return "", 204
# Call the update script with config dir location to install the certificate immediately.
if not app.config['TESTING']:
cert_dir = os.path.join(config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY)
if subprocess.call([os.path.join(INIT_SCRIPTS_LOCATION, 'certs_install.sh')], env={ 'CERTDIR': cert_dir }) != 0:
raise Exception('Could not install certificates')
# Call the update script with config dir location to install the certificate immediately.
if not app.config["TESTING"]:
cert_dir = os.path.join(
config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY
)
if (
subprocess.call(
[os.path.join(INIT_SCRIPTS_LOCATION, "certs_install.sh")],
env={"CERTDIR": cert_dir},
)
!= 0
):
raise Exception("Could not install certificates")
return '', 204
return "", 204
@nickname('deleteCustomCertificate')
def delete(self, certpath):
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.remove_volume_file(cert_full_path)
return '', 204
@nickname("deleteCustomCertificate")
def delete(self, certpath):
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
config_provider.remove_volume_file(cert_full_path)
return "", 204
@resource('/v1/superuser/customcerts')
@resource("/v1/superuser/customcerts")
class SuperUserCustomCertificates(ApiResource):
""" Resource for managing custom certificates. """
""" Resource for managing custom certificates. """
@nickname('getCustomCertificates')
def get(self):
has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
if extra_certs_found is None:
return {
'status': 'file' if has_extra_certs_path else 'none',
}
@nickname("getCustomCertificates")
def get(self):
has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
if extra_certs_found is None:
return {"status": "file" if has_extra_certs_path else "none"}
cert_views = []
for extra_cert_path in extra_certs_found:
try:
cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, extra_cert_path)
with config_provider.get_volume_file(cert_full_path) as f:
certificate = load_certificate(f.read())
cert_views.append({
'path': extra_cert_path,
'names': list(certificate.names),
'expired': certificate.expired,
})
except CertInvalidException as cie:
cert_views.append({
'path': extra_cert_path,
'error': cie.message,
})
except IOError as ioe:
cert_views.append({
'path': extra_cert_path,
'error': ioe.message,
})
cert_views = []
for extra_cert_path in extra_certs_found:
try:
cert_full_path = config_provider.get_volume_path(
EXTRA_CA_DIRECTORY, extra_cert_path
)
with config_provider.get_volume_file(cert_full_path) as f:
certificate = load_certificate(f.read())
cert_views.append(
{
"path": extra_cert_path,
"names": list(certificate.names),
"expired": certificate.expired,
}
)
except CertInvalidException as cie:
cert_views.append({"path": extra_cert_path, "error": cie.message})
except IOError as ioe:
cert_views.append({"path": extra_cert_path, "error": ioe.message})
return {
'status': 'directory',
'certs': cert_views,
}
return {"status": "directory", "certs": cert_views}
@resource('/v1/superuser/keys')
@resource("/v1/superuser/keys")
class SuperUserServiceKeyManagement(ApiResource):
""" Resource for managing service keys."""
schemas = {
'CreateServiceKey': {
'id': 'CreateServiceKey',
'type': 'object',
'description': 'Description of creation of a service key',
'required': ['service', 'expiration'],
'properties': {
'service': {
'type': 'string',
'description': 'The service authenticating with this key',
},
'name': {
'type': 'string',
'description': 'The friendly name of a service key',
},
'metadata': {
'type': 'object',
'description': 'The key/value pairs of this key\'s metadata',
},
'notes': {
'type': 'string',
'description': 'If specified, the extra notes for the key',
},
'expiration': {
'description': 'The expiration date as a unix timestamp',
'anyOf': [{'type': 'number'}, {'type': 'null'}],
},
},
},
}
""" Resource for managing service keys."""
@nickname('listServiceKeys')
def get(self):
keys = pre_oci_model.list_all_service_keys()
return jsonify({
'keys': [key.to_dict() for key in keys],
})
@nickname('createServiceKey')
@validate_json_request('CreateServiceKey')
def post(self):
body = request.get_json()
# Ensure we have a valid expiration date if specified.
expiration_date = body.get('expiration', None)
if expiration_date is not None:
try:
expiration_date = datetime.utcfromtimestamp(float(expiration_date))
except ValueError as ve:
raise InvalidRequest('Invalid expiration date: %s' % ve)
if expiration_date <= datetime.now():
raise InvalidRequest('Expiration date cannot be in the past')
# Create the metadata for the key.
metadata = body.get('metadata', {})
metadata.update({
'created_by': 'Quay Superuser Panel',
'ip': request.remote_addr,
})
# Generate a key with a private key that we *never save*.
(private_key, key_id) = pre_oci_model.generate_service_key(body['service'], expiration_date,
metadata=metadata,
name=body.get('name', ''))
# Auto-approve the service key.
pre_oci_model.approve_service_key(key_id, ServiceKeyApprovalType.SUPERUSER,
notes=body.get('notes', ''))
# Log the creation and auto-approval of the service key.
key_log_metadata = {
'kid': key_id,
'preshared': True,
'service': body['service'],
'name': body.get('name', ''),
'expiration_date': expiration_date,
'auto_approved': True,
schemas = {
"CreateServiceKey": {
"id": "CreateServiceKey",
"type": "object",
"description": "Description of creation of a service key",
"required": ["service", "expiration"],
"properties": {
"service": {
"type": "string",
"description": "The service authenticating with this key",
},
"name": {
"type": "string",
"description": "The friendly name of a service key",
},
"metadata": {
"type": "object",
"description": "The key/value pairs of this key's metadata",
},
"notes": {
"type": "string",
"description": "If specified, the extra notes for the key",
},
"expiration": {
"description": "The expiration date as a unix timestamp",
"anyOf": [{"type": "number"}, {"type": "null"}],
},
},
}
}
log_action('service_key_create', None, key_log_metadata)
log_action('service_key_approve', None, key_log_metadata)
@nickname("listServiceKeys")
def get(self):
keys = pre_oci_model.list_all_service_keys()
return jsonify({
'kid': key_id,
'name': body.get('name', ''),
'service': body['service'],
'public_key': private_key.publickey().exportKey('PEM'),
'private_key': private_key.exportKey('PEM'),
})
return jsonify({"keys": [key.to_dict() for key in keys]})
@resource('/v1/superuser/approvedkeys/<kid>')
@nickname("createServiceKey")
@validate_json_request("CreateServiceKey")
def post(self):
body = request.get_json()
# Ensure we have a valid expiration date if specified.
expiration_date = body.get("expiration", None)
if expiration_date is not None:
try:
expiration_date = datetime.utcfromtimestamp(float(expiration_date))
except ValueError as ve:
raise InvalidRequest("Invalid expiration date: %s" % ve)
if expiration_date <= datetime.now():
raise InvalidRequest("Expiration date cannot be in the past")
# Create the metadata for the key.
metadata = body.get("metadata", {})
metadata.update(
{"created_by": "Quay Superuser Panel", "ip": request.remote_addr}
)
# Generate a key with a private key that we *never save*.
(private_key, key_id) = pre_oci_model.generate_service_key(
body["service"],
expiration_date,
metadata=metadata,
name=body.get("name", ""),
)
# Auto-approve the service key.
pre_oci_model.approve_service_key(
key_id, ServiceKeyApprovalType.SUPERUSER, notes=body.get("notes", "")
)
# Log the creation and auto-approval of the service key.
key_log_metadata = {
"kid": key_id,
"preshared": True,
"service": body["service"],
"name": body.get("name", ""),
"expiration_date": expiration_date,
"auto_approved": True,
}
log_action("service_key_create", None, key_log_metadata)
log_action("service_key_approve", None, key_log_metadata)
return jsonify(
{
"kid": key_id,
"name": body.get("name", ""),
"service": body["service"],
"public_key": private_key.publickey().exportKey("PEM"),
"private_key": private_key.exportKey("PEM"),
}
)
@resource("/v1/superuser/approvedkeys/<kid>")
class SuperUserServiceKeyApproval(ApiResource):
""" Resource for approving service keys. """
""" Resource for approving service keys. """
schemas = {
'ApproveServiceKey': {
'id': 'ApproveServiceKey',
'type': 'object',
'description': 'Information for approving service keys',
'properties': {
'notes': {
'type': 'string',
'description': 'Optional approval notes',
},
},
},
}
schemas = {
"ApproveServiceKey": {
"id": "ApproveServiceKey",
"type": "object",
"description": "Information for approving service keys",
"properties": {
"notes": {"type": "string", "description": "Optional approval notes"}
},
}
}
@nickname('approveServiceKey')
@validate_json_request('ApproveServiceKey')
def post(self, kid):
notes = request.get_json().get('notes', '')
try:
key = pre_oci_model.approve_service_key(kid, ServiceKeyApprovalType.SUPERUSER, notes=notes)
@nickname("approveServiceKey")
@validate_json_request("ApproveServiceKey")
def post(self, kid):
notes = request.get_json().get("notes", "")
try:
key = pre_oci_model.approve_service_key(
kid, ServiceKeyApprovalType.SUPERUSER, notes=notes
)
# Log the approval of the service key.
key_log_metadata = {
'kid': kid,
'service': key.service,
'name': key.name,
'expiration_date': key.expiration_date,
}
# Log the approval of the service key.
key_log_metadata = {
"kid": kid,
"service": key.service,
"name": key.name,
"expiration_date": key.expiration_date,
}
# Note: this may not actually be the current person modifying the config, but if they're in the config tool,
# they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
super_user = app.config.get('SUPER_USERS', [None])[0]
log_action('service_key_approve', super_user, key_log_metadata)
except ServiceKeyDoesNotExist:
raise NotFound()
except ServiceKeyAlreadyApproved:
pass
# Note: this may not actually be the current person modifying the config, but if they're in the config tool,
# they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
super_user = app.config.get("SUPER_USERS", [None])[0]
log_action("service_key_approve", super_user, key_log_metadata)
except ServiceKeyDoesNotExist:
raise NotFound()
except ServiceKeyAlreadyApproved:
pass
return make_response('', 201)
return make_response("", 201)

View file

@ -6,21 +6,33 @@ from config_app.config_endpoints.api import format_date
def user_view(user):
return {
'name': user.username,
'kind': 'user',
'is_robot': user.robot,
}
return {"name": user.username, "kind": "user", "is_robot": user.robot}
class RepositoryBuild(namedtuple('RepositoryBuild',
['uuid', 'logs_archived', 'repository_namespace_user_username',
'repository_name',
'can_write', 'can_read', 'pull_robot', 'resource_key', 'trigger',
'display_name',
'started', 'job_config', 'phase', 'status', 'error',
'archive_url'])):
"""
class RepositoryBuild(
namedtuple(
"RepositoryBuild",
[
"uuid",
"logs_archived",
"repository_namespace_user_username",
"repository_name",
"can_write",
"can_read",
"pull_robot",
"resource_key",
"trigger",
"display_name",
"started",
"job_config",
"phase",
"status",
"error",
"archive_url",
],
)
):
"""
RepositoryBuild represents a build associated with a repostiory
:type uuid: string
:type logs_archived: boolean
@ -40,42 +52,46 @@ class RepositoryBuild(namedtuple('RepositoryBuild',
:type archive_url: string
"""
def to_dict(self):
def to_dict(self):
resp = {
'id': self.uuid,
'phase': self.phase,
'started': format_date(self.started),
'display_name': self.display_name,
'status': self.status or {},
'subdirectory': self.job_config.get('build_subdir', ''),
'dockerfile_path': self.job_config.get('build_subdir', ''),
'context': self.job_config.get('context', ''),
'tags': self.job_config.get('docker_tags', []),
'manual_user': self.job_config.get('manual_user', None),
'is_writer': self.can_write,
'trigger': self.trigger.to_dict(),
'trigger_metadata': self.job_config.get('trigger_metadata', None) if self.can_read else None,
'resource_key': self.resource_key,
'pull_robot': user_view(self.pull_robot) if self.pull_robot else None,
'repository': {
'namespace': self.repository_namespace_user_username,
'name': self.repository_name
},
'error': self.error,
}
resp = {
"id": self.uuid,
"phase": self.phase,
"started": format_date(self.started),
"display_name": self.display_name,
"status": self.status or {},
"subdirectory": self.job_config.get("build_subdir", ""),
"dockerfile_path": self.job_config.get("build_subdir", ""),
"context": self.job_config.get("context", ""),
"tags": self.job_config.get("docker_tags", []),
"manual_user": self.job_config.get("manual_user", None),
"is_writer": self.can_write,
"trigger": self.trigger.to_dict(),
"trigger_metadata": self.job_config.get("trigger_metadata", None)
if self.can_read
else None,
"resource_key": self.resource_key,
"pull_robot": user_view(self.pull_robot) if self.pull_robot else None,
"repository": {
"namespace": self.repository_namespace_user_username,
"name": self.repository_name,
},
"error": self.error,
}
if self.can_write:
if self.resource_key is not None:
resp['archive_url'] = self.archive_url
elif self.job_config.get('archive_url', None):
resp['archive_url'] = self.job_config['archive_url']
if self.can_write:
if self.resource_key is not None:
resp["archive_url"] = self.archive_url
elif self.job_config.get("archive_url", None):
resp["archive_url"] = self.job_config["archive_url"]
return resp
return resp
class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_date', 'notes'])):
"""
class Approval(
namedtuple("Approval", ["approver", "approval_type", "approved_date", "notes"])
):
"""
Approval represents whether a key has been approved or not
:type approver: User
:type approval_type: string
@ -83,19 +99,32 @@ class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_da
:type notes: string
"""
def to_dict(self):
return {
'approver': self.approver.to_dict() if self.approver else None,
'approval_type': self.approval_type,
'approved_date': self.approved_date,
'notes': self.notes,
}
def to_dict(self):
return {
"approver": self.approver.to_dict() if self.approver else None,
"approval_type": self.approval_type,
"approved_date": self.approved_date,
"notes": self.notes,
}
class ServiceKey(
namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'metadata', 'created_date',
'expiration_date', 'rotation_duration', 'approval'])):
"""
namedtuple(
"ServiceKey",
[
"name",
"kid",
"service",
"jwk",
"metadata",
"created_date",
"expiration_date",
"rotation_duration",
"approval",
],
)
):
"""
ServiceKey is an apostille signing key
:type name: string
:type kid: int
@ -109,22 +138,22 @@ class ServiceKey(
"""
def to_dict(self):
return {
'name': self.name,
'kid': self.kid,
'service': self.service,
'jwk': self.jwk,
'metadata': self.metadata,
'created_date': self.created_date,
'expiration_date': self.expiration_date,
'rotation_duration': self.rotation_duration,
'approval': self.approval.to_dict() if self.approval is not None else None,
}
def to_dict(self):
return {
"name": self.name,
"kid": self.kid,
"service": self.service,
"jwk": self.jwk,
"metadata": self.metadata,
"created_date": self.created_date,
"expiration_date": self.expiration_date,
"rotation_duration": self.rotation_duration,
"approval": self.approval.to_dict() if self.approval is not None else None,
}
class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robot'])):
"""
class User(namedtuple("User", ["username", "email", "verified", "enabled", "robot"])):
"""
User represents a single user.
:type username: string
:type email: string
@ -133,41 +162,38 @@ class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robo
:type robot: User
"""
def to_dict(self):
user_data = {
'kind': 'user',
'name': self.username,
'username': self.username,
'email': self.email,
'verified': self.verified,
'enabled': self.enabled,
}
def to_dict(self):
user_data = {
"kind": "user",
"name": self.username,
"username": self.username,
"email": self.email,
"verified": self.verified,
"enabled": self.enabled,
}
return user_data
return user_data
class Organization(namedtuple('Organization', ['username', 'email'])):
"""
class Organization(namedtuple("Organization", ["username", "email"])):
"""
Organization represents a single org.
:type username: string
:type email: string
"""
def to_dict(self):
return {
'name': self.username,
'email': self.email,
}
def to_dict(self):
return {"name": self.username, "email": self.email}
@add_metaclass(ABCMeta)
class SuperuserDataInterface(object):
"""
"""
Interface that represents all data store interactions required by a superuser api.
"""
@abstractmethod
def list_all_service_keys(self):
"""
@abstractmethod
def list_all_service_keys(self):
"""
Returns a list of service keys
"""

View file

@ -1,60 +1,85 @@
from data import model
from config_app.config_endpoints.api.superuser_models_interface import (SuperuserDataInterface, User, ServiceKey,
Approval)
from config_app.config_endpoints.api.superuser_models_interface import (
SuperuserDataInterface,
User,
ServiceKey,
Approval,
)
def _create_user(user):
if user is None:
return None
return User(user.username, user.email, user.verified, user.enabled, user.robot)
if user is None:
return None
return User(user.username, user.email, user.verified, user.enabled, user.robot)
def _create_key(key):
approval = None
if key.approval is not None:
approval = Approval(_create_user(key.approval.approver), key.approval.approval_type,
key.approval.approved_date,
key.approval.notes)
approval = None
if key.approval is not None:
approval = Approval(
_create_user(key.approval.approver),
key.approval.approval_type,
key.approval.approved_date,
key.approval.notes,
)
return ServiceKey(key.name, key.kid, key.service, key.jwk, key.metadata, key.created_date,
key.expiration_date,
key.rotation_duration, approval)
return ServiceKey(
key.name,
key.kid,
key.service,
key.jwk,
key.metadata,
key.created_date,
key.expiration_date,
key.rotation_duration,
approval,
)
class ServiceKeyDoesNotExist(Exception):
pass
pass
class ServiceKeyAlreadyApproved(Exception):
pass
pass
class PreOCIModel(SuperuserDataInterface):
"""
"""
PreOCIModel implements the data model for the SuperUser using a database schema
before it was changed to support the OCI specification.
"""
def list_all_service_keys(self):
keys = model.service_keys.list_all_keys()
return [_create_key(key) for key in keys]
def list_all_service_keys(self):
keys = model.service_keys.list_all_keys()
return [_create_key(key) for key in keys]
def approve_service_key(self, kid, approval_type, notes=''):
try:
key = model.service_keys.approve_service_key(kid, approval_type, notes=notes)
return _create_key(key)
except model.ServiceKeyDoesNotExist:
raise ServiceKeyDoesNotExist
except model.ServiceKeyAlreadyApproved:
raise ServiceKeyAlreadyApproved
def approve_service_key(self, kid, approval_type, notes=""):
try:
key = model.service_keys.approve_service_key(
kid, approval_type, notes=notes
)
return _create_key(key)
except model.ServiceKeyDoesNotExist:
raise ServiceKeyDoesNotExist
except model.ServiceKeyAlreadyApproved:
raise ServiceKeyAlreadyApproved
def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None,
rotation_duration=None):
(private_key, key) = model.service_keys.generate_service_key(service, expiration_date,
metadata=metadata, name=name)
def generate_service_key(
self,
service,
expiration_date,
kid=None,
name="",
metadata=None,
rotation_duration=None,
):
(private_key, key) = model.service_keys.generate_service_key(
service, expiration_date, metadata=metadata, name=name
)
return private_key, key.kid
return private_key, key.kid
pre_oci_model = PreOCIModel()

View file

@ -10,53 +10,59 @@ from data.database import configure
from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_util.tar import tarinfo_filter_partial, strip_absolute_path_and_add_trailing_dir
from config_app.config_util.tar import (
tarinfo_filter_partial,
strip_absolute_path_and_add_trailing_dir,
)
@resource('/v1/configapp/initialization')
@resource("/v1/configapp/initialization")
class ConfigInitialization(ApiResource):
"""
"""
Resource for dealing with any initialization logic for the config app
"""
@nickname('scStartNewConfig')
def post(self):
config_provider.new_config_dir()
return make_response('OK')
@nickname("scStartNewConfig")
def post(self):
config_provider.new_config_dir()
return make_response("OK")
@resource('/v1/configapp/tarconfig')
@resource("/v1/configapp/tarconfig")
class TarConfigLoader(ApiResource):
"""
"""
Resource for dealing with configuration as a tarball,
including loading and generating functions
"""
@nickname('scGetConfigTarball')
def get(self):
config_path = config_provider.get_config_dir_path()
tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
temp = tempfile.NamedTemporaryFile()
@nickname("scGetConfigTarball")
def get(self):
config_path = config_provider.get_config_dir_path()
tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
temp = tempfile.NamedTemporaryFile()
with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
for name in os.listdir(config_path):
tar.add(os.path.join(config_path, name), filter=tarinfo_filter_partial(tar_dir_prefix))
return send_file(temp.name, mimetype='application/gzip')
with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
for name in os.listdir(config_path):
tar.add(
os.path.join(config_path, name),
filter=tarinfo_filter_partial(tar_dir_prefix),
)
return send_file(temp.name, mimetype="application/gzip")
@nickname('scUploadTarballConfig')
def put(self):
""" Loads tarball config into the config provider """
# Generate a new empty dir to load the config into
config_provider.new_config_dir()
input_stream = request.stream
with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream:
tar_stream.extractall(config_provider.get_config_dir_path())
@nickname("scUploadTarballConfig")
def put(self):
""" Loads tarball config into the config provider """
# Generate a new empty dir to load the config into
config_provider.new_config_dir()
input_stream = request.stream
with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream:
tar_stream.extractall(config_provider.get_config_dir_path())
config_provider.create_copy_of_config_dir()
config_provider.create_copy_of_config_dir()
# now try to connect to the db provided in their config to validate it works
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
# now try to connect to the db provided in their config to validate it works
combined = dict(**app.config)
combined.update(config_provider.get_config())
configure(combined)
return make_response('OK')
return make_response("OK")

View file

@ -3,12 +3,12 @@ from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_endpoints.api.superuser_models_interface import user_view
@resource('/v1/user/')
@resource("/v1/user/")
class User(ApiResource):
""" Operations related to users. """
""" Operations related to users. """
@nickname('getLoggedInUser')
def get(self):
""" Get user information for the authenticated user. """
user = get_authenticated_user()
return user_view(user)
@nickname("getLoggedInUser")
def get(self):
""" Get user information for the authenticated user. """
user = get_authenticated_user()
return user_view(user)

View file

@ -14,60 +14,72 @@ from config_app.config_util.k8sconfig import get_k8s_namespace
def truthy_bool(param):
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
return param not in {False, "false", "False", "0", "FALSE", "", "null"}
DEFAULT_JS_BUNDLE_NAME = 'configapp'
PARAM_REGEX = re.compile(r'<([^:>]+:)*([\w]+)>')
DEFAULT_JS_BUNDLE_NAME = "configapp"
PARAM_REGEX = re.compile(r"<([^:>]+:)*([\w]+)>")
logger = logging.getLogger(__name__)
TYPE_CONVERTER = {
truthy_bool: 'boolean',
str: 'string',
basestring: 'string',
reqparse.text_type: 'string',
int: 'integer',
truthy_bool: "boolean",
str: "string",
basestring: "string",
reqparse.text_type: "string",
int: "integer",
}
def _list_files(path, extension, contains=""):
""" Returns a list of all the files with the given extension found under the given path. """
""" Returns a list of all the files with the given extension found under the given path. """
def matches(f):
return os.path.splitext(f)[1] == '.' + extension and contains in os.path.splitext(f)[0]
def matches(f):
return (
os.path.splitext(f)[1] == "." + extension
and contains in os.path.splitext(f)[0]
)
def join_path(dp, f):
# Remove the static/ prefix. It is added in the template.
return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len('config_app/static/'):]
def join_path(dp, f):
# Remove the static/ prefix. It is added in the template.
return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len("config_app/static/") :]
filepath = os.path.join(os.path.join(ROOT_DIR, 'config_app/static/'), path)
return [join_path(dp, f) for dp, _, files in os.walk(filepath) for f in files if matches(f)]
filepath = os.path.join(os.path.join(ROOT_DIR, "config_app/static/"), path)
return [
join_path(dp, f)
for dp, _, files in os.walk(filepath)
for f in files
if matches(f)
]
FONT_AWESOME_4 = 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css'
FONT_AWESOME_4 = "netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css"
def render_page_template(name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs):
""" Renders the page template with the given name as the response and returns its contents. """
main_scripts = _list_files('build', 'js', js_bundle_name)
def render_page_template(
name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs
):
""" Renders the page template with the given name as the response and returns its contents. """
main_scripts = _list_files("build", "js", js_bundle_name)
use_cdn = os.getenv('TESTING') == 'true'
use_cdn = os.getenv("TESTING") == "true"
external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
external_scripts = get_external_javascript(local=not use_cdn)
external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
external_scripts = get_external_javascript(local=not use_cdn)
contents = render_template(name,
route_data=route_data,
main_scripts=main_scripts,
external_styles=external_styles,
external_scripts=external_scripts,
config_set=frontend_visible_config(app.config),
kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
**kwargs)
contents = render_template(
name,
route_data=route_data,
main_scripts=main_scripts,
external_styles=external_styles,
external_scripts=external_scripts,
config_set=frontend_visible_config(app.config),
kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
**kwargs
)
resp = make_response(contents)
resp.headers['X-FRAME-OPTIONS'] = 'DENY'
return resp
resp = make_response(contents)
resp.headers["X-FRAME-OPTIONS"] = "DENY"
return resp
def fully_qualified_name(method_view_class):
return '%s.%s' % (method_view_class.__module__, method_view_class.__name__)
return "%s.%s" % (method_view_class.__module__, method_view_class.__name__)

View file

@ -5,11 +5,11 @@ from werkzeug.exceptions import HTTPException
class ApiErrorType(Enum):
invalid_request = 'invalid_request'
invalid_request = "invalid_request"
class ApiException(HTTPException):
"""
"""
Represents an error in the application/problem+json format.
See: https://tools.ietf.org/html/rfc7807
@ -31,36 +31,42 @@ class ApiException(HTTPException):
information if dereferenced.
"""
def __init__(self, error_type, status_code, error_description, payload=None):
Exception.__init__(self)
self.error_description = error_description
self.code = status_code
self.payload = payload
self.error_type = error_type
self.data = self.to_dict()
def __init__(self, error_type, status_code, error_description, payload=None):
Exception.__init__(self)
self.error_description = error_description
self.code = status_code
self.payload = payload
self.error_type = error_type
self.data = self.to_dict()
super(ApiException, self).__init__(error_description, None)
super(ApiException, self).__init__(error_description, None)
def to_dict(self):
rv = dict(self.payload or ())
def to_dict(self):
rv = dict(self.payload or ())
if self.error_description is not None:
rv['detail'] = self.error_description
rv['error_message'] = self.error_description # TODO: deprecate
if self.error_description is not None:
rv["detail"] = self.error_description
rv["error_message"] = self.error_description # TODO: deprecate
rv['error_type'] = self.error_type.value # TODO: deprecate
rv['title'] = self.error_type.value
rv['type'] = url_for('api.error', error_type=self.error_type.value, _external=True)
rv['status'] = self.code
rv["error_type"] = self.error_type.value # TODO: deprecate
rv["title"] = self.error_type.value
rv["type"] = url_for(
"api.error", error_type=self.error_type.value, _external=True
)
rv["status"] = self.code
return rv
return rv
class InvalidRequest(ApiException):
def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_request, 400, error_description, payload)
def __init__(self, error_description, payload=None):
ApiException.__init__(
self, ApiErrorType.invalid_request, 400, error_description, payload
)
class InvalidResponse(ApiException):
def __init__(self, error_description, payload=None):
ApiException.__init__(self, ApiErrorType.invalid_response, 400, error_description, payload)
def __init__(self, error_description, payload=None):
ApiException.__init__(
self, ApiErrorType.invalid_response, 400, error_description, payload
)

View file

@ -5,19 +5,21 @@ from config_app.config_endpoints.common import render_page_template
from config_app.config_endpoints.api.discovery import generate_route_data
from config_app.config_endpoints.api import no_cache
setup_web = Blueprint('setup_web', __name__, template_folder='templates')
setup_web = Blueprint("setup_web", __name__, template_folder="templates")
@lru_cache(maxsize=1)
def _get_route_data():
return generate_route_data()
return generate_route_data()
def render_page_template_with_routedata(name, *args, **kwargs):
return render_page_template(name, _get_route_data(), *args, **kwargs)
return render_page_template(name, _get_route_data(), *args, **kwargs)
@no_cache
@setup_web.route('/', methods=['GET'], defaults={'path': ''})
@setup_web.route("/", methods=["GET"], defaults={"path": ""})
def index(path, **kwargs):
return render_page_template_with_routedata('index.html', js_bundle_name='configapp', **kwargs)
return render_page_template_with_routedata(
"index.html", js_bundle_name="configapp", **kwargs
)

View file

@ -5,204 +5,242 @@ from data import database, model
from util.security.test.test_ssl_util import generate_test_cert
from config_app.c_app import app
from config_app.config_test import ApiTestCase, all_queues, ADMIN_ACCESS_USER, ADMIN_ACCESS_EMAIL
from config_app.config_test import (
ApiTestCase,
all_queues,
ADMIN_ACCESS_USER,
ADMIN_ACCESS_EMAIL,
)
from config_app.config_endpoints.api import api_bp
from config_app.config_endpoints.api.superuser import SuperUserCustomCertificate, SuperUserCustomCertificates
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserCreateInitialSuperUser, \
SuperUserConfigFile, SuperUserRegistryStatus
from config_app.config_endpoints.api.superuser import (
SuperUserCustomCertificate,
SuperUserCustomCertificates,
)
from config_app.config_endpoints.api.suconfig import (
SuperUserConfig,
SuperUserCreateInitialSuperUser,
SuperUserConfigFile,
SuperUserRegistryStatus,
)
try:
app.register_blueprint(api_bp, url_prefix='/api')
app.register_blueprint(api_bp, url_prefix="/api")
except ValueError:
# This blueprint was already registered
pass
# This blueprint was already registered
pass
class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_create_superuser(self):
data = {
'username': 'newsuper',
'password': 'password',
'email': 'jschorr+fake@devtable.com',
}
def test_create_superuser(self):
data = {
"username": "newsuper",
"password": "password",
"email": "jschorr+fake@devtable.com",
}
# Add some fake config.
fake_config = {
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
# Add some fake config.
fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, hostname='fakehost'))
self.putJsonResponse(
SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
)
# Try to write with config. Should 403 since there are users in the DB.
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
# Try to write with config. Should 403 since there are users in the DB.
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
# Delete all users in the DB.
for user in list(database.User.select()):
model.user.delete_user(user, all_queues)
# Delete all users in the DB.
for user in list(database.User.select()):
model.user.delete_user(user, all_queues)
# Create the superuser.
self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
# Create the superuser.
self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
# Ensure the user exists in the DB.
self.assertIsNotNone(model.user.get_user('newsuper'))
# Ensure the user exists in the DB.
self.assertIsNotNone(model.user.get_user("newsuper"))
# Ensure that the current user is a superuser in the config.
json = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['newsuper'], json['config']['SUPER_USERS'])
# Ensure that the current user is a superuser in the config.
json = self.getJsonResponse(SuperUserConfig)
self.assertEquals(["newsuper"], json["config"]["SUPER_USERS"])
# Ensure that the current user is a superuser in memory by trying to call an API
# that will fail otherwise.
self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
# Ensure that the current user is a superuser in memory by trying to call an API
# that will fail otherwise.
self.getResponse(SuperUserConfigFile, params=dict(filename="ssl.cert"))
class TestSuperUserConfig(ApiTestCase):
def test_get_status_update_config(self):
# With no config the status should be 'config-db'.
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status'])
def test_get_status_update_config(self):
# With no config the status should be 'config-db'.
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config-db", json["status"])
# Add some fake config.
fake_config = {
'AUTHENTICATION_TYPE': 'Database',
'SECRET_KEY': 'fakekey',
}
# Add some fake config.
fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
json = self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config,
hostname='fakehost'))
self.assertEquals('fakekey', json['config']['SECRET_KEY'])
self.assertEquals('fakehost', json['config']['SERVER_HOSTNAME'])
self.assertEquals('Database', json['config']['AUTHENTICATION_TYPE'])
json = self.putJsonResponse(
SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
)
self.assertEquals("fakekey", json["config"]["SECRET_KEY"])
self.assertEquals("fakehost", json["config"]["SERVER_HOSTNAME"])
self.assertEquals("Database", json["config"]["AUTHENTICATION_TYPE"])
# With config the status should be 'setup-db'.
# TODO: fix this test
# json = self.getJsonResponse(SuperUserRegistryStatus)
# self.assertEquals('setup-db', json['status'])
# With config the status should be 'setup-db'.
# TODO: fix this test
# json = self.getJsonResponse(SuperUserRegistryStatus)
# self.assertEquals('setup-db', json['status'])
def test_config_file(self):
# Try for an invalid file. Should 404.
self.getResponse(SuperUserConfigFile, params=dict(filename='foobar'), expected_code=404)
def test_config_file(self):
# Try for an invalid file. Should 404.
self.getResponse(
SuperUserConfigFile, params=dict(filename="foobar"), expected_code=404
)
# Try for a valid filename. Should not exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
self.assertFalse(json['exists'])
# Try for a valid filename. Should not exist.
json = self.getJsonResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(json["exists"])
# Add the file.
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'),
file=(StringIO('my file contents'), 'ssl.cert'))
# Add the file.
self.postResponse(
SuperUserConfigFile,
params=dict(filename="ssl.cert"),
file=(StringIO("my file contents"), "ssl.cert"),
)
# Should now exist.
json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
self.assertTrue(json['exists'])
# Should now exist.
json = self.getJsonResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertTrue(json["exists"])
def test_update_with_external_auth(self):
# Run a mock LDAP.
mockldap = MockLdap({
'dc=quay,dc=io': {
'dc': ['quay', 'io']
},
'ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees'
},
'uid=' + ADMIN_ACCESS_USER + ',ou=employees,dc=quay,dc=io': {
'dc': ['quay', 'io'],
'ou': 'employees',
'uid': [ADMIN_ACCESS_USER],
'userPassword': ['password'],
'mail': [ADMIN_ACCESS_EMAIL],
},
})
def test_update_with_external_auth(self):
# Run a mock LDAP.
mockldap = MockLdap(
{
"dc=quay,dc=io": {"dc": ["quay", "io"]},
"ou=employees,dc=quay,dc=io": {"dc": ["quay", "io"], "ou": "employees"},
"uid="
+ ADMIN_ACCESS_USER
+ ",ou=employees,dc=quay,dc=io": {
"dc": ["quay", "io"],
"ou": "employees",
"uid": [ADMIN_ACCESS_USER],
"userPassword": ["password"],
"mail": [ADMIN_ACCESS_EMAIL],
},
}
)
config = {
'AUTHENTICATION_TYPE': 'LDAP',
'LDAP_BASE_DN': ['dc=quay', 'dc=io'],
'LDAP_ADMIN_DN': 'uid=devtable,ou=employees,dc=quay,dc=io',
'LDAP_ADMIN_PASSWD': 'password',
'LDAP_USER_RDN': ['ou=employees'],
'LDAP_UID_ATTR': 'uid',
'LDAP_EMAIL_ATTR': 'mail',
}
config = {
"AUTHENTICATION_TYPE": "LDAP",
"LDAP_BASE_DN": ["dc=quay", "dc=io"],
"LDAP_ADMIN_DN": "uid=devtable,ou=employees,dc=quay,dc=io",
"LDAP_ADMIN_PASSWD": "password",
"LDAP_USER_RDN": ["ou=employees"],
"LDAP_UID_ATTR": "uid",
"LDAP_EMAIL_ATTR": "mail",
}
mockldap.start()
try:
# Write the config with the valid password.
self.putResponse(SuperUserConfig,
data={'config': config,
'password': 'password',
'hostname': 'foo'}, expected_code=200)
mockldap.start()
try:
# Write the config with the valid password.
self.putResponse(
SuperUserConfig,
data={"config": config, "password": "password", "hostname": "foo"},
expected_code=200,
)
# Ensure that the user row has been linked.
# TODO: fix this test
# self.assertEquals(ADMIN_ACCESS_USER,
# model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
finally:
mockldap.stop()
# Ensure that the user row has been linked.
# TODO: fix this test
# self.assertEquals(ADMIN_ACCESS_USER,
# model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
finally:
mockldap.stop()
class TestSuperUserCustomCertificates(ApiTestCase):
def test_custom_certificates(self):
def test_custom_certificates(self):
# Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', san_list=['DNS:bar', 'DNS:baz'])
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204)
# Upload a certificate.
cert_contents, _ = generate_test_cert(
hostname="somecoolhost", san_list=["DNS:bar", "DNS:baz"]
)
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs']))
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0]
self.assertEquals('testcert.crt', cert_info['path'])
cert_info = json["certs"][0]
self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost', 'bar', 'baz']), set(cert_info['names']))
self.assertFalse(cert_info['expired'])
self.assertEquals(set(["somecoolhost", "bar", "baz"]), set(cert_info["names"]))
self.assertFalse(cert_info["expired"])
# Remove the certificate.
self.deleteResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'))
# Remove the certificate.
self.deleteResponse(
SuperUserCustomCertificate, params=dict(certpath="testcert.crt")
)
# Make sure it is gone.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(0, len(json['certs']))
# Make sure it is gone.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(0, len(json["certs"]))
def test_expired_custom_certificate(self):
# Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204)
def test_expired_custom_certificate(self):
# Upload a certificate.
cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO(cert_contents), "testcert.crt"),
expected_code=204,
)
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs']))
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0]
self.assertEquals('testcert.crt', cert_info['path'])
cert_info = json["certs"][0]
self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals(set(['somecoolhost']), set(cert_info['names']))
self.assertTrue(cert_info['expired'])
self.assertEquals(set(["somecoolhost"]), set(cert_info["names"]))
self.assertTrue(cert_info["expired"])
def test_invalid_custom_certificate(self):
# Upload an invalid certificate.
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
file=(StringIO('some contents'), 'testcert.crt'), expected_code=204)
def test_invalid_custom_certificate(self):
# Upload an invalid certificate.
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert.crt"),
file=(StringIO("some contents"), "testcert.crt"),
expected_code=204,
)
# Make sure it is present but invalid.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs']))
# Make sure it is present but invalid.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json["certs"]))
cert_info = json['certs'][0]
self.assertEquals('testcert.crt', cert_info['path'])
self.assertEquals('no start line', cert_info['error'])
cert_info = json["certs"][0]
self.assertEquals("testcert.crt", cert_info["path"])
self.assertEquals("no start line", cert_info["error"])
def test_path_sanitization(self):
# Upload a certificate.
cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10)
self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert/../foobar.crt'),
file=(StringIO(cert_contents), 'testcert/../foobar.crt'), expected_code=204)
def test_path_sanitization(self):
# Upload a certificate.
cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
self.postResponse(
SuperUserCustomCertificate,
params=dict(certpath="testcert/../foobar.crt"),
file=(StringIO(cert_contents), "testcert/../foobar.crt"),
expected_code=204,
)
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json['certs']))
cert_info = json['certs'][0]
self.assertEquals('foobar.crt', cert_info['path'])
# Make sure it is present.
json = self.getJsonResponse(SuperUserCustomCertificates)
self.assertEquals(1, len(json["certs"]))
cert_info = json["certs"][0]
self.assertEquals("foobar.crt", cert_info["path"])

View file

@ -4,176 +4,235 @@ import mock
from data.database import User
from data import model
from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserConfigValidate, SuperUserConfigFile, \
SuperUserRegistryStatus, SuperUserCreateInitialSuperUser
from config_app.config_endpoints.api.suconfig import (
SuperUserConfig,
SuperUserConfigValidate,
SuperUserConfigFile,
SuperUserRegistryStatus,
SuperUserCreateInitialSuperUser,
)
from config_app.config_endpoints.api import api_bp
from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER
from config_app.c_app import app, config_provider
try:
app.register_blueprint(api_bp, url_prefix='/api')
app.register_blueprint(api_bp, url_prefix="/api")
except ValueError:
# This blueprint was already registered
pass
# This blueprint was already registered
pass
# OVERRIDES FROM PORTING FROM OLD APP:
all_queues = [] # the config app doesn't have any queues
all_queues = [] # the config app doesn't have any queues
class FreshConfigProvider(object):
def __enter__(self):
config_provider.reset_for_test()
return config_provider
def __enter__(self):
config_provider.reset_for_test()
return config_provider
def __exit__(self, type, value, traceback):
config_provider.reset_for_test()
def __exit__(self, type, value, traceback):
config_provider.reset_for_test()
class TestSuperUserRegistryStatus(ApiTestCase):
def test_registry_status_no_config(self):
with FreshConfigProvider():
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config-db', json['status'])
def test_registry_status_no_config(self):
with FreshConfigProvider():
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=False))
def test_registry_status_no_database(self):
with FreshConfigProvider():
config_provider.save_config({'key': 'value'})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('setup-db', json['status'])
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=False),
)
def test_registry_status_no_database(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("setup-db", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
def test_registry_status_db_has_superuser(self):
with FreshConfigProvider():
config_provider.save_config({'key': 'value'})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config', json['status'])
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
def test_registry_status_db_has_superuser(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=False))
def test_registry_status_db_no_superuser(self):
with FreshConfigProvider():
config_provider.save_config({'key': 'value'})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('create-superuser', json['status'])
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_has_users",
mock.Mock(return_value=False),
)
def test_registry_status_db_no_superuser(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value"})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("create-superuser", json["status"])
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_is_valid",
mock.Mock(return_value=True),
)
@mock.patch(
"config_app.config_endpoints.api.suconfig.database_has_users",
mock.Mock(return_value=True),
)
def test_registry_status_setup_complete(self):
with FreshConfigProvider():
config_provider.save_config({"key": "value", "SETUP_COMPLETE": True})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals("config", json["status"])
@mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
@mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=True))
def test_registry_status_setup_complete(self):
with FreshConfigProvider():
config_provider.save_config({'key': 'value', 'SETUP_COMPLETE': True})
json = self.getJsonResponse(SuperUserRegistryStatus)
self.assertEquals('config', json['status'])
class TestSuperUserConfigFile(ApiTestCase):
def test_get_superuser_invalid_filename(self):
with FreshConfigProvider():
self.getResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404)
def test_get_superuser_invalid_filename(self):
with FreshConfigProvider():
self.getResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_get_superuser(self):
with FreshConfigProvider():
result = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
self.assertFalse(result['exists'])
def test_get_superuser(self):
with FreshConfigProvider():
result = self.getJsonResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert")
)
self.assertFalse(result["exists"])
def test_post_no_file(self):
with FreshConfigProvider():
# No file
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400)
def test_post_no_file(self):
with FreshConfigProvider():
# No file
self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
def test_post_superuser_invalid_filename(self):
with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404)
def test_post_superuser_invalid_filename(self):
with FreshConfigProvider():
self.postResponse(
SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
)
def test_post_superuser(self):
with FreshConfigProvider():
self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400)
def test_post_superuser(self):
with FreshConfigProvider():
self.postResponse(
SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
)
class TestSuperUserCreateInitialSuperUser(ApiTestCase):
def test_no_config_file(self):
with FreshConfigProvider():
# If there is no config.yaml, then this method should security fail.
data = dict(username='cooluser', password='password', email='fake@example.com')
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
def test_no_config_file(self):
with FreshConfigProvider():
# If there is no config.yaml, then this method should security fail.
data = dict(
username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_db_users(self):
with FreshConfigProvider():
# Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
def test_config_file_with_db_users(self):
with FreshConfigProvider():
# Write some config.
self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# If there is a config.yaml, but existing DB users exist, then this method should security
# fail.
data = dict(username='cooluser', password='password', email='fake@example.com')
self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
# If there is a config.yaml, but existing DB users exist, then this method should security
# fail.
data = dict(
username="cooluser", password="password", email="fake@example.com"
)
self.postResponse(
SuperUserCreateInitialSuperUser, data=data, expected_code=403
)
def test_config_file_with_no_db_users(self):
with FreshConfigProvider():
# Write some config.
self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
def test_config_file_with_no_db_users(self):
with FreshConfigProvider():
# Write some config.
self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
# Delete all the users in the DB.
for user in list(User.select()):
model.user.delete_user(user, all_queues)
# Delete all the users in the DB.
for user in list(User.select()):
model.user.delete_user(user, all_queues)
# This method should now succeed.
data = dict(username='cooluser', password='password', email='fake@example.com')
result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
self.assertTrue(result['status'])
# This method should now succeed.
data = dict(
username="cooluser", password="password", email="fake@example.com"
)
result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
self.assertTrue(result["status"])
# Verify the superuser was created.
User.get(User.username == 'cooluser')
# Verify the superuser was created.
User.get(User.username == "cooluser")
# Verify the superuser was placed into the config.
result = self.getJsonResponse(SuperUserConfig)
self.assertEquals(['cooluser'], result['config']['SUPER_USERS'])
# Verify the superuser was placed into the config.
result = self.getJsonResponse(SuperUserConfig)
self.assertEquals(["cooluser"], result["config"]["SUPER_USERS"])
class TestSuperUserConfigValidate(ApiTestCase):
def test_nonsuperuser_noconfig(self):
with FreshConfigProvider():
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'),
data=dict(config={}))
def test_nonsuperuser_noconfig(self):
with FreshConfigProvider():
result = self.postJsonResponse(
SuperUserConfigValidate,
params=dict(service="someservice"),
data=dict(config={}),
)
self.assertFalse(result['status'])
self.assertFalse(result["status"])
def test_nonsuperuser_config(self):
with FreshConfigProvider():
# The validate config call works if there is no config.yaml OR the user is a superuser.
# Add a config, and verify it breaks when unauthenticated.
json = self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
def test_nonsuperuser_config(self):
with FreshConfigProvider():
# The validate config call works if there is no config.yaml OR the user is a superuser.
# Add a config, and verify it breaks when unauthenticated.
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
self.assertTrue(json['exists'])
result = self.postJsonResponse(
SuperUserConfigValidate,
params=dict(service="someservice"),
data=dict(config={}),
)
result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'),
data=dict(config={}))
self.assertFalse(result['status'])
self.assertFalse(result["status"])
class TestSuperUserConfig(ApiTestCase):
def test_get_superuser(self):
with FreshConfigProvider():
json = self.getJsonResponse(SuperUserConfig)
def test_get_superuser(self):
with FreshConfigProvider():
json = self.getJsonResponse(SuperUserConfig)
# Note: We expect the config to be none because a config.yaml should never be checked into
# the directory.
self.assertIsNone(json['config'])
# Note: We expect the config to be none because a config.yaml should never be checked into
# the directory.
self.assertIsNone(json["config"])
def test_put(self):
with FreshConfigProvider() as config:
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
self.assertTrue(json['exists'])
def test_put(self):
with FreshConfigProvider() as config:
json = self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="foobar")
)
self.assertTrue(json["exists"])
# Verify the config file exists.
self.assertTrue(config.config_exists())
# Verify the config file exists.
self.assertTrue(config.config_exists())
# This should succeed.
json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='barbaz'))
self.assertTrue(json['exists'])
# This should succeed.
json = self.putJsonResponse(
SuperUserConfig, data=dict(config={}, hostname="barbaz")
)
self.assertTrue(json["exists"])
json = self.getJsonResponse(SuperUserConfig)
self.assertIsNotNone(json['config'])
json = self.getJsonResponse(SuperUserConfig)
self.assertIsNotNone(json["config"])
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

View file

@ -5,58 +5,63 @@ from backports.tempfile import TemporaryDirectory
from config_app.config_util.config.fileprovider import FileConfigProvider
OLD_CONFIG_SUBDIR = 'old/'
OLD_CONFIG_SUBDIR = "old/"
class TransientDirectoryProvider(FileConfigProvider):
""" Implementation of the config provider that reads and writes the data
""" Implementation of the config provider that reads and writes the data
from/to the file system, only using temporary directories,
deleting old dirs and creating new ones as requested.
"""
def __init__(self, config_volume, yaml_filename, py_filename):
# Create a temp directory that will be cleaned up when we change the config path
# This should ensure we have no "pollution" of different configs:
# no uploaded config should ever affect subsequent config modifications/creations
temp_dir = TemporaryDirectory()
self.temp_dir = temp_dir
self.old_config_dir = None
super(TransientDirectoryProvider, self).__init__(temp_dir.name, yaml_filename, py_filename)
def __init__(self, config_volume, yaml_filename, py_filename):
# Create a temp directory that will be cleaned up when we change the config path
# This should ensure we have no "pollution" of different configs:
# no uploaded config should ever affect subsequent config modifications/creations
temp_dir = TemporaryDirectory()
self.temp_dir = temp_dir
self.old_config_dir = None
super(TransientDirectoryProvider, self).__init__(
temp_dir.name, yaml_filename, py_filename
)
@property
def provider_id(self):
return 'transient'
@property
def provider_id(self):
return "transient"
def new_config_dir(self):
"""
def new_config_dir(self):
"""
Update the path with a new temporary directory, deleting the old one in the process
"""
self.temp_dir.cleanup()
temp_dir = TemporaryDirectory()
self.temp_dir.cleanup()
temp_dir = TemporaryDirectory()
self.config_volume = temp_dir.name
self.temp_dir = temp_dir
self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename)
self.config_volume = temp_dir.name
self.temp_dir = temp_dir
self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename)
def create_copy_of_config_dir(self):
"""
def create_copy_of_config_dir(self):
"""
Create a directory to store loaded/populated configuration (for rollback if necessary)
"""
if self.old_config_dir is not None:
self.old_config_dir.cleanup()
if self.old_config_dir is not None:
self.old_config_dir.cleanup()
temp_dir = TemporaryDirectory()
self.old_config_dir = temp_dir
temp_dir = TemporaryDirectory()
self.old_config_dir = temp_dir
# Python 2.7's shutil.copy() doesn't allow for copying to existing directories,
# so when copying/reading to the old saved config, we have to talk to a subdirectory,
# and use the shutil.copytree() function
copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR))
# Python 2.7's shutil.copy() doesn't allow for copying to existing directories,
# so when copying/reading to the old saved config, we have to talk to a subdirectory,
# and use the shutil.copytree() function
copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR))
def get_config_dir_path(self):
return self.config_volume
def get_config_dir_path(self):
return self.config_volume
def get_old_config_dir(self):
if self.old_config_dir is None:
raise Exception('Cannot return a configuration that was no old configuration')
def get_old_config_dir(self):
if self.old_config_dir is None:
raise Exception(
"Cannot return a configuration that was no old configuration"
)
return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)
return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)

View file

@ -3,37 +3,40 @@ import os
from config_app.config_util.config.fileprovider import FileConfigProvider
from config_app.config_util.config.testprovider import TestConfigProvider
from config_app.config_util.config.TransientDirectoryProvider import TransientDirectoryProvider
from config_app.config_util.config.TransientDirectoryProvider import (
TransientDirectoryProvider,
)
from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX
def get_config_provider(config_volume, yaml_filename, py_filename, testing=False):
""" Loads and returns the config provider for the current environment. """
""" Loads and returns the config provider for the current environment. """
if testing:
return TestConfigProvider()
if testing:
return TestConfigProvider()
return TransientDirectoryProvider(config_volume, yaml_filename, py_filename)
return TransientDirectoryProvider(config_volume, yaml_filename, py_filename)
def get_config_as_kube_secret(config_path):
data = {}
data = {}
# Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir
# we have to put the extra certs in with a prefix, and then one of our init scripts
# (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory
# so that they'll be installed correctly on startup by the certs_install script
certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY)
if os.path.exists(certs_dir):
for extra_cert in os.listdir(certs_dir):
with open(os.path.join(certs_dir, extra_cert)) as f:
data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(f.read())
# Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir
# we have to put the extra certs in with a prefix, and then one of our init scripts
# (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory
# so that they'll be installed correctly on startup by the certs_install script
certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY)
if os.path.exists(certs_dir):
for extra_cert in os.listdir(certs_dir):
with open(os.path.join(certs_dir, extra_cert)) as f:
data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(
f.read()
)
for name in os.listdir(config_path):
file_path = os.path.join(config_path, name)
if not os.path.isdir(file_path):
with open(file_path) as f:
data[name] = base64.b64encode(f.read())
for name in os.listdir(config_path):
file_path = os.path.join(config_path, name)
if not os.path.isdir(file_path):
with open(file_path) as f:
data[name] = base64.b64encode(f.read())
return data
return data

Some files were not shown because too many files have changed in this diff Show more