diff --git a/_init.py b/_init.py
index 84a574b0f..216f47e15 100644
--- a/_init.py
+++ b/_init.py
@@ -7,40 +7,45 @@ from util.config.provider import get_config_provider
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
-STATIC_DIR = os.path.join(ROOT_DIR, 'static/')
-STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/')
-STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/')
-STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, 'webfonts/')
-TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/')
+STATIC_DIR = os.path.join(ROOT_DIR, "static/")
+STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
+STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
+STATIC_WEBFONTS_DIR = os.path.join(STATIC_DIR, "webfonts/")
+TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
-IS_TESTING = 'TEST' in os.environ
-IS_BUILDING = 'BUILDING' in os.environ
-IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ
-OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, 'stack/')
+IS_TESTING = "TEST" in os.environ
+IS_BUILDING = "BUILDING" in os.environ
+IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
+OVERRIDE_CONFIG_DIRECTORY = os.path.join(CONF_DIR, "stack/")
-config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py',
- testing=IS_TESTING, kubernetes=IS_KUBERNETES)
+config_provider = get_config_provider(
+ OVERRIDE_CONFIG_DIRECTORY,
+ "config.yaml",
+ "config.py",
+ testing=IS_TESTING,
+ kubernetes=IS_KUBERNETES,
+)
def _get_version_number_changelog():
- try:
- with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f:
- return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0)
- except IOError:
- return ''
+ try:
+ with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
+ return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
+ except IOError:
+ return ""
def _get_git_sha():
- if os.path.exists("GIT_HEAD"):
- with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
- return f.read()
- else:
- try:
- return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
- except (OSError, subprocess.CalledProcessError, Exception):
- pass
- return "unknown"
+ if os.path.exists("GIT_HEAD"):
+ with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
+ return f.read()
+ else:
+ try:
+ return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
+ except (OSError, subprocess.CalledProcessError, Exception):
+ pass
+ return "unknown"
__version__ = _get_version_number_changelog()
diff --git a/active_migration.py b/active_migration.py
index 693bcaac6..c80e239ed 100644
--- a/active_migration.py
+++ b/active_migration.py
@@ -1,22 +1,30 @@
from enum import Enum, unique
from data.migrationutil import DefinedDataMigration, MigrationPhase
+
@unique
class ERTMigrationFlags(Enum):
- """ Flags for the encrypted robot token migration. """
- READ_OLD_FIELDS = 'read-old'
- WRITE_OLD_FIELDS = 'write-old'
+ """ Flags for the encrypted robot token migration. """
+
+ READ_OLD_FIELDS = "read-old"
+ WRITE_OLD_FIELDS = "write-old"
ActiveDataMigration = DefinedDataMigration(
- 'encrypted_robot_tokens',
- 'ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE',
- [
- MigrationPhase('add-new-fields', 'c13c8052f7a6', [ERTMigrationFlags.READ_OLD_FIELDS,
- ERTMigrationFlags.WRITE_OLD_FIELDS]),
- MigrationPhase('backfill-then-read-only-new',
- '703298a825c2', [ERTMigrationFlags.WRITE_OLD_FIELDS]),
- MigrationPhase('stop-writing-both', '703298a825c2', []),
- MigrationPhase('remove-old-fields', 'c059b952ed76', []),
- ]
+ "encrypted_robot_tokens",
+ "ENCRYPTED_ROBOT_TOKEN_MIGRATION_PHASE",
+ [
+ MigrationPhase(
+ "add-new-fields",
+ "c13c8052f7a6",
+ [ERTMigrationFlags.READ_OLD_FIELDS, ERTMigrationFlags.WRITE_OLD_FIELDS],
+ ),
+ MigrationPhase(
+ "backfill-then-read-only-new",
+ "703298a825c2",
+ [ERTMigrationFlags.WRITE_OLD_FIELDS],
+ ),
+ MigrationPhase("stop-writing-both", "703298a825c2", []),
+ MigrationPhase("remove-old-fields", "c059b952ed76", []),
+ ],
)
diff --git a/app.py b/app.py
index 33245bee1..677ab9f95 100644
--- a/app.py
+++ b/app.py
@@ -16,8 +16,14 @@ from werkzeug.exceptions import HTTPException
import features
-from _init import (config_provider, CONF_DIR, IS_KUBERNETES, IS_TESTING, OVERRIDE_CONFIG_DIRECTORY,
- IS_BUILDING)
+from _init import (
+ config_provider,
+ CONF_DIR,
+ IS_KUBERNETES,
+ IS_TESTING,
+ OVERRIDE_CONFIG_DIRECTORY,
+ IS_BUILDING,
+)
from auth.auth_context import get_authenticated_user
from avatars.avatars import Avatar
@@ -35,7 +41,11 @@ from data.userevent import UserEventsBuilderModule
from data.userfiles import Userfiles
from data.users import UserAuthentication
from data.registry_model import registry_model
-from path_converters import RegexConverter, RepositoryPathConverter, APIRepositoryPathConverter
+from path_converters import (
+ RegexConverter,
+ RepositoryPathConverter,
+ APIRepositoryPathConverter,
+)
from oauth.services.github import GithubOAuthService
from oauth.services.gitlab import GitLabOAuthService
from oauth.loginmanager import OAuthLoginManager
@@ -62,13 +72,13 @@ from util.security.instancekeys import InstanceKeys
from util.security.signing import Signer
-OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, 'stack/config.yaml')
-OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, 'stack/config.py')
+OVERRIDE_CONFIG_YAML_FILENAME = os.path.join(CONF_DIR, "stack/config.yaml")
+OVERRIDE_CONFIG_PY_FILENAME = os.path.join(CONF_DIR, "stack/config.py")
-OVERRIDE_CONFIG_KEY = 'QUAY_OVERRIDE_CONFIG'
+OVERRIDE_CONFIG_KEY = "QUAY_OVERRIDE_CONFIG"
-DOCKER_V2_SIGNINGKEY_FILENAME = 'docker_v2.pem'
-INIT_SCRIPTS_LOCATION = '/conf/init/'
+DOCKER_V2_SIGNINGKEY_FILENAME = "docker_v2.pem"
+INIT_SCRIPTS_LOCATION = "/conf/init/"
app = Flask(__name__)
logger = logging.getLogger(__name__)
@@ -79,62 +89,75 @@ is_kubernetes = IS_KUBERNETES
is_building = IS_BUILDING
if is_testing:
- from test.testconfig import TestConfig
- logger.debug('Loading test config.')
- app.config.from_object(TestConfig())
+ from test.testconfig import TestConfig
+
+ logger.debug("Loading test config.")
+ app.config.from_object(TestConfig())
else:
- from config import DefaultConfig
- logger.debug('Loading default config.')
- app.config.from_object(DefaultConfig())
- app.teardown_request(database.close_db_filter)
+ from config import DefaultConfig
+
+ logger.debug("Loading default config.")
+ app.config.from_object(DefaultConfig())
+ app.teardown_request(database.close_db_filter)
# Load the override config via the provider.
config_provider.update_app_config(app.config)
# Update any configuration found in the override environment variable.
-environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, '{}'))
+environ_config = json.loads(os.environ.get(OVERRIDE_CONFIG_KEY, "{}"))
app.config.update(environ_config)
# Fix remote address handling for Flask.
-if app.config.get('PROXY_COUNT', 1):
- app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get('PROXY_COUNT', 1))
+if app.config.get("PROXY_COUNT", 1):
+ app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=app.config.get("PROXY_COUNT", 1))
# Ensure the V3 upgrade key is specified correctly. If not, simply fail.
# TODO: Remove for V3.1.
-if not is_testing and not is_building and app.config.get('SETUP_COMPLETE', False):
- v3_upgrade_mode = app.config.get('V3_UPGRADE_MODE')
- if v3_upgrade_mode is None:
- raise Exception('Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs')
+if not is_testing and not is_building and app.config.get("SETUP_COMPLETE", False):
+ v3_upgrade_mode = app.config.get("V3_UPGRADE_MODE")
+ if v3_upgrade_mode is None:
+ raise Exception(
+ "Configuration flag `V3_UPGRADE_MODE` must be set. Please check the upgrade docs"
+ )
- if (v3_upgrade_mode != 'background'
- and v3_upgrade_mode != 'complete'
- and v3_upgrade_mode != 'production-transition'
- and v3_upgrade_mode != 'post-oci-rollout'
- and v3_upgrade_mode != 'post-oci-roll-back-compat'):
- raise Exception('Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs')
+ if (
+ v3_upgrade_mode != "background"
+ and v3_upgrade_mode != "complete"
+ and v3_upgrade_mode != "production-transition"
+ and v3_upgrade_mode != "post-oci-rollout"
+ and v3_upgrade_mode != "post-oci-roll-back-compat"
+ ):
+ raise Exception(
+ "Invalid value for config `V3_UPGRADE_MODE`. Please check the upgrade docs"
+ )
# Split the registry model based on config.
# TODO: Remove once we are fully on the OCI data model.
-registry_model.setup_split(app.config.get('OCI_NAMESPACE_PROPORTION') or 0,
- app.config.get('OCI_NAMESPACE_WHITELIST') or set(),
- app.config.get('V22_NAMESPACE_WHITELIST') or set(),
- app.config.get('V3_UPGRADE_MODE'))
+registry_model.setup_split(
+ app.config.get("OCI_NAMESPACE_PROPORTION") or 0,
+ app.config.get("OCI_NAMESPACE_WHITELIST") or set(),
+ app.config.get("V22_NAMESPACE_WHITELIST") or set(),
+ app.config.get("V3_UPGRADE_MODE"),
+)
# Allow user to define a custom storage preference for the local instance.
-_distributed_storage_preference = os.environ.get('QUAY_DISTRIBUTED_STORAGE_PREFERENCE', '').split()
+_distributed_storage_preference = os.environ.get(
+ "QUAY_DISTRIBUTED_STORAGE_PREFERENCE", ""
+).split()
if _distributed_storage_preference:
- app.config['DISTRIBUTED_STORAGE_PREFERENCE'] = _distributed_storage_preference
+ app.config["DISTRIBUTED_STORAGE_PREFERENCE"] = _distributed_storage_preference
# Generate a secret key if none was specified.
-if app.config['SECRET_KEY'] is None:
- logger.debug('Generating in-memory secret key')
- app.config['SECRET_KEY'] = generate_secret_key()
+if app.config["SECRET_KEY"] is None:
+ logger.debug("Generating in-memory secret key")
+ app.config["SECRET_KEY"] = generate_secret_key()
# If the "preferred" scheme is https, then http is not allowed. Therefore, ensure we have a secure
# session cookie.
-if (app.config['PREFERRED_URL_SCHEME'] == 'https' and
- not app.config.get('FORCE_NONSECURE_SESSION_COOKIE', False)):
- app.config['SESSION_COOKIE_SECURE'] = True
+if app.config["PREFERRED_URL_SCHEME"] == "https" and not app.config.get(
+ "FORCE_NONSECURE_SESSION_COOKIE", False
+):
+ app.config["SESSION_COOKIE_SECURE"] = True
# Load features from config.
features.import_features(app.config)
@@ -145,65 +168,77 @@ logger.debug("Loaded config", extra={"config": app.config})
class RequestWithId(Request):
- request_gen = staticmethod(urn_generator(['request']))
+ request_gen = staticmethod(urn_generator(["request"]))
- def __init__(self, *args, **kwargs):
- super(RequestWithId, self).__init__(*args, **kwargs)
- self.request_id = self.request_gen()
+ def __init__(self, *args, **kwargs):
+ super(RequestWithId, self).__init__(*args, **kwargs)
+ self.request_id = self.request_gen()
@app.before_request
def _request_start():
- if os.getenv('PYDEV_DEBUG', None):
- import pydevd
- host, port = os.getenv('PYDEV_DEBUG').split(':')
- pydevd.settrace(host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False)
+ if os.getenv("PYDEV_DEBUG", None):
+ import pydevd
- logger.debug('Starting request: %s (%s)', request.request_id, request.path,
- extra={"request_id": request.request_id})
+ host, port = os.getenv("PYDEV_DEBUG").split(":")
+ pydevd.settrace(
+ host,
+ port=int(port),
+ stdoutToServer=True,
+ stderrToServer=True,
+ suspend=False,
+ )
+
+ logger.debug(
+ "Starting request: %s (%s)",
+ request.request_id,
+ request.path,
+ extra={"request_id": request.request_id},
+ )
-DEFAULT_FILTER = lambda x: '[FILTERED]'
+DEFAULT_FILTER = lambda x: "[FILTERED]"
FILTERED_VALUES = [
- {'key': ['password'], 'fn': DEFAULT_FILTER},
- {'key': ['user', 'password'], 'fn': DEFAULT_FILTER},
- {'key': ['blob'], 'fn': lambda x: x[0:8]}
+ {"key": ["password"], "fn": DEFAULT_FILTER},
+ {"key": ["user", "password"], "fn": DEFAULT_FILTER},
+ {"key": ["blob"], "fn": lambda x: x[0:8]},
]
@app.after_request
def _request_end(resp):
- try:
- jsonbody = request.get_json(force=True, silent=True)
- except HTTPException:
- jsonbody = None
+ try:
+ jsonbody = request.get_json(force=True, silent=True)
+ except HTTPException:
+ jsonbody = None
- values = request.values.to_dict()
+ values = request.values.to_dict()
- if jsonbody and not isinstance(jsonbody, dict):
- jsonbody = {'_parsererror': jsonbody}
+ if jsonbody and not isinstance(jsonbody, dict):
+ jsonbody = {"_parsererror": jsonbody}
- if isinstance(values, dict):
- filter_logs(values, FILTERED_VALUES)
+ if isinstance(values, dict):
+ filter_logs(values, FILTERED_VALUES)
- extra = {
- "endpoint": request.endpoint,
- "request_id" : request.request_id,
- "remote_addr": request.remote_addr,
- "http_method": request.method,
- "original_url": request.url,
- "path": request.path,
- "parameters": values,
- "json_body": jsonbody,
- "confsha": CONFIG_DIGEST,
- }
+ extra = {
+ "endpoint": request.endpoint,
+ "request_id": request.request_id,
+ "remote_addr": request.remote_addr,
+ "http_method": request.method,
+ "original_url": request.url,
+ "path": request.path,
+ "parameters": values,
+ "json_body": jsonbody,
+ "confsha": CONFIG_DIGEST,
+ }
- if request.user_agent is not None:
- extra["user-agent"] = request.user_agent.string
-
- logger.debug('Ending request: %s (%s)', request.request_id, request.path, extra=extra)
- return resp
+ if request.user_agent is not None:
+ extra["user-agent"] = request.user_agent.string
+ logger.debug(
+ "Ending request: %s (%s)", request.request_id, request.path, extra=extra
+ )
+ return resp
root_logger = logging.getLogger()
@@ -211,13 +246,13 @@ root_logger = logging.getLogger()
app.request_class = RequestWithId
# Register custom converters.
-app.url_map.converters['regex'] = RegexConverter
-app.url_map.converters['repopath'] = RepositoryPathConverter
-app.url_map.converters['apirepopath'] = APIRepositoryPathConverter
+app.url_map.converters["regex"] = RegexConverter
+app.url_map.converters["repopath"] = RepositoryPathConverter
+app.url_map.converters["apirepopath"] = APIRepositoryPathConverter
Principal(app, use_sessions=False)
-tf = app.config['DB_TRANSACTION_FACTORY']
+tf = app.config["DB_TRANSACTION_FACTORY"]
model_cache = get_model_cache(app.config)
avatar = Avatar(app)
@@ -225,10 +260,14 @@ login_manager = LoginManager(app)
mail = Mail(app)
prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus)
-chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue)
+chunk_cleanup_queue = WorkQueue(
+ app.config["CHUNK_CLEANUP_QUEUE_NAME"], tf, metric_queue=metric_queue
+)
instance_keys = InstanceKeys(app)
ip_resolver = IPResolver(app)
-storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver)
+storage = Storage(
+ app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver
+)
userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage)
analytics = Analytics(app)
@@ -246,55 +285,99 @@ build_canceller = BuildCanceller(app)
start_cloudwatch_sender(metric_queue, app)
-github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG')
-gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG')
+github_trigger = GithubOAuthService(app.config, "GITHUB_TRIGGER_CONFIG")
+gitlab_trigger = GitLabOAuthService(app.config, "GITLAB_TRIGGER_CONFIG")
oauth_login = OAuthLoginManager(app.config)
oauth_apps = [github_trigger, gitlab_trigger]
-image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf,
- has_namespace=False, metric_queue=metric_queue)
-dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf,
- metric_queue=metric_queue,
- reporter=BuildMetricQueueReporter(metric_queue),
- has_namespace=True)
-notification_queue = WorkQueue(app.config['NOTIFICATION_QUEUE_NAME'], tf, has_namespace=True,
- metric_queue=metric_queue)
-secscan_notification_queue = WorkQueue(app.config['SECSCAN_NOTIFICATION_QUEUE_NAME'], tf,
- has_namespace=False,
- metric_queue=metric_queue)
-export_action_logs_queue = WorkQueue(app.config['EXPORT_ACTION_LOGS_QUEUE_NAME'], tf,
- has_namespace=True,
- metric_queue=metric_queue)
+image_replication_queue = WorkQueue(
+ app.config["REPLICATION_QUEUE_NAME"],
+ tf,
+ has_namespace=False,
+ metric_queue=metric_queue,
+)
+dockerfile_build_queue = WorkQueue(
+ app.config["DOCKERFILE_BUILD_QUEUE_NAME"],
+ tf,
+ metric_queue=metric_queue,
+ reporter=BuildMetricQueueReporter(metric_queue),
+ has_namespace=True,
+)
+notification_queue = WorkQueue(
+ app.config["NOTIFICATION_QUEUE_NAME"],
+ tf,
+ has_namespace=True,
+ metric_queue=metric_queue,
+)
+secscan_notification_queue = WorkQueue(
+ app.config["SECSCAN_NOTIFICATION_QUEUE_NAME"],
+ tf,
+ has_namespace=False,
+ metric_queue=metric_queue,
+)
+export_action_logs_queue = WorkQueue(
+ app.config["EXPORT_ACTION_LOGS_QUEUE_NAME"],
+ tf,
+ has_namespace=True,
+ metric_queue=metric_queue,
+)
# Note: We set `has_namespace` to `False` here, as we explicitly want this queue to not be emptied
# when a namespace is marked for deletion.
-namespace_gc_queue = WorkQueue(app.config['NAMESPACE_GC_QUEUE_NAME'], tf, has_namespace=False,
- metric_queue=metric_queue)
+namespace_gc_queue = WorkQueue(
+ app.config["NAMESPACE_GC_QUEUE_NAME"],
+ tf,
+ has_namespace=False,
+ metric_queue=metric_queue,
+)
-all_queues = [image_replication_queue, dockerfile_build_queue, notification_queue,
- secscan_notification_queue, chunk_cleanup_queue, namespace_gc_queue]
+all_queues = [
+ image_replication_queue,
+ dockerfile_build_queue,
+ notification_queue,
+ secscan_notification_queue,
+ chunk_cleanup_queue,
+ namespace_gc_queue,
+]
-url_scheme_and_hostname = URLSchemeAndHostname(app.config['PREFERRED_URL_SCHEME'], app.config['SERVER_HOSTNAME'])
-secscan_api = SecurityScannerAPI(app.config, storage, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'],
- uri_creator=get_blob_download_uri_getter(app.test_request_context('/'), url_scheme_and_hostname),
- instance_keys=instance_keys)
+url_scheme_and_hostname = URLSchemeAndHostname(
+ app.config["PREFERRED_URL_SCHEME"], app.config["SERVER_HOSTNAME"]
+)
+secscan_api = SecurityScannerAPI(
+ app.config,
+ storage,
+ app.config["SERVER_HOSTNAME"],
+ app.config["HTTPCLIENT"],
+ uri_creator=get_blob_download_uri_getter(
+ app.test_request_context("/"), url_scheme_and_hostname
+ ),
+ instance_keys=instance_keys,
+)
-repo_mirror_api = RepoMirrorAPI(app.config, app.config['SERVER_HOSTNAME'], app.config['HTTPCLIENT'],
- instance_keys=instance_keys)
+repo_mirror_api = RepoMirrorAPI(
+ app.config,
+ app.config["SERVER_HOSTNAME"],
+ app.config["HTTPCLIENT"],
+ instance_keys=instance_keys,
+)
tuf_metadata_api = TUFMetadataAPI(app, app.config)
# Check for a key in config. If none found, generate a new signing key for Docker V2 manifests.
_v2_key_path = os.path.join(OVERRIDE_CONFIG_DIRECTORY, DOCKER_V2_SIGNINGKEY_FILENAME)
if os.path.exists(_v2_key_path):
- docker_v2_signing_key = RSAKey().load(_v2_key_path)
+ docker_v2_signing_key = RSAKey().load(_v2_key_path)
else:
- docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
+ docker_v2_signing_key = RSAKey(key=RSA.generate(2048))
# Configure the database.
-if app.config.get('DATABASE_SECRET_KEY') is None and app.config.get('SETUP_COMPLETE', False):
- raise Exception('Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?')
+if app.config.get("DATABASE_SECRET_KEY") is None and app.config.get(
+ "SETUP_COMPLETE", False
+):
+ raise Exception(
+ "Missing DATABASE_SECRET_KEY in config; did you perhaps forget to add it?"
+ )
database.configure(app.config)
@@ -306,8 +389,9 @@ model.config.register_repo_cleanup_callback(tuf_metadata_api.delete_metadata)
@login_manager.user_loader
def load_user(user_uuid):
- logger.debug('User loader loading deferred user with uuid: %s', user_uuid)
- return LoginWrappedDBUser(user_uuid)
+ logger.debug("User loader loading deferred user with uuid: %s", user_uuid)
+ return LoginWrappedDBUser(user_uuid)
+
logs_model.configure(app.config)
diff --git a/application.py b/application.py
index b7f478841..1a0c799fa 100644
--- a/application.py
+++ b/application.py
@@ -1,5 +1,6 @@
# NOTE: Must be before we import or call anything that may be synchronous.
from gevent import monkey
+
monkey.patch_all()
import os
@@ -17,6 +18,6 @@ import registry
import secscan
-if __name__ == '__main__':
- logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
- application.run(port=5000, debug=True, threaded=True, host='0.0.0.0')
+if __name__ == "__main__":
+ logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
+ application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")
diff --git a/auth/auth_context.py b/auth/auth_context.py
index 8cb57f691..375d3d62a 100644
--- a/auth/auth_context.py
+++ b/auth/auth_context.py
@@ -1,21 +1,25 @@
from flask import _request_ctx_stack
+
def get_authenticated_context():
- """ Returns the auth context for the current request context, if any. """
- return getattr(_request_ctx_stack.top, 'authenticated_context', None)
+ """ Returns the auth context for the current request context, if any. """
+ return getattr(_request_ctx_stack.top, "authenticated_context", None)
+
def get_authenticated_user():
- """ Returns the authenticated user, if any, or None if none. """
- context = get_authenticated_context()
- return context.authed_user if context else None
+ """ Returns the authenticated user, if any, or None if none. """
+ context = get_authenticated_context()
+ return context.authed_user if context else None
+
def get_validated_oauth_token():
- """ Returns the authenticated and validated OAuth access token, if any, or None if none. """
- context = get_authenticated_context()
- return context.authed_oauth_token if context else None
+ """ Returns the authenticated and validated OAuth access token, if any, or None if none. """
+ context = get_authenticated_context()
+ return context.authed_oauth_token if context else None
+
def set_authenticated_context(auth_context):
- """ Sets the auth context for the current request context to that given. """
- ctx = _request_ctx_stack.top
- ctx.authenticated_context = auth_context
- return auth_context
+ """ Sets the auth context for the current request context to that given. """
+ ctx = _request_ctx_stack.top
+ ctx.authenticated_context = auth_context
+ return auth_context
diff --git a/auth/auth_context_type.py b/auth/auth_context_type.py
index 012222243..c35083f8d 100644
--- a/auth/auth_context_type.py
+++ b/auth/auth_context_type.py
@@ -16,422 +16,446 @@ from auth.scopes import scopes_from_scope_string
logger = logging.getLogger(__name__)
+
@add_metaclass(ABCMeta)
class AuthContext(object):
- """
+ """
Interface that represents the current context of authentication.
"""
- @property
- @abstractmethod
- def entity_kind(self):
- """ Returns the kind of the entity in this auth context. """
- pass
+ @property
+ @abstractmethod
+ def entity_kind(self):
+ """ Returns the kind of the entity in this auth context. """
+ pass
- @property
- @abstractmethod
- def is_anonymous(self):
- """ Returns true if this is an anonymous context. """
- pass
+ @property
+ @abstractmethod
+ def is_anonymous(self):
+ """ Returns true if this is an anonymous context. """
+ pass
- @property
- @abstractmethod
- def authed_oauth_token(self):
- """ Returns the authenticated OAuth token, if any. """
- pass
+ @property
+ @abstractmethod
+ def authed_oauth_token(self):
+ """ Returns the authenticated OAuth token, if any. """
+ pass
- @property
- @abstractmethod
- def authed_user(self):
- """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
+ @property
+ @abstractmethod
+ def authed_user(self):
+ """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts.
"""
- pass
+ pass
- @property
- @abstractmethod
- def has_nonrobot_user(self):
- """ Returns whether a user (not a robot) was authenticated successfully. """
- pass
+ @property
+ @abstractmethod
+ def has_nonrobot_user(self):
+ """ Returns whether a user (not a robot) was authenticated successfully. """
+ pass
- @property
- @abstractmethod
- def identity(self):
- """ Returns the identity for the auth context. """
- pass
+ @property
+ @abstractmethod
+ def identity(self):
+ """ Returns the identity for the auth context. """
+ pass
- @property
- @abstractmethod
- def description(self):
- """ Returns a human-readable and *public* description of the current auth context. """
- pass
+ @property
+ @abstractmethod
+ def description(self):
+ """ Returns a human-readable and *public* description of the current auth context. """
+ pass
- @property
- @abstractmethod
- def credential_username(self):
- """ Returns the username to create credentials for this context's entity, if any. """
- pass
+ @property
+ @abstractmethod
+ def credential_username(self):
+ """ Returns the username to create credentials for this context's entity, if any. """
+ pass
- @abstractmethod
- def analytics_id_and_public_metadata(self):
- """ Returns the analytics ID and public log metadata for this auth context. """
- pass
+ @abstractmethod
+ def analytics_id_and_public_metadata(self):
+ """ Returns the analytics ID and public log metadata for this auth context. """
+ pass
- @abstractmethod
- def apply_to_request_context(self):
- """ Applies this auth result to the auth context and Flask-Principal. """
- pass
+ @abstractmethod
+ def apply_to_request_context(self):
+ """ Applies this auth result to the auth context and Flask-Principal. """
+ pass
- @abstractmethod
- def to_signed_dict(self):
- """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
+ @abstractmethod
+ def to_signed_dict(self):
+ """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
- pass
+ pass
- @property
- @abstractmethod
- def unique_key(self):
- """ Returns a key that is unique to this auth context type and its data. For example, an
+ @property
+ @abstractmethod
+ def unique_key(self):
+ """ Returns a key that is unique to this auth context type and its data. For example, an
instance of the auth context type for the user might be a string of the form
`user-{user-uuid}`. Callers should treat this key as opaque and not rely on the contents
for anything besides uniqueness. This is typically used by callers when they'd like to
check cache but not hit the database to get a fully validated auth context.
"""
- pass
+ pass
class ValidatedAuthContext(AuthContext):
- """ ValidatedAuthContext represents the loaded, authenticated and validated auth information
+ """ ValidatedAuthContext represents the loaded, authenticated and validated auth information
for the current request context.
"""
- def __init__(self, user=None, token=None, oauthtoken=None, robot=None, appspecifictoken=None,
- signed_data=None):
- # Note: These field names *MUST* match the string values of the kinds defined in
- # ContextEntityKind.
- self.user = user
- self.robot = robot
- self.token = token
- self.oauthtoken = oauthtoken
- self.appspecifictoken = appspecifictoken
- self.signed_data = signed_data
- def tuple(self):
- return vars(self).values()
+ def __init__(
+ self,
+ user=None,
+ token=None,
+ oauthtoken=None,
+ robot=None,
+ appspecifictoken=None,
+ signed_data=None,
+ ):
+ # Note: These field names *MUST* match the string values of the kinds defined in
+ # ContextEntityKind.
+ self.user = user
+ self.robot = robot
+ self.token = token
+ self.oauthtoken = oauthtoken
+ self.appspecifictoken = appspecifictoken
+ self.signed_data = signed_data
- def __eq__(self, other):
- return self.tuple() == other.tuple()
+ def tuple(self):
+ return vars(self).values()
- @property
- def entity_kind(self):
- """ Returns the kind of the entity in this auth context. """
- for kind in ContextEntityKind:
- if hasattr(self, kind.value) and getattr(self, kind.value):
- return kind
+ def __eq__(self, other):
+ return self.tuple() == other.tuple()
- return ContextEntityKind.anonymous
+ @property
+ def entity_kind(self):
+ """ Returns the kind of the entity in this auth context. """
+ for kind in ContextEntityKind:
+ if hasattr(self, kind.value) and getattr(self, kind.value):
+ return kind
- @property
- def authed_user(self):
- """ Returns the authenticated user, whether directly, or via an OAuth token. Note that this
+ return ContextEntityKind.anonymous
+
+ @property
+ def authed_user(self):
+ """ Returns the authenticated user, whether directly, or via an OAuth token. Note that this
will also return robot accounts.
"""
- authed_user = self._authed_user()
- if authed_user is not None and not authed_user.enabled:
- logger.warning('Attempt to reference a disabled user/robot: %s', authed_user.username)
- return None
+ authed_user = self._authed_user()
+ if authed_user is not None and not authed_user.enabled:
+ logger.warning(
+ "Attempt to reference a disabled user/robot: %s", authed_user.username
+ )
+ return None
- return authed_user
+ return authed_user
- @property
- def authed_oauth_token(self):
- return self.oauthtoken
+ @property
+ def authed_oauth_token(self):
+ return self.oauthtoken
- def _authed_user(self):
- if self.oauthtoken:
- return self.oauthtoken.authorized_user
+ def _authed_user(self):
+ if self.oauthtoken:
+ return self.oauthtoken.authorized_user
- if self.appspecifictoken:
- return self.appspecifictoken.user
+ if self.appspecifictoken:
+ return self.appspecifictoken.user
- if self.signed_data:
- return model.user.get_user(self.signed_data['user_context'])
+ if self.signed_data:
+ return model.user.get_user(self.signed_data["user_context"])
- return self.user if self.user else self.robot
+ return self.user if self.user else self.robot
- @property
- def is_anonymous(self):
- """ Returns true if this is an anonymous context. """
- return not self.authed_user and not self.token and not self.signed_data
+ @property
+ def is_anonymous(self):
+ """ Returns true if this is an anonymous context. """
+ return not self.authed_user and not self.token and not self.signed_data
- @property
- def has_nonrobot_user(self):
- """ Returns whether a user (not a robot) was authenticated successfully. """
- return bool(self.authed_user and not self.robot)
+ @property
+ def has_nonrobot_user(self):
+ """ Returns whether a user (not a robot) was authenticated successfully. """
+ return bool(self.authed_user and not self.robot)
- @property
- def identity(self):
- """ Returns the identity for the auth context. """
- if self.oauthtoken:
- scope_set = scopes_from_scope_string(self.oauthtoken.scope)
- return QuayDeferredPermissionUser.for_user(self.oauthtoken.authorized_user, scope_set)
+ @property
+ def identity(self):
+ """ Returns the identity for the auth context. """
+ if self.oauthtoken:
+ scope_set = scopes_from_scope_string(self.oauthtoken.scope)
+ return QuayDeferredPermissionUser.for_user(
+ self.oauthtoken.authorized_user, scope_set
+ )
- if self.authed_user:
- return QuayDeferredPermissionUser.for_user(self.authed_user)
+ if self.authed_user:
+ return QuayDeferredPermissionUser.for_user(self.authed_user)
- if self.token:
- return Identity(self.token.get_code(), 'token')
+ if self.token:
+ return Identity(self.token.get_code(), "token")
- if self.signed_data:
- identity = Identity(None, 'signed_grant')
- identity.provides.update(self.signed_data['grants'])
- return identity
+ if self.signed_data:
+ identity = Identity(None, "signed_grant")
+ identity.provides.update(self.signed_data["grants"])
+ return identity
- return None
+ return None
- @property
- def entity_reference(self):
- """ Returns the DB object reference for this context's entity. """
- if self.entity_kind == ContextEntityKind.anonymous:
- return None
+ @property
+ def entity_reference(self):
+ """ Returns the DB object reference for this context's entity. """
+ if self.entity_kind == ContextEntityKind.anonymous:
+ return None
- return getattr(self, self.entity_kind.value)
+ return getattr(self, self.entity_kind.value)
- @property
- def description(self):
- """ Returns a human-readable and *public* description of the current auth context. """
- handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
- return handler.description(self.entity_reference)
+ @property
+ def description(self):
+ """ Returns a human-readable and *public* description of the current auth context. """
+ handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
+ return handler.description(self.entity_reference)
- @property
- def credential_username(self):
- """ Returns the username to create credentials for this context's entity, if any. """
- handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
- return handler.credential_username(self.entity_reference)
+ @property
+ def credential_username(self):
+ """ Returns the username to create credentials for this context's entity, if any. """
+ handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
+ return handler.credential_username(self.entity_reference)
- def analytics_id_and_public_metadata(self):
- """ Returns the analytics ID and public log metadata for this auth context. """
- handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
- return handler.analytics_id_and_public_metadata(self.entity_reference)
+ def analytics_id_and_public_metadata(self):
+ """ Returns the analytics ID and public log metadata for this auth context. """
+ handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
+ return handler.analytics_id_and_public_metadata(self.entity_reference)
- def apply_to_request_context(self):
- """ Applies this auth result to the auth context and Flask-Principal. """
- # Save to the request context.
- set_authenticated_context(self)
+ def apply_to_request_context(self):
+ """ Applies this auth result to the auth context and Flask-Principal. """
+ # Save to the request context.
+ set_authenticated_context(self)
- # Set the identity for Flask-Principal.
- if self.identity:
- identity_changed.send(app, identity=self.identity)
+ # Set the identity for Flask-Principal.
+ if self.identity:
+ identity_changed.send(app, identity=self.identity)
- @property
- def unique_key(self):
- signed_dict = self.to_signed_dict()
- return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)'))
+ @property
+ def unique_key(self):
+ signed_dict = self.to_signed_dict()
+ return "%s-%s" % (
+ signed_dict["entity_kind"],
+ signed_dict.get("entity_reference", "(anon)"),
+ )
- def to_signed_dict(self):
- """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
+ def to_signed_dict(self):
+ """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
- dict_data = {
- 'version': 2,
- 'entity_kind': self.entity_kind.value,
- }
+ dict_data = {"version": 2, "entity_kind": self.entity_kind.value}
- if self.entity_kind != ContextEntityKind.anonymous:
- handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
- dict_data.update({
- 'entity_reference': handler.get_serialized_entity_reference(self.entity_reference),
- })
+ if self.entity_kind != ContextEntityKind.anonymous:
+ handler = CONTEXT_ENTITY_HANDLERS[self.entity_kind]()
+ dict_data.update(
+ {
+ "entity_reference": handler.get_serialized_entity_reference(
+ self.entity_reference
+ )
+ }
+ )
- # Add legacy information.
- # TODO: Remove this all once the new code is fully deployed.
- if self.token:
- dict_data.update({
- 'kind': 'token',
- 'token': self.token.code,
- })
+ # Add legacy information.
+ # TODO: Remove this all once the new code is fully deployed.
+ if self.token:
+ dict_data.update({"kind": "token", "token": self.token.code})
- if self.oauthtoken:
- dict_data.update({
- 'kind': 'oauth',
- 'oauth': self.oauthtoken.uuid,
- 'user': self.authed_user.username,
- })
+ if self.oauthtoken:
+ dict_data.update(
+ {
+ "kind": "oauth",
+ "oauth": self.oauthtoken.uuid,
+ "user": self.authed_user.username,
+ }
+ )
- if self.user or self.robot:
- dict_data.update({
- 'kind': 'user',
- 'user': self.authed_user.username,
- })
+ if self.user or self.robot:
+ dict_data.update({"kind": "user", "user": self.authed_user.username})
- if self.appspecifictoken:
- dict_data.update({
- 'kind': 'user',
- 'user': self.authed_user.username,
- })
+ if self.appspecifictoken:
+ dict_data.update({"kind": "user", "user": self.authed_user.username})
- if self.is_anonymous:
- dict_data.update({
- 'kind': 'anonymous',
- })
+ if self.is_anonymous:
+ dict_data.update({"kind": "anonymous"})
+
+ # End of legacy information.
+ return dict_data
- # End of legacy information.
- return dict_data
class SignedAuthContext(AuthContext):
- """ SignedAuthContext represents an auth context loaded from a signed token of some kind,
+ """ SignedAuthContext represents an auth context loaded from a signed token of some kind,
such as a JWT. Unlike ValidatedAuthContext, SignedAuthContext operates lazily, only loading
the actual {user, robot, token, etc} when requested. This allows registry operations that
only need to check if *some* entity is present to do so, without hitting the database.
"""
- def __init__(self, kind, signed_data, v1_dict_format):
- self.kind = kind
- self.signed_data = signed_data
- self.v1_dict_format = v1_dict_format
- @property
- def unique_key(self):
- if self.v1_dict_format:
- # Since V1 data format is verbose, just use the validated version to get the key.
- return self._get_validated().unique_key
+ def __init__(self, kind, signed_data, v1_dict_format):
+ self.kind = kind
+ self.signed_data = signed_data
+ self.v1_dict_format = v1_dict_format
- signed_dict = self.signed_data
- return '%s-%s' % (signed_dict['entity_kind'], signed_dict.get('entity_reference', '(anon)'))
+ @property
+ def unique_key(self):
+ if self.v1_dict_format:
+ # Since V1 data format is verbose, just use the validated version to get the key.
+ return self._get_validated().unique_key
- @classmethod
- def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
- if not v1_dict_format:
- entity_kind = ContextEntityKind(dict_data.get('entity_kind', 'anonymous'))
- return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
+ signed_dict = self.signed_data
+ return "%s-%s" % (
+ signed_dict["entity_kind"],
+ signed_dict.get("entity_reference", "(anon)"),
+ )
- # Legacy handling.
- # TODO: Remove this all once the new code is fully deployed.
- kind_string = dict_data.get('kind', 'anonymous')
- if kind_string == 'oauth':
- kind_string = 'oauthtoken'
+ @classmethod
+ def build_from_signed_dict(cls, dict_data, v1_dict_format=False):
+ if not v1_dict_format:
+ entity_kind = ContextEntityKind(dict_data.get("entity_kind", "anonymous"))
+ return SignedAuthContext(entity_kind, dict_data, v1_dict_format)
- kind = ContextEntityKind(kind_string)
- return SignedAuthContext(kind, dict_data, v1_dict_format)
+ # Legacy handling.
+ # TODO: Remove this all once the new code is fully deployed.
+ kind_string = dict_data.get("kind", "anonymous")
+ if kind_string == "oauth":
+ kind_string = "oauthtoken"
- @lru_cache(maxsize=1)
- def _get_validated(self):
- """ Returns a ValidatedAuthContext for this signed context, resolving all the necessary
+ kind = ContextEntityKind(kind_string)
+ return SignedAuthContext(kind, dict_data, v1_dict_format)
+
+ @lru_cache(maxsize=1)
+ def _get_validated(self):
+ """ Returns a ValidatedAuthContext for this signed context, resolving all the necessary
references.
"""
- if not self.v1_dict_format:
- if self.kind == ContextEntityKind.anonymous:
- return ValidatedAuthContext()
+ if not self.v1_dict_format:
+ if self.kind == ContextEntityKind.anonymous:
+ return ValidatedAuthContext()
- serialized_entity_reference = self.signed_data['entity_reference']
- handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
- entity_reference = handler.deserialize_entity_reference(serialized_entity_reference)
- if entity_reference is None:
- logger.debug('Could not deserialize entity reference `%s` under kind `%s`',
- serialized_entity_reference, self.kind)
- return ValidatedAuthContext()
+ serialized_entity_reference = self.signed_data["entity_reference"]
+ handler = CONTEXT_ENTITY_HANDLERS[self.kind]()
+ entity_reference = handler.deserialize_entity_reference(
+ serialized_entity_reference
+ )
+ if entity_reference is None:
+ logger.debug(
+ "Could not deserialize entity reference `%s` under kind `%s`",
+ serialized_entity_reference,
+ self.kind,
+ )
+ return ValidatedAuthContext()
- return ValidatedAuthContext(**{self.kind.value: entity_reference})
+ return ValidatedAuthContext(**{self.kind.value: entity_reference})
- # Legacy handling.
- # TODO: Remove this all once the new code is fully deployed.
- kind_string = self.signed_data.get('kind', 'anonymous')
- if kind_string == 'oauth':
- kind_string = 'oauthtoken'
+ # Legacy handling.
+ # TODO: Remove this all once the new code is fully deployed.
+ kind_string = self.signed_data.get("kind", "anonymous")
+ if kind_string == "oauth":
+ kind_string = "oauthtoken"
- kind = ContextEntityKind(kind_string)
- if kind == ContextEntityKind.anonymous:
- return ValidatedAuthContext()
+ kind = ContextEntityKind(kind_string)
+ if kind == ContextEntityKind.anonymous:
+ return ValidatedAuthContext()
- if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
- user = model.user.get_user(self.signed_data.get('user', ''))
- if not user:
- return None
+ if kind == ContextEntityKind.user or kind == ContextEntityKind.robot:
+ user = model.user.get_user(self.signed_data.get("user", ""))
+ if not user:
+ return None
- return ValidatedAuthContext(robot=user) if user.robot else ValidatedAuthContext(user=user)
+ return (
+ ValidatedAuthContext(robot=user)
+ if user.robot
+ else ValidatedAuthContext(user=user)
+ )
- if kind == ContextEntityKind.token:
- token = model.token.load_token_data(self.signed_data.get('token'))
- if not token:
- return None
+ if kind == ContextEntityKind.token:
+ token = model.token.load_token_data(self.signed_data.get("token"))
+ if not token:
+ return None
- return ValidatedAuthContext(token=token)
+ return ValidatedAuthContext(token=token)
- if kind == ContextEntityKind.oauthtoken:
- user = model.user.get_user(self.signed_data.get('user', ''))
- if not user:
- return None
+ if kind == ContextEntityKind.oauthtoken:
+ user = model.user.get_user(self.signed_data.get("user", ""))
+ if not user:
+ return None
- token_uuid = self.signed_data.get('oauth', '')
- oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
- if not oauthtoken:
- return None
+ token_uuid = self.signed_data.get("oauth", "")
+ oauthtoken = model.oauth.lookup_access_token_for_user(user, token_uuid)
+ if not oauthtoken:
+ return None
- return ValidatedAuthContext(oauthtoken=oauthtoken)
+ return ValidatedAuthContext(oauthtoken=oauthtoken)
- raise Exception('Unknown auth context kind `%s` when deserializing %s' % (kind,
- self.signed_data))
- # End of legacy handling.
+ raise Exception(
+ "Unknown auth context kind `%s` when deserializing %s"
+ % (kind, self.signed_data)
+ )
+ # End of legacy handling.
- @property
- def entity_kind(self):
- """ Returns the kind of the entity in this auth context. """
- return self.kind
+ @property
+ def entity_kind(self):
+ """ Returns the kind of the entity in this auth context. """
+ return self.kind
- @property
- def is_anonymous(self):
- """ Returns true if this is an anonymous context. """
- return self.kind == ContextEntityKind.anonymous
+ @property
+ def is_anonymous(self):
+ """ Returns true if this is an anonymous context. """
+ return self.kind == ContextEntityKind.anonymous
- @property
- def authed_user(self):
- """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
+ @property
+ def authed_user(self):
+ """ Returns the authenticated user, whether directly, or via an OAuth or access token. Note that
this property will also return robot accounts.
"""
- if self.kind == ContextEntityKind.anonymous:
- return None
+ if self.kind == ContextEntityKind.anonymous:
+ return None
- return self._get_validated().authed_user
+ return self._get_validated().authed_user
- @property
- def authed_oauth_token(self):
- if self.kind == ContextEntityKind.anonymous:
- return None
+ @property
+ def authed_oauth_token(self):
+ if self.kind == ContextEntityKind.anonymous:
+ return None
- return self._get_validated().authed_oauth_token
+ return self._get_validated().authed_oauth_token
- @property
- def has_nonrobot_user(self):
- """ Returns whether a user (not a robot) was authenticated successfully. """
- if self.kind == ContextEntityKind.anonymous:
- return False
+ @property
+ def has_nonrobot_user(self):
+ """ Returns whether a user (not a robot) was authenticated successfully. """
+ if self.kind == ContextEntityKind.anonymous:
+ return False
- return self._get_validated().has_nonrobot_user
+ return self._get_validated().has_nonrobot_user
- @property
- def identity(self):
- """ Returns the identity for the auth context. """
- return self._get_validated().identity
+ @property
+ def identity(self):
+ """ Returns the identity for the auth context. """
+ return self._get_validated().identity
- @property
- def description(self):
- """ Returns a human-readable and *public* description of the current auth context. """
- return self._get_validated().description
+ @property
+ def description(self):
+ """ Returns a human-readable and *public* description of the current auth context. """
+ return self._get_validated().description
- @property
- def credential_username(self):
- """ Returns the username to create credentials for this context's entity, if any. """
- return self._get_validated().credential_username
+ @property
+ def credential_username(self):
+ """ Returns the username to create credentials for this context's entity, if any. """
+ return self._get_validated().credential_username
- def analytics_id_and_public_metadata(self):
- """ Returns the analytics ID and public log metadata for this auth context. """
- return self._get_validated().analytics_id_and_public_metadata()
+ def analytics_id_and_public_metadata(self):
+ """ Returns the analytics ID and public log metadata for this auth context. """
+ return self._get_validated().analytics_id_and_public_metadata()
- def apply_to_request_context(self):
- """ Applies this auth result to the auth context and Flask-Principal. """
- return self._get_validated().apply_to_request_context()
+ def apply_to_request_context(self):
+ """ Applies this auth result to the auth context and Flask-Principal. """
+ return self._get_validated().apply_to_request_context()
- def to_signed_dict(self):
- """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
+ def to_signed_dict(self):
+ """ Serializes the auth context into a dictionary suitable for inclusion in a JWT or other
form of signed serialization.
"""
- return self.signed_data
+ return self.signed_data
diff --git a/auth/basic.py b/auth/basic.py
index 926450ad6..49d0150a4 100644
--- a/auth/basic.py
+++ b/auth/basic.py
@@ -8,51 +8,54 @@ from auth.validateresult import ValidateResult, AuthKind
logger = logging.getLogger(__name__)
+
def has_basic_auth(username):
- """ Returns true if a basic auth header exists with a username and password pair that validates
+ """ Returns true if a basic auth header exists with a username and password pair that validates
against the internal authentication system. Returns True on full success and False on any
failure (missing header, invalid header, invalid credentials, etc).
"""
- auth_header = request.headers.get('authorization', '')
- result = validate_basic_auth(auth_header)
- return result.has_nonrobot_user and result.context.user.username == username
+ auth_header = request.headers.get("authorization", "")
+ result = validate_basic_auth(auth_header)
+ return result.has_nonrobot_user and result.context.user.username == username
def validate_basic_auth(auth_header):
- """ Validates the specified basic auth header, returning whether its credentials point
+ """ Validates the specified basic auth header, returning whether its credentials point
to a valid user or token.
"""
- if not auth_header:
- return ValidateResult(AuthKind.basic, missing=True)
+ if not auth_header:
+ return ValidateResult(AuthKind.basic, missing=True)
- logger.debug('Attempt to process basic auth header')
+ logger.debug("Attempt to process basic auth header")
- # Parse the basic auth header.
- assert isinstance(auth_header, basestring)
- credentials, err = _parse_basic_auth_header(auth_header)
- if err is not None:
- logger.debug('Got invalid basic auth header: %s', auth_header)
- return ValidateResult(AuthKind.basic, missing=True)
+ # Parse the basic auth header.
+ assert isinstance(auth_header, basestring)
+ credentials, err = _parse_basic_auth_header(auth_header)
+ if err is not None:
+ logger.debug("Got invalid basic auth header: %s", auth_header)
+ return ValidateResult(AuthKind.basic, missing=True)
- auth_username, auth_password_or_token = credentials
- result, _ = validate_credentials(auth_username, auth_password_or_token)
- return result.with_kind(AuthKind.basic)
+ auth_username, auth_password_or_token = credentials
+ result, _ = validate_credentials(auth_username, auth_password_or_token)
+ return result.with_kind(AuthKind.basic)
def _parse_basic_auth_header(auth):
- """ Parses the given basic auth header, returning the credentials found inside.
+ """ Parses the given basic auth header, returning the credentials found inside.
"""
- normalized = [part.strip() for part in auth.split(' ') if part]
- if normalized[0].lower() != 'basic' or len(normalized) != 2:
- return None, 'Invalid basic auth header'
+ normalized = [part.strip() for part in auth.split(" ") if part]
+ if normalized[0].lower() != "basic" or len(normalized) != 2:
+ return None, "Invalid basic auth header"
- try:
- credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
- except (TypeError, UnicodeDecodeError, ValueError):
- logger.exception('Exception when parsing basic auth header: %s', auth)
- return None, 'Could not parse basic auth header'
+ try:
+ credentials = [
+ part.decode("utf-8") for part in b64decode(normalized[1]).split(":", 1)
+ ]
+ except (TypeError, UnicodeDecodeError, ValueError):
+ logger.exception("Exception when parsing basic auth header: %s", auth)
+ return None, "Could not parse basic auth header"
- if len(credentials) != 2:
- return None, 'Unexpected number of credentials found in basic auth header'
+ if len(credentials) != 2:
+ return None, "Unexpected number of credentials found in basic auth header"
- return credentials, None
+ return credentials, None
diff --git a/auth/context_entity.py b/auth/context_entity.py
index 038624b0c..7c52dbe8d 100644
--- a/auth/context_entity.py
+++ b/auth/context_entity.py
@@ -4,200 +4,210 @@ from enum import Enum
from data import model
-from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
- APP_SPECIFIC_TOKEN_USERNAME)
+from auth.credential_consts import (
+ ACCESS_TOKEN_USERNAME,
+ OAUTH_TOKEN_USERNAME,
+ APP_SPECIFIC_TOKEN_USERNAME,
+)
+
class ContextEntityKind(Enum):
- """ Defines the various kinds of entities in an auth context. Note that the string values of
+ """ Defines the various kinds of entities in an auth context. Note that the string values of
these fields *must* match the names of the fields in the ValidatedAuthContext class, as
we fill them in directly based on the string names here.
"""
- anonymous = 'anonymous'
- user = 'user'
- robot = 'robot'
- token = 'token'
- oauthtoken = 'oauthtoken'
- appspecifictoken = 'appspecifictoken'
- signed_data = 'signed_data'
+
+ anonymous = "anonymous"
+ user = "user"
+ robot = "robot"
+ token = "token"
+ oauthtoken = "oauthtoken"
+ appspecifictoken = "appspecifictoken"
+ signed_data = "signed_data"
@add_metaclass(ABCMeta)
class ContextEntityHandler(object):
- """
+ """
Interface that represents handling specific kinds of entities under an auth context.
"""
- @abstractmethod
- def credential_username(self, entity_reference):
- """ Returns the username to create credentials for this entity, if any. """
- pass
+ @abstractmethod
+ def credential_username(self, entity_reference):
+ """ Returns the username to create credentials for this entity, if any. """
+ pass
- @abstractmethod
- def get_serialized_entity_reference(self, entity_reference):
- """ Returns the entity reference for this kind of auth context, serialized into a form that can
+ @abstractmethod
+ def get_serialized_entity_reference(self, entity_reference):
+ """ Returns the entity reference for this kind of auth context, serialized into a form that can
be placed into a JSON object and put into a JWT. This is typically a DB UUID or another
unique identifier for the object in the DB.
"""
- pass
+ pass
- @abstractmethod
- def deserialize_entity_reference(self, serialized_entity_reference):
- """ Returns the deserialized reference to the entity in the database, or None if none. """
- pass
+ @abstractmethod
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ """ Returns the deserialized reference to the entity in the database, or None if none. """
+ pass
- @abstractmethod
- def description(self, entity_reference):
- """ Returns a human-readable and *public* description of the current entity. """
- pass
+ @abstractmethod
+ def description(self, entity_reference):
+ """ Returns a human-readable and *public* description of the current entity. """
+ pass
- @abstractmethod
- def analytics_id_and_public_metadata(self, entity_reference):
- """ Returns the analyitics ID and a dict of public metadata for the current entity. """
- pass
+ @abstractmethod
+ def analytics_id_and_public_metadata(self, entity_reference):
+ """ Returns the analyitics ID and a dict of public metadata for the current entity. """
+ pass
class AnonymousEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return None
+ def credential_username(self, entity_reference):
+ return None
- def get_serialized_entity_reference(self, entity_reference):
- return None
+ def get_serialized_entity_reference(self, entity_reference):
+ return None
- def deserialize_entity_reference(self, serialized_entity_reference):
- return None
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return None
- def description(self, entity_reference):
- return "anonymous"
+ def description(self, entity_reference):
+ return "anonymous"
- def analytics_id_and_public_metadata(self, entity_reference):
- return "anonymous", {}
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return "anonymous", {}
class UserEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return entity_reference.username
+ def credential_username(self, entity_reference):
+ return entity_reference.username
- def get_serialized_entity_reference(self, entity_reference):
- return entity_reference.uuid
+ def get_serialized_entity_reference(self, entity_reference):
+ return entity_reference.uuid
- def deserialize_entity_reference(self, serialized_entity_reference):
- return model.user.get_user_by_uuid(serialized_entity_reference)
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return model.user.get_user_by_uuid(serialized_entity_reference)
- def description(self, entity_reference):
- return "user %s" % entity_reference.username
+ def description(self, entity_reference):
+ return "user %s" % entity_reference.username
- def analytics_id_and_public_metadata(self, entity_reference):
- return entity_reference.username, {
- 'username': entity_reference.username,
- }
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return entity_reference.username, {"username": entity_reference.username}
class RobotEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return entity_reference.username
+ def credential_username(self, entity_reference):
+ return entity_reference.username
- def get_serialized_entity_reference(self, entity_reference):
- return entity_reference.username
+ def get_serialized_entity_reference(self, entity_reference):
+ return entity_reference.username
- def deserialize_entity_reference(self, serialized_entity_reference):
- return model.user.lookup_robot(serialized_entity_reference)
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return model.user.lookup_robot(serialized_entity_reference)
- def description(self, entity_reference):
- return "robot %s" % entity_reference.username
+ def description(self, entity_reference):
+ return "robot %s" % entity_reference.username
- def analytics_id_and_public_metadata(self, entity_reference):
- return entity_reference.username, {
- 'username': entity_reference.username,
- 'is_robot': True,
- }
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return (
+ entity_reference.username,
+ {"username": entity_reference.username, "is_robot": True},
+ )
class TokenEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return ACCESS_TOKEN_USERNAME
+ def credential_username(self, entity_reference):
+ return ACCESS_TOKEN_USERNAME
- def get_serialized_entity_reference(self, entity_reference):
- return entity_reference.get_code()
+ def get_serialized_entity_reference(self, entity_reference):
+ return entity_reference.get_code()
- def deserialize_entity_reference(self, serialized_entity_reference):
- return model.token.load_token_data(serialized_entity_reference)
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return model.token.load_token_data(serialized_entity_reference)
- def description(self, entity_reference):
- return "token %s" % entity_reference.friendly_name
+ def description(self, entity_reference):
+ return "token %s" % entity_reference.friendly_name
- def analytics_id_and_public_metadata(self, entity_reference):
- return 'token:%s' % entity_reference.id, {
- 'token': entity_reference.friendly_name,
- }
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return (
+ "token:%s" % entity_reference.id,
+ {"token": entity_reference.friendly_name},
+ )
class OAuthTokenEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return OAUTH_TOKEN_USERNAME
+ def credential_username(self, entity_reference):
+ return OAUTH_TOKEN_USERNAME
- def get_serialized_entity_reference(self, entity_reference):
- return entity_reference.uuid
+ def get_serialized_entity_reference(self, entity_reference):
+ return entity_reference.uuid
- def deserialize_entity_reference(self, serialized_entity_reference):
- return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference)
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return model.oauth.lookup_access_token_by_uuid(serialized_entity_reference)
- def description(self, entity_reference):
- return "oauthtoken for user %s" % entity_reference.authorized_user.username
+ def description(self, entity_reference):
+ return "oauthtoken for user %s" % entity_reference.authorized_user.username
- def analytics_id_and_public_metadata(self, entity_reference):
- return 'oauthtoken:%s' % entity_reference.id, {
- 'oauth_token_id': entity_reference.id,
- 'oauth_token_application_id': entity_reference.application.client_id,
- 'oauth_token_application': entity_reference.application.name,
- 'username': entity_reference.authorized_user.username,
- }
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return (
+ "oauthtoken:%s" % entity_reference.id,
+ {
+ "oauth_token_id": entity_reference.id,
+ "oauth_token_application_id": entity_reference.application.client_id,
+ "oauth_token_application": entity_reference.application.name,
+ "username": entity_reference.authorized_user.username,
+ },
+ )
class AppSpecificTokenEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return APP_SPECIFIC_TOKEN_USERNAME
+ def credential_username(self, entity_reference):
+ return APP_SPECIFIC_TOKEN_USERNAME
- def get_serialized_entity_reference(self, entity_reference):
- return entity_reference.uuid
+ def get_serialized_entity_reference(self, entity_reference):
+ return entity_reference.uuid
- def deserialize_entity_reference(self, serialized_entity_reference):
- return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference)
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ return model.appspecifictoken.get_token_by_uuid(serialized_entity_reference)
- def description(self, entity_reference):
- tpl = (entity_reference.title, entity_reference.user.username)
- return "app specific token %s for user %s" % tpl
+ def description(self, entity_reference):
+ tpl = (entity_reference.title, entity_reference.user.username)
+ return "app specific token %s for user %s" % tpl
- def analytics_id_and_public_metadata(self, entity_reference):
- return 'appspecifictoken:%s' % entity_reference.id, {
- 'app_specific_token': entity_reference.uuid,
- 'app_specific_token_title': entity_reference.title,
- 'username': entity_reference.user.username,
- }
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return (
+ "appspecifictoken:%s" % entity_reference.id,
+ {
+ "app_specific_token": entity_reference.uuid,
+ "app_specific_token_title": entity_reference.title,
+ "username": entity_reference.user.username,
+ },
+ )
class SignedDataEntityHandler(ContextEntityHandler):
- def credential_username(self, entity_reference):
- return None
+ def credential_username(self, entity_reference):
+ return None
- def get_serialized_entity_reference(self, entity_reference):
- raise NotImplementedError
+ def get_serialized_entity_reference(self, entity_reference):
+ raise NotImplementedError
- def deserialize_entity_reference(self, serialized_entity_reference):
- raise NotImplementedError
+ def deserialize_entity_reference(self, serialized_entity_reference):
+ raise NotImplementedError
- def description(self, entity_reference):
- return "signed"
+ def description(self, entity_reference):
+ return "signed"
- def analytics_id_and_public_metadata(self, entity_reference):
- return 'signed', {'signed': entity_reference}
+ def analytics_id_and_public_metadata(self, entity_reference):
+ return "signed", {"signed": entity_reference}
CONTEXT_ENTITY_HANDLERS = {
- ContextEntityKind.anonymous: AnonymousEntityHandler,
- ContextEntityKind.user: UserEntityHandler,
- ContextEntityKind.robot: RobotEntityHandler,
- ContextEntityKind.token: TokenEntityHandler,
- ContextEntityKind.oauthtoken: OAuthTokenEntityHandler,
- ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler,
- ContextEntityKind.signed_data: SignedDataEntityHandler,
+ ContextEntityKind.anonymous: AnonymousEntityHandler,
+ ContextEntityKind.user: UserEntityHandler,
+ ContextEntityKind.robot: RobotEntityHandler,
+ ContextEntityKind.token: TokenEntityHandler,
+ ContextEntityKind.oauthtoken: OAuthTokenEntityHandler,
+ ContextEntityKind.appspecifictoken: AppSpecificTokenEntityHandler,
+ ContextEntityKind.signed_data: SignedDataEntityHandler,
}
diff --git a/auth/cookie.py b/auth/cookie.py
index 68ed0f8ee..839183f32 100644
--- a/auth/cookie.py
+++ b/auth/cookie.py
@@ -7,31 +7,40 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__)
+
def validate_session_cookie(auth_header_unusued=None):
- """ Attempts to load a user from a session cookie. """
- if current_user.is_anonymous:
- return ValidateResult(AuthKind.cookie, missing=True)
+ """ Attempts to load a user from a session cookie. """
+ if current_user.is_anonymous:
+ return ValidateResult(AuthKind.cookie, missing=True)
- try:
- # Attempt to parse the user uuid to make sure the cookie has the right value type
- UUID(current_user.get_id())
- except ValueError:
- logger.debug('Got non-UUID for session cookie user: %s', current_user.get_id())
- return ValidateResult(AuthKind.cookie, error_message='Invalid session cookie format')
+ try:
+ # Attempt to parse the user uuid to make sure the cookie has the right value type
+ UUID(current_user.get_id())
+ except ValueError:
+ logger.debug("Got non-UUID for session cookie user: %s", current_user.get_id())
+ return ValidateResult(
+ AuthKind.cookie, error_message="Invalid session cookie format"
+ )
- logger.debug('Loading user from cookie: %s', current_user.get_id())
- db_user = current_user.db_user()
- if db_user is None:
- return ValidateResult(AuthKind.cookie, error_message='Could not find matching user')
+ logger.debug("Loading user from cookie: %s", current_user.get_id())
+ db_user = current_user.db_user()
+ if db_user is None:
+ return ValidateResult(
+ AuthKind.cookie, error_message="Could not find matching user"
+ )
- # Don't allow disabled users to login.
- if not db_user.enabled:
- logger.debug('User %s in session cookie is disabled', db_user.username)
- return ValidateResult(AuthKind.cookie, error_message='User account is disabled')
+ # Don't allow disabled users to login.
+ if not db_user.enabled:
+ logger.debug("User %s in session cookie is disabled", db_user.username)
+ return ValidateResult(AuthKind.cookie, error_message="User account is disabled")
- # Don't allow organizations to "login".
- if db_user.organization:
- logger.debug('User %s in session cookie is in-fact organization', db_user.username)
- return ValidateResult(AuthKind.cookie, error_message='Cannot login to organization')
+ # Don't allow organizations to "login".
+ if db_user.organization:
+ logger.debug(
+ "User %s in session cookie is in-fact organization", db_user.username
+ )
+ return ValidateResult(
+ AuthKind.cookie, error_message="Cannot login to organization"
+ )
- return ValidateResult(AuthKind.cookie, user=db_user)
+ return ValidateResult(AuthKind.cookie, user=db_user)
diff --git a/auth/credential_consts.py b/auth/credential_consts.py
index dda9834d1..93287d833 100644
--- a/auth/credential_consts.py
+++ b/auth/credential_consts.py
@@ -1,3 +1,3 @@
-ACCESS_TOKEN_USERNAME = '$token'
-OAUTH_TOKEN_USERNAME = '$oauthtoken'
-APP_SPECIFIC_TOKEN_USERNAME = '$app'
+ACCESS_TOKEN_USERNAME = "$token"
+OAUTH_TOKEN_USERNAME = "$oauthtoken"
+APP_SPECIFIC_TOKEN_USERNAME = "$app"
diff --git a/auth/credentials.py b/auth/credentials.py
index 5d8c8b4dd..f56f6a540 100644
--- a/auth/credentials.py
+++ b/auth/credentials.py
@@ -7,8 +7,11 @@ import features
from app import authentication
from auth.oauth import validate_oauth_token
from auth.validateresult import ValidateResult, AuthKind
-from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
- APP_SPECIFIC_TOKEN_USERNAME)
+from auth.credential_consts import (
+ ACCESS_TOKEN_USERNAME,
+ OAUTH_TOKEN_USERNAME,
+ APP_SPECIFIC_TOKEN_USERNAME,
+)
from data import model
from util.names import parse_robot_username
@@ -16,70 +19,116 @@ logger = logging.getLogger(__name__)
class CredentialKind(Enum):
- user = 'user'
- robot = 'robot'
- token = ACCESS_TOKEN_USERNAME
- oauth_token = OAUTH_TOKEN_USERNAME
- app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
+ user = "user"
+ robot = "robot"
+ token = ACCESS_TOKEN_USERNAME
+ oauth_token = OAUTH_TOKEN_USERNAME
+ app_specific_token = APP_SPECIFIC_TOKEN_USERNAME
def validate_credentials(auth_username, auth_password_or_token):
- """ Validates a pair of auth username and password/token credentials. """
- # Check for access tokens.
- if auth_username == ACCESS_TOKEN_USERNAME:
- logger.debug('Found credentials for access token')
- try:
- token = model.token.load_token_data(auth_password_or_token)
- logger.debug('Successfully validated credentials for access token %s', token.id)
- return ValidateResult(AuthKind.credentials, token=token), CredentialKind.token
- except model.DataModelException:
- logger.warning('Failed to validate credentials for access token %s', auth_password_or_token)
- return (ValidateResult(AuthKind.credentials, error_message='Invalid access token'),
- CredentialKind.token)
+ """ Validates a pair of auth username and password/token credentials. """
+ # Check for access tokens.
+ if auth_username == ACCESS_TOKEN_USERNAME:
+ logger.debug("Found credentials for access token")
+ try:
+ token = model.token.load_token_data(auth_password_or_token)
+ logger.debug(
+ "Successfully validated credentials for access token %s", token.id
+ )
+ return (
+ ValidateResult(AuthKind.credentials, token=token),
+ CredentialKind.token,
+ )
+ except model.DataModelException:
+ logger.warning(
+ "Failed to validate credentials for access token %s",
+ auth_password_or_token,
+ )
+ return (
+ ValidateResult(
+ AuthKind.credentials, error_message="Invalid access token"
+ ),
+ CredentialKind.token,
+ )
- # Check for App Specific tokens.
- if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
- logger.debug('Found credentials for app specific auth token')
- token = model.appspecifictoken.access_valid_token(auth_password_or_token)
- if token is None:
- logger.debug('Failed to validate credentials for app specific token: %s',
- auth_password_or_token)
- return (ValidateResult(AuthKind.credentials, error_message='Invalid token'),
- CredentialKind.app_specific_token)
+ # Check for App Specific tokens.
+ if features.APP_SPECIFIC_TOKENS and auth_username == APP_SPECIFIC_TOKEN_USERNAME:
+ logger.debug("Found credentials for app specific auth token")
+ token = model.appspecifictoken.access_valid_token(auth_password_or_token)
+ if token is None:
+ logger.debug(
+ "Failed to validate credentials for app specific token: %s",
+ auth_password_or_token,
+ )
+ return (
+ ValidateResult(AuthKind.credentials, error_message="Invalid token"),
+ CredentialKind.app_specific_token,
+ )
- if not token.user.enabled:
- logger.debug('Tried to use an app specific token for a disabled user: %s',
- token.uuid)
- return (ValidateResult(AuthKind.credentials,
- error_message='This user has been disabled. Please contact your administrator.'),
- CredentialKind.app_specific_token)
+ if not token.user.enabled:
+ logger.debug(
+ "Tried to use an app specific token for a disabled user: %s", token.uuid
+ )
+ return (
+ ValidateResult(
+ AuthKind.credentials,
+ error_message="This user has been disabled. Please contact your administrator.",
+ ),
+ CredentialKind.app_specific_token,
+ )
- logger.debug('Successfully validated credentials for app specific token %s', token.id)
- return (ValidateResult(AuthKind.credentials, appspecifictoken=token),
- CredentialKind.app_specific_token)
+ logger.debug(
+ "Successfully validated credentials for app specific token %s", token.id
+ )
+ return (
+ ValidateResult(AuthKind.credentials, appspecifictoken=token),
+ CredentialKind.app_specific_token,
+ )
- # Check for OAuth tokens.
- if auth_username == OAUTH_TOKEN_USERNAME:
- return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token
+ # Check for OAuth tokens.
+ if auth_username == OAUTH_TOKEN_USERNAME:
+ return validate_oauth_token(auth_password_or_token), CredentialKind.oauth_token
- # Check for robots and users.
- is_robot = parse_robot_username(auth_username)
- if is_robot:
- logger.debug('Found credentials header for robot %s', auth_username)
- try:
- robot = model.user.verify_robot(auth_username, auth_password_or_token)
- logger.debug('Successfully validated credentials for robot %s', auth_username)
- return ValidateResult(AuthKind.credentials, robot=robot), CredentialKind.robot
- except model.InvalidRobotException as ire:
- logger.warning('Failed to validate credentials for robot %s: %s', auth_username, ire)
- return ValidateResult(AuthKind.credentials, error_message=str(ire)), CredentialKind.robot
+ # Check for robots and users.
+ is_robot = parse_robot_username(auth_username)
+ if is_robot:
+ logger.debug("Found credentials header for robot %s", auth_username)
+ try:
+ robot = model.user.verify_robot(auth_username, auth_password_or_token)
+ logger.debug(
+ "Successfully validated credentials for robot %s", auth_username
+ )
+ return (
+ ValidateResult(AuthKind.credentials, robot=robot),
+ CredentialKind.robot,
+ )
+ except model.InvalidRobotException as ire:
+ logger.warning(
+ "Failed to validate credentials for robot %s: %s", auth_username, ire
+ )
+ return (
+ ValidateResult(AuthKind.credentials, error_message=str(ire)),
+ CredentialKind.robot,
+ )
- # Otherwise, treat as a standard user.
- (authenticated, err) = authentication.verify_and_link_user(auth_username, auth_password_or_token,
- basic_auth=True)
- if authenticated:
- logger.debug('Successfully validated credentials for user %s', authenticated.username)
- return ValidateResult(AuthKind.credentials, user=authenticated), CredentialKind.user
- else:
- logger.warning('Failed to validate credentials for user %s: %s', auth_username, err)
- return ValidateResult(AuthKind.credentials, error_message=err), CredentialKind.user
+ # Otherwise, treat as a standard user.
+ (authenticated, err) = authentication.verify_and_link_user(
+ auth_username, auth_password_or_token, basic_auth=True
+ )
+ if authenticated:
+ logger.debug(
+ "Successfully validated credentials for user %s", authenticated.username
+ )
+ return (
+ ValidateResult(AuthKind.credentials, user=authenticated),
+ CredentialKind.user,
+ )
+ else:
+ logger.warning(
+ "Failed to validate credentials for user %s: %s", auth_username, err
+ )
+ return (
+ ValidateResult(AuthKind.credentials, error_message=err),
+ CredentialKind.user,
+ )
diff --git a/auth/decorators.py b/auth/decorators.py
index 5fc966140..6e5a0cf05 100644
--- a/auth/decorators.py
+++ b/auth/decorators.py
@@ -14,83 +14,101 @@ from util.http import abort
logger = logging.getLogger(__name__)
+
def _auth_decorator(pass_result=False, handlers=None):
- """ Builds an auth decorator that runs the given handlers and, if any return successfully,
+ """ Builds an auth decorator that runs the given handlers and, if any return successfully,
sets up the auth context. The wrapped function will be invoked *regardless of success or
failure of the auth handler(s)*
"""
- def processor(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- auth_header = request.headers.get('authorization', '')
- result = None
- for handler in handlers:
- result = handler(auth_header)
- # If the handler was missing the necessary information, skip it and try the next one.
- if result.missing:
- continue
+ def processor(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ auth_header = request.headers.get("authorization", "")
+ result = None
- # Check for a valid result.
- if result.auth_valid:
- logger.debug('Found valid auth result: %s', result.tuple())
+ for handler in handlers:
+ result = handler(auth_header)
+ # If the handler was missing the necessary information, skip it and try the next one.
+ if result.missing:
+ continue
- # Set the various pieces of the auth context.
- result.apply_to_context()
+ # Check for a valid result.
+ if result.auth_valid:
+ logger.debug("Found valid auth result: %s", result.tuple())
- # Log the metric.
- metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
- break
+ # Set the various pieces of the auth context.
+ result.apply_to_context()
- # Otherwise, report the error.
- if result.error_message is not None:
- # Log the failure.
- metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
- break
+ # Log the metric.
+ metric_queue.authentication_count.Inc(
+ labelvalues=[result.kind, True]
+ )
+ break
- if pass_result:
- kwargs['auth_result'] = result
+ # Otherwise, report the error.
+ if result.error_message is not None:
+ # Log the failure.
+ metric_queue.authentication_count.Inc(
+ labelvalues=[result.kind, False]
+ )
+ break
- return func(*args, **kwargs)
- return wrapper
- return processor
+ if pass_result:
+ kwargs["auth_result"] = result
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return processor
-process_oauth = _auth_decorator(handlers=[validate_bearer_auth, validate_session_cookie])
+process_oauth = _auth_decorator(
+ handlers=[validate_bearer_auth, validate_session_cookie]
+)
process_auth = _auth_decorator(handlers=[validate_signed_grant, validate_basic_auth])
-process_auth_or_cookie = _auth_decorator(handlers=[validate_basic_auth, validate_session_cookie])
+process_auth_or_cookie = _auth_decorator(
+ handlers=[validate_basic_auth, validate_session_cookie]
+)
process_basic_auth = _auth_decorator(handlers=[validate_basic_auth], pass_result=True)
process_basic_auth_no_pass = _auth_decorator(handlers=[validate_basic_auth])
def require_session_login(func):
- """ Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
+ """ Decorates a function and ensures that a valid session cookie exists or a 401 is raised. If
a valid session cookie does exist, the authenticated user and identity are also set.
"""
- @wraps(func)
- def wrapper(*args, **kwargs):
- result = validate_session_cookie()
- if result.has_nonrobot_user:
- result.apply_to_context()
- metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
- return func(*args, **kwargs)
- elif not result.missing:
- metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
- abort(401, message='Method requires login and no valid login could be loaded.')
- return wrapper
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ result = validate_session_cookie()
+ if result.has_nonrobot_user:
+ result.apply_to_context()
+ metric_queue.authentication_count.Inc(labelvalues=[result.kind, True])
+ return func(*args, **kwargs)
+ elif not result.missing:
+ metric_queue.authentication_count.Inc(labelvalues=[result.kind, False])
+
+ abort(401, message="Method requires login and no valid login could be loaded.")
+
+ return wrapper
def extract_namespace_repo_from_session(func):
- """ Extracts the namespace and repository name from the current session (which must exist)
+ """ Extracts the namespace and repository name from the current session (which must exist)
and passes them into the decorated function as the first and second arguments. If the
session doesn't exist or does not contain these arugments, a 400 error is raised.
"""
- @wraps(func)
- def wrapper(*args, **kwargs):
- if 'namespace' not in session or 'repository' not in session:
- logger.error('Unable to load namespace or repository from session: %s', session)
- abort(400, message='Missing namespace in request')
- return func(session['namespace'], session['repository'], *args, **kwargs)
- return wrapper
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if "namespace" not in session or "repository" not in session:
+ logger.error(
+ "Unable to load namespace or repository from session: %s", session
+ )
+ abort(400, message="Missing namespace in request")
+
+ return func(session["namespace"], session["repository"], *args, **kwargs)
+
+ return wrapper
diff --git a/auth/oauth.py b/auth/oauth.py
index aaea92831..b9f3ca7bd 100644
--- a/auth/oauth.py
+++ b/auth/oauth.py
@@ -8,41 +8,47 @@ from data import model
logger = logging.getLogger(__name__)
+
def validate_bearer_auth(auth_header):
- """ Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
+ """ Validates an OAuth token found inside a basic auth `Bearer` token, returning whether it
points to a valid OAuth token.
"""
- if not auth_header:
- return ValidateResult(AuthKind.oauth, missing=True)
+ if not auth_header:
+ return ValidateResult(AuthKind.oauth, missing=True)
- normalized = [part.strip() for part in auth_header.split(' ') if part]
- if normalized[0].lower() != 'bearer' or len(normalized) != 2:
- logger.debug('Got invalid bearer token format: %s', auth_header)
- return ValidateResult(AuthKind.oauth, missing=True)
+ normalized = [part.strip() for part in auth_header.split(" ") if part]
+ if normalized[0].lower() != "bearer" or len(normalized) != 2:
+ logger.debug("Got invalid bearer token format: %s", auth_header)
+ return ValidateResult(AuthKind.oauth, missing=True)
- (_, oauth_token) = normalized
- return validate_oauth_token(oauth_token)
+ (_, oauth_token) = normalized
+ return validate_oauth_token(oauth_token)
def validate_oauth_token(token):
- """ Validates the specified OAuth token, returning whether it points to a valid OAuth token.
+ """ Validates the specified OAuth token, returning whether it points to a valid OAuth token.
"""
- validated = model.oauth.validate_access_token(token)
- if not validated:
- logger.warning('OAuth access token could not be validated: %s', token)
- return ValidateResult(AuthKind.oauth,
- error_message='OAuth access token could not be validated')
+ validated = model.oauth.validate_access_token(token)
+ if not validated:
+ logger.warning("OAuth access token could not be validated: %s", token)
+ return ValidateResult(
+ AuthKind.oauth, error_message="OAuth access token could not be validated"
+ )
- if validated.expires_at <= datetime.utcnow():
- logger.warning('OAuth access with an expired token: %s', token)
- return ValidateResult(AuthKind.oauth, error_message='OAuth access token has expired')
+ if validated.expires_at <= datetime.utcnow():
+ logger.warning("OAuth access with an expired token: %s", token)
+ return ValidateResult(
+ AuthKind.oauth, error_message="OAuth access token has expired"
+ )
- # Don't allow disabled users to login.
- if not validated.authorized_user.enabled:
- return ValidateResult(AuthKind.oauth,
- error_message='Granter of the oauth access token is disabled')
+ # Don't allow disabled users to login.
+ if not validated.authorized_user.enabled:
+ return ValidateResult(
+ AuthKind.oauth,
+ error_message="Granter of the oauth access token is disabled",
+ )
- # We have a valid token
- scope_set = scopes_from_scope_string(validated.scope)
- logger.debug('Successfully validated oauth access token with scope: %s', scope_set)
- return ValidateResult(AuthKind.oauth, oauthtoken=validated)
+ # We have a valid token
+ scope_set = scopes_from_scope_string(validated.scope)
+ logger.debug("Successfully validated oauth access token with scope: %s", scope_set)
+ return ValidateResult(AuthKind.oauth, oauthtoken=validated)
diff --git a/auth/permissions.py b/auth/permissions.py
index c967aa046..10419acbc 100644
--- a/auth/permissions.py
+++ b/auth/permissions.py
@@ -14,351 +14,399 @@ from data import model
logger = logging.getLogger(__name__)
-_ResourceNeed = namedtuple('resource', ['type', 'namespace', 'name', 'role'])
-_RepositoryNeed = partial(_ResourceNeed, 'repository')
-_NamespaceWideNeed = namedtuple('namespacewide', ['type', 'namespace', 'role'])
-_OrganizationNeed = partial(_NamespaceWideNeed, 'organization')
-_OrganizationRepoNeed = partial(_NamespaceWideNeed, 'organizationrepo')
-_TeamTypeNeed = namedtuple('teamwideneed', ['type', 'orgname', 'teamname', 'role'])
-_TeamNeed = partial(_TeamTypeNeed, 'orgteam')
-_UserTypeNeed = namedtuple('userspecificneed', ['type', 'username', 'role'])
-_UserNeed = partial(_UserTypeNeed, 'user')
-_SuperUserNeed = partial(namedtuple('superuserneed', ['type']), 'superuser')
+_ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"])
+_RepositoryNeed = partial(_ResourceNeed, "repository")
+_NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"])
+_OrganizationNeed = partial(_NamespaceWideNeed, "organization")
+_OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo")
+_TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"])
+_TeamNeed = partial(_TeamTypeNeed, "orgteam")
+_UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"])
+_UserNeed = partial(_UserTypeNeed, "user")
+_SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser")
-REPO_ROLES = [None, 'read', 'write', 'admin']
-TEAM_ROLES = [None, 'member', 'creator', 'admin']
-USER_ROLES = [None, 'read', 'admin']
+REPO_ROLES = [None, "read", "write", "admin"]
+TEAM_ROLES = [None, "member", "creator", "admin"]
+USER_ROLES = [None, "read", "admin"]
-TEAM_ORGWIDE_REPO_ROLES = {
- 'admin': 'admin',
- 'creator': None,
- 'member': None,
-}
+TEAM_ORGWIDE_REPO_ROLES = {"admin": "admin", "creator": None, "member": None}
SCOPE_MAX_REPO_ROLES = defaultdict(lambda: None)
-SCOPE_MAX_REPO_ROLES.update({
- scopes.READ_REPO: 'read',
- scopes.WRITE_REPO: 'write',
- scopes.ADMIN_REPO: 'admin',
- scopes.DIRECT_LOGIN: 'admin',
-})
+SCOPE_MAX_REPO_ROLES.update(
+ {
+ scopes.READ_REPO: "read",
+ scopes.WRITE_REPO: "write",
+ scopes.ADMIN_REPO: "admin",
+ scopes.DIRECT_LOGIN: "admin",
+ }
+)
SCOPE_MAX_TEAM_ROLES = defaultdict(lambda: None)
-SCOPE_MAX_TEAM_ROLES.update({
- scopes.CREATE_REPO: 'creator',
- scopes.DIRECT_LOGIN: 'admin',
- scopes.ORG_ADMIN: 'admin',
-})
+SCOPE_MAX_TEAM_ROLES.update(
+ {
+ scopes.CREATE_REPO: "creator",
+ scopes.DIRECT_LOGIN: "admin",
+ scopes.ORG_ADMIN: "admin",
+ }
+)
SCOPE_MAX_USER_ROLES = defaultdict(lambda: None)
-SCOPE_MAX_USER_ROLES.update({
- scopes.READ_USER: 'read',
- scopes.DIRECT_LOGIN: 'admin',
- scopes.ADMIN_USER: 'admin',
-})
+SCOPE_MAX_USER_ROLES.update(
+ {scopes.READ_USER: "read", scopes.DIRECT_LOGIN: "admin", scopes.ADMIN_USER: "admin"}
+)
+
def repository_read_grant(namespace, repository):
- return _RepositoryNeed(namespace, repository, 'read')
+ return _RepositoryNeed(namespace, repository, "read")
def repository_write_grant(namespace, repository):
- return _RepositoryNeed(namespace, repository, 'write')
+ return _RepositoryNeed(namespace, repository, "write")
def repository_admin_grant(namespace, repository):
- return _RepositoryNeed(namespace, repository, 'admin')
+ return _RepositoryNeed(namespace, repository, "admin")
class QuayDeferredPermissionUser(Identity):
- def __init__(self, uuid, auth_type, auth_scopes, user=None):
- super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type)
+ def __init__(self, uuid, auth_type, auth_scopes, user=None):
+ super(QuayDeferredPermissionUser, self).__init__(uuid, auth_type)
- self._namespace_wide_loaded = set()
- self._repositories_loaded = set()
- self._personal_loaded = False
+ self._namespace_wide_loaded = set()
+ self._repositories_loaded = set()
+ self._personal_loaded = False
- self._scope_set = auth_scopes
- self._user_object = user
+ self._scope_set = auth_scopes
+ self._user_object = user
- @staticmethod
- def for_id(uuid, auth_scopes=None):
- auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
- return QuayDeferredPermissionUser(uuid, 'user_uuid', auth_scopes)
+ @staticmethod
+ def for_id(uuid, auth_scopes=None):
+ auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
+ return QuayDeferredPermissionUser(uuid, "user_uuid", auth_scopes)
- @staticmethod
- def for_user(user, auth_scopes=None):
- auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
- return QuayDeferredPermissionUser(user.uuid, 'user_uuid', auth_scopes, user=user)
+ @staticmethod
+ def for_user(user, auth_scopes=None):
+ auth_scopes = auth_scopes if auth_scopes is not None else {scopes.DIRECT_LOGIN}
+ return QuayDeferredPermissionUser(
+ user.uuid, "user_uuid", auth_scopes, user=user
+ )
- def _translate_role_for_scopes(self, cardinality, max_roles, role):
- if self._scope_set is None:
- return role
+ def _translate_role_for_scopes(self, cardinality, max_roles, role):
+ if self._scope_set is None:
+ return role
- max_for_scopes = max({cardinality.index(max_roles[scope]) for scope in self._scope_set})
+ max_for_scopes = max(
+ {cardinality.index(max_roles[scope]) for scope in self._scope_set}
+ )
- if max_for_scopes < cardinality.index(role):
- logger.debug('Translated permission %s -> %s', role, cardinality[max_for_scopes])
- return cardinality[max_for_scopes]
- else:
- return role
+ if max_for_scopes < cardinality.index(role):
+ logger.debug(
+ "Translated permission %s -> %s", role, cardinality[max_for_scopes]
+ )
+ return cardinality[max_for_scopes]
+ else:
+ return role
- def _team_role_for_scopes(self, role):
- return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role)
+ def _team_role_for_scopes(self, role):
+ return self._translate_role_for_scopes(TEAM_ROLES, SCOPE_MAX_TEAM_ROLES, role)
- def _repo_role_for_scopes(self, role):
- return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role)
+ def _repo_role_for_scopes(self, role):
+ return self._translate_role_for_scopes(REPO_ROLES, SCOPE_MAX_REPO_ROLES, role)
- def _user_role_for_scopes(self, role):
- return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role)
+ def _user_role_for_scopes(self, role):
+ return self._translate_role_for_scopes(USER_ROLES, SCOPE_MAX_USER_ROLES, role)
- def _populate_user_provides(self, user_object):
- """ Populates the provides that naturally apply to a user, such as being the admin of
+ def _populate_user_provides(self, user_object):
+ """ Populates the provides that naturally apply to a user, such as being the admin of
their own namespace.
"""
- # Add the user specific permissions, only for non-oauth permission
- user_grant = _UserNeed(user_object.username, self._user_role_for_scopes('admin'))
- logger.debug('User permission: {0}'.format(user_grant))
- self.provides.add(user_grant)
+ # Add the user specific permissions, only for non-oauth permission
+ user_grant = _UserNeed(
+ user_object.username, self._user_role_for_scopes("admin")
+ )
+ logger.debug("User permission: {0}".format(user_grant))
+ self.provides.add(user_grant)
- # Every user is the admin of their own 'org'
- user_namespace = _OrganizationNeed(user_object.username, self._team_role_for_scopes('admin'))
- logger.debug('User namespace permission: {0}'.format(user_namespace))
- self.provides.add(user_namespace)
+ # Every user is the admin of their own 'org'
+ user_namespace = _OrganizationNeed(
+ user_object.username, self._team_role_for_scopes("admin")
+ )
+ logger.debug("User namespace permission: {0}".format(user_namespace))
+ self.provides.add(user_namespace)
- # Org repo roles can differ for scopes
- user_repos = _OrganizationRepoNeed(user_object.username, self._repo_role_for_scopes('admin'))
- logger.debug('User namespace repo permission: {0}'.format(user_repos))
- self.provides.add(user_repos)
+ # Org repo roles can differ for scopes
+ user_repos = _OrganizationRepoNeed(
+ user_object.username, self._repo_role_for_scopes("admin")
+ )
+ logger.debug("User namespace repo permission: {0}".format(user_repos))
+ self.provides.add(user_repos)
- if ((scopes.SUPERUSER in self._scope_set or scopes.DIRECT_LOGIN in self._scope_set) and
- superusers.is_superuser(user_object.username)):
- logger.debug('Adding superuser to user: %s', user_object.username)
- self.provides.add(_SuperUserNeed())
+ if (
+ scopes.SUPERUSER in self._scope_set
+ or scopes.DIRECT_LOGIN in self._scope_set
+ ) and superusers.is_superuser(user_object.username):
+ logger.debug("Adding superuser to user: %s", user_object.username)
+ self.provides.add(_SuperUserNeed())
- def _populate_namespace_wide_provides(self, user_object, namespace_filter):
- """ Populates the namespace-wide provides for a particular user under a particular namespace.
+ def _populate_namespace_wide_provides(self, user_object, namespace_filter):
+ """ Populates the namespace-wide provides for a particular user under a particular namespace.
This method does *not* add any provides for specific repositories.
"""
- for team in model.permission.get_org_wide_permissions(user_object, org_filter=namespace_filter):
- team_org_grant = _OrganizationNeed(team.organization.username,
- self._team_role_for_scopes(team.role.name))
- logger.debug('Organization team added permission: {0}'.format(team_org_grant))
- self.provides.add(team_org_grant)
+ for team in model.permission.get_org_wide_permissions(
+ user_object, org_filter=namespace_filter
+ ):
+ team_org_grant = _OrganizationNeed(
+ team.organization.username, self._team_role_for_scopes(team.role.name)
+ )
+ logger.debug(
+ "Organization team added permission: {0}".format(team_org_grant)
+ )
+ self.provides.add(team_org_grant)
- team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
- org_repo_grant = _OrganizationRepoNeed(team.organization.username,
- self._repo_role_for_scopes(team_repo_role))
- logger.debug('Organization team added repo permission: {0}'.format(org_repo_grant))
- self.provides.add(org_repo_grant)
+ team_repo_role = TEAM_ORGWIDE_REPO_ROLES[team.role.name]
+ org_repo_grant = _OrganizationRepoNeed(
+ team.organization.username, self._repo_role_for_scopes(team_repo_role)
+ )
+ logger.debug(
+ "Organization team added repo permission: {0}".format(org_repo_grant)
+ )
+ self.provides.add(org_repo_grant)
- team_grant = _TeamNeed(team.organization.username, team.name,
- self._team_role_for_scopes(team.role.name))
- logger.debug('Team added permission: {0}'.format(team_grant))
- self.provides.add(team_grant)
+ team_grant = _TeamNeed(
+ team.organization.username,
+ team.name,
+ self._team_role_for_scopes(team.role.name),
+ )
+ logger.debug("Team added permission: {0}".format(team_grant))
+ self.provides.add(team_grant)
- def _populate_repository_provides(self, user_object, namespace_filter, repository_name):
- """ Populates the repository-specific provides for a particular user and repository. """
+ def _populate_repository_provides(
+ self, user_object, namespace_filter, repository_name
+ ):
+ """ Populates the repository-specific provides for a particular user and repository. """
- if namespace_filter and repository_name:
- permissions = model.permission.get_user_repository_permissions(user_object, namespace_filter,
- repository_name)
- else:
- permissions = model.permission.get_all_user_repository_permissions(user_object)
+ if namespace_filter and repository_name:
+ permissions = model.permission.get_user_repository_permissions(
+ user_object, namespace_filter, repository_name
+ )
+ else:
+ permissions = model.permission.get_all_user_repository_permissions(
+ user_object
+ )
- for perm in permissions:
- repo_grant = _RepositoryNeed(perm.repository.namespace_user.username, perm.repository.name,
- self._repo_role_for_scopes(perm.role.name))
- logger.debug('User added permission: {0}'.format(repo_grant))
- self.provides.add(repo_grant)
+ for perm in permissions:
+ repo_grant = _RepositoryNeed(
+ perm.repository.namespace_user.username,
+ perm.repository.name,
+ self._repo_role_for_scopes(perm.role.name),
+ )
+ logger.debug("User added permission: {0}".format(repo_grant))
+ self.provides.add(repo_grant)
- def can(self, permission):
- logger.debug('Loading user permissions after deferring for: %s', self.id)
- user_object = self._user_object or model.user.get_user_by_uuid(self.id)
- if user_object is None:
- return super(QuayDeferredPermissionUser, self).can(permission)
+ def can(self, permission):
+ logger.debug("Loading user permissions after deferring for: %s", self.id)
+ user_object = self._user_object or model.user.get_user_by_uuid(self.id)
+ if user_object is None:
+ return super(QuayDeferredPermissionUser, self).can(permission)
- # Add the user-specific provides.
- if not self._personal_loaded:
- self._populate_user_provides(user_object)
- self._personal_loaded = True
+ # Add the user-specific provides.
+ if not self._personal_loaded:
+ self._populate_user_provides(user_object)
+ self._personal_loaded = True
- # If we now have permission, no need to load any more permissions.
- if super(QuayDeferredPermissionUser, self).can(permission):
- return super(QuayDeferredPermissionUser, self).can(permission)
+ # If we now have permission, no need to load any more permissions.
+ if super(QuayDeferredPermissionUser, self).can(permission):
+ return super(QuayDeferredPermissionUser, self).can(permission)
- # Check for namespace and/or repository permissions.
- perm_namespace = permission.namespace
- perm_repo_name = permission.repo_name
- perm_repository = None
+ # Check for namespace and/or repository permissions.
+ perm_namespace = permission.namespace
+ perm_repo_name = permission.repo_name
+ perm_repository = None
- if perm_namespace and perm_repo_name:
- perm_repository = '%s/%s' % (perm_namespace, perm_repo_name)
+ if perm_namespace and perm_repo_name:
+ perm_repository = "%s/%s" % (perm_namespace, perm_repo_name)
- if not perm_namespace and not perm_repo_name:
- # Nothing more to load, so just check directly.
- return super(QuayDeferredPermissionUser, self).can(permission)
+ if not perm_namespace and not perm_repo_name:
+ # Nothing more to load, so just check directly.
+ return super(QuayDeferredPermissionUser, self).can(permission)
- # Lazy-load the repository-specific permissions.
- if perm_repository and perm_repository not in self._repositories_loaded:
- self._populate_repository_provides(user_object, perm_namespace, perm_repo_name)
- self._repositories_loaded.add(perm_repository)
+ # Lazy-load the repository-specific permissions.
+ if perm_repository and perm_repository not in self._repositories_loaded:
+ self._populate_repository_provides(
+ user_object, perm_namespace, perm_repo_name
+ )
+ self._repositories_loaded.add(perm_repository)
+
+ # If we now have permission, no need to load any more permissions.
+ if super(QuayDeferredPermissionUser, self).can(permission):
+ return super(QuayDeferredPermissionUser, self).can(permission)
+
+ # Lazy-load the namespace-wide-only permissions.
+ if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
+ self._populate_namespace_wide_provides(user_object, perm_namespace)
+ self._namespace_wide_loaded.add(perm_namespace)
- # If we now have permission, no need to load any more permissions.
- if super(QuayDeferredPermissionUser, self).can(permission):
return super(QuayDeferredPermissionUser, self).can(permission)
- # Lazy-load the namespace-wide-only permissions.
- if perm_namespace and perm_namespace not in self._namespace_wide_loaded:
- self._populate_namespace_wide_provides(user_object, perm_namespace)
- self._namespace_wide_loaded.add(perm_namespace)
-
- return super(QuayDeferredPermissionUser, self).can(permission)
-
class QuayPermission(Permission):
- """ Base for all permissions in Quay. """
- namespace = None
- repo_name = None
+ """ Base for all permissions in Quay. """
+
+ namespace = None
+ repo_name = None
class ModifyRepositoryPermission(QuayPermission):
- def __init__(self, namespace, name):
- admin_need = _RepositoryNeed(namespace, name, 'admin')
- write_need = _RepositoryNeed(namespace, name, 'write')
- org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
- org_write_need = _OrganizationRepoNeed(namespace, 'write')
+ def __init__(self, namespace, name):
+ admin_need = _RepositoryNeed(namespace, name, "admin")
+ write_need = _RepositoryNeed(namespace, name, "write")
+ org_admin_need = _OrganizationRepoNeed(namespace, "admin")
+ org_write_need = _OrganizationRepoNeed(namespace, "write")
- self.namespace = namespace
- self.repo_name = name
+ self.namespace = namespace
+ self.repo_name = name
- super(ModifyRepositoryPermission, self).__init__(admin_need, write_need, org_admin_need,
- org_write_need)
+ super(ModifyRepositoryPermission, self).__init__(
+ admin_need, write_need, org_admin_need, org_write_need
+ )
class ReadRepositoryPermission(QuayPermission):
- def __init__(self, namespace, name):
- admin_need = _RepositoryNeed(namespace, name, 'admin')
- write_need = _RepositoryNeed(namespace, name, 'write')
- read_need = _RepositoryNeed(namespace, name, 'read')
- org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
- org_write_need = _OrganizationRepoNeed(namespace, 'write')
- org_read_need = _OrganizationRepoNeed(namespace, 'read')
+ def __init__(self, namespace, name):
+ admin_need = _RepositoryNeed(namespace, name, "admin")
+ write_need = _RepositoryNeed(namespace, name, "write")
+ read_need = _RepositoryNeed(namespace, name, "read")
+ org_admin_need = _OrganizationRepoNeed(namespace, "admin")
+ org_write_need = _OrganizationRepoNeed(namespace, "write")
+ org_read_need = _OrganizationRepoNeed(namespace, "read")
- self.namespace = namespace
- self.repo_name = name
+ self.namespace = namespace
+ self.repo_name = name
- super(ReadRepositoryPermission, self).__init__(admin_need, write_need, read_need,
- org_admin_need, org_read_need, org_write_need)
+ super(ReadRepositoryPermission, self).__init__(
+ admin_need,
+ write_need,
+ read_need,
+ org_admin_need,
+ org_read_need,
+ org_write_need,
+ )
class AdministerRepositoryPermission(QuayPermission):
- def __init__(self, namespace, name):
- admin_need = _RepositoryNeed(namespace, name, 'admin')
- org_admin_need = _OrganizationRepoNeed(namespace, 'admin')
+ def __init__(self, namespace, name):
+ admin_need = _RepositoryNeed(namespace, name, "admin")
+ org_admin_need = _OrganizationRepoNeed(namespace, "admin")
- self.namespace = namespace
- self.repo_name = name
+ self.namespace = namespace
+ self.repo_name = name
- super(AdministerRepositoryPermission, self).__init__(admin_need,
- org_admin_need)
+ super(AdministerRepositoryPermission, self).__init__(admin_need, org_admin_need)
class CreateRepositoryPermission(QuayPermission):
- def __init__(self, namespace):
- admin_org = _OrganizationNeed(namespace, 'admin')
- create_repo_org = _OrganizationNeed(namespace, 'creator')
+ def __init__(self, namespace):
+ admin_org = _OrganizationNeed(namespace, "admin")
+ create_repo_org = _OrganizationNeed(namespace, "creator")
- self.namespace = namespace
+ self.namespace = namespace
+
+ super(CreateRepositoryPermission, self).__init__(admin_org, create_repo_org)
- super(CreateRepositoryPermission, self).__init__(admin_org,
- create_repo_org)
class SuperUserPermission(QuayPermission):
- def __init__(self):
- need = _SuperUserNeed()
- super(SuperUserPermission, self).__init__(need)
+ def __init__(self):
+ need = _SuperUserNeed()
+ super(SuperUserPermission, self).__init__(need)
class UserAdminPermission(QuayPermission):
- def __init__(self, username):
- user_admin = _UserNeed(username, 'admin')
- super(UserAdminPermission, self).__init__(user_admin)
+ def __init__(self, username):
+ user_admin = _UserNeed(username, "admin")
+ super(UserAdminPermission, self).__init__(user_admin)
class UserReadPermission(QuayPermission):
- def __init__(self, username):
- user_admin = _UserNeed(username, 'admin')
- user_read = _UserNeed(username, 'read')
- super(UserReadPermission, self).__init__(user_read, user_admin)
+ def __init__(self, username):
+ user_admin = _UserNeed(username, "admin")
+ user_read = _UserNeed(username, "read")
+ super(UserReadPermission, self).__init__(user_read, user_admin)
class AdministerOrganizationPermission(QuayPermission):
- def __init__(self, org_name):
- admin_org = _OrganizationNeed(org_name, 'admin')
+ def __init__(self, org_name):
+ admin_org = _OrganizationNeed(org_name, "admin")
- self.namespace = org_name
+ self.namespace = org_name
- super(AdministerOrganizationPermission, self).__init__(admin_org)
+ super(AdministerOrganizationPermission, self).__init__(admin_org)
class OrganizationMemberPermission(QuayPermission):
- def __init__(self, org_name):
- admin_org = _OrganizationNeed(org_name, 'admin')
- repo_creator_org = _OrganizationNeed(org_name, 'creator')
- org_member = _OrganizationNeed(org_name, 'member')
+ def __init__(self, org_name):
+ admin_org = _OrganizationNeed(org_name, "admin")
+ repo_creator_org = _OrganizationNeed(org_name, "creator")
+ org_member = _OrganizationNeed(org_name, "member")
- self.namespace = org_name
+ self.namespace = org_name
- super(OrganizationMemberPermission, self).__init__(admin_org, org_member,
- repo_creator_org)
+ super(OrganizationMemberPermission, self).__init__(
+ admin_org, org_member, repo_creator_org
+ )
class ViewTeamPermission(QuayPermission):
- def __init__(self, org_name, team_name):
- team_admin = _TeamNeed(org_name, team_name, 'admin')
- team_creator = _TeamNeed(org_name, team_name, 'creator')
- team_member = _TeamNeed(org_name, team_name, 'member')
- admin_org = _OrganizationNeed(org_name, 'admin')
+ def __init__(self, org_name, team_name):
+ team_admin = _TeamNeed(org_name, team_name, "admin")
+ team_creator = _TeamNeed(org_name, team_name, "creator")
+ team_member = _TeamNeed(org_name, team_name, "member")
+ admin_org = _OrganizationNeed(org_name, "admin")
- self.namespace = org_name
+ self.namespace = org_name
- super(ViewTeamPermission, self).__init__(team_admin, team_creator,
- team_member, admin_org)
+ super(ViewTeamPermission, self).__init__(
+ team_admin, team_creator, team_member, admin_org
+ )
class AlwaysFailPermission(QuayPermission):
- def can(self):
- return False
+ def can(self):
+ return False
@identity_loaded.connect_via(app)
def on_identity_loaded(sender, identity):
- logger.debug('Identity loaded: %s' % identity)
- # We have verified an identity, load in all of the permissions
+ logger.debug("Identity loaded: %s" % identity)
+ # We have verified an identity, load in all of the permissions
- if isinstance(identity, QuayDeferredPermissionUser):
- logger.debug('Deferring permissions for user with uuid: %s', identity.id)
+ if isinstance(identity, QuayDeferredPermissionUser):
+ logger.debug("Deferring permissions for user with uuid: %s", identity.id)
- elif identity.auth_type == 'user_uuid':
- logger.debug('Switching username permission to deferred object with uuid: %s', identity.id)
- switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
- identity_changed.send(app, identity=switch_to_deferred)
+ elif identity.auth_type == "user_uuid":
+ logger.debug(
+ "Switching username permission to deferred object with uuid: %s",
+ identity.id,
+ )
+ switch_to_deferred = QuayDeferredPermissionUser.for_id(identity.id)
+ identity_changed.send(app, identity=switch_to_deferred)
- elif identity.auth_type == 'token':
- logger.debug('Loading permissions for token: %s', identity.id)
- token_data = model.token.load_token_data(identity.id)
+ elif identity.auth_type == "token":
+ logger.debug("Loading permissions for token: %s", identity.id)
+ token_data = model.token.load_token_data(identity.id)
- repo_grant = _RepositoryNeed(token_data.repository.namespace_user.username,
- token_data.repository.name,
- token_data.role.name)
- logger.debug('Delegate token added permission: %s', repo_grant)
- identity.provides.add(repo_grant)
+ repo_grant = _RepositoryNeed(
+ token_data.repository.namespace_user.username,
+ token_data.repository.name,
+ token_data.role.name,
+ )
+ logger.debug("Delegate token added permission: %s", repo_grant)
+ identity.provides.add(repo_grant)
- elif identity.auth_type == 'signed_grant' or identity.auth_type == 'signed_jwt':
- logger.debug('Loaded %s identity for: %s', identity.auth_type, identity.id)
+ elif identity.auth_type == "signed_grant" or identity.auth_type == "signed_jwt":
+ logger.debug("Loaded %s identity for: %s", identity.auth_type, identity.id)
- else:
- logger.error('Unknown identity auth type: %s', identity.auth_type)
+ else:
+ logger.error("Unknown identity auth type: %s", identity.auth_type)
diff --git a/auth/registry_jwt_auth.py b/auth/registry_jwt_auth.py
index 75be63d73..135e49a94 100644
--- a/auth/registry_jwt_auth.py
+++ b/auth/registry_jwt_auth.py
@@ -9,156 +9,166 @@ from flask_principal import identity_changed, Identity
from app import app, get_app_url, instance_keys, metric_queue
from auth.auth_context import set_authenticated_context
from auth.auth_context_type import SignedAuthContext
-from auth.permissions import repository_read_grant, repository_write_grant, repository_admin_grant
+from auth.permissions import (
+ repository_read_grant,
+ repository_write_grant,
+ repository_admin_grant,
+)
from util.http import abort
from util.names import parse_namespace_repository
-from util.security.registry_jwt import (ANONYMOUS_SUB, decode_bearer_header,
- InvalidBearerTokenException)
+from util.security.registry_jwt import (
+ ANONYMOUS_SUB,
+ decode_bearer_header,
+ InvalidBearerTokenException,
+)
logger = logging.getLogger(__name__)
ACCESS_SCHEMA = {
- 'type': 'array',
- 'description': 'List of access granted to the subject',
- 'items': {
- 'type': 'object',
- 'required': [
- 'type',
- 'name',
- 'actions',
- ],
- 'properties': {
- 'type': {
- 'type': 'string',
- 'description': 'We only allow repository permissions',
- 'enum': [
- 'repository',
- ],
- },
- 'name': {
- 'type': 'string',
- 'description': 'The name of the repository for which we are receiving access'
- },
- 'actions': {
- 'type': 'array',
- 'description': 'List of specific verbs which can be performed against repository',
- 'items': {
- 'type': 'string',
- 'enum': [
- 'push',
- 'pull',
- '*',
- ],
+ "type": "array",
+ "description": "List of access granted to the subject",
+ "items": {
+ "type": "object",
+ "required": ["type", "name", "actions"],
+ "properties": {
+ "type": {
+ "type": "string",
+ "description": "We only allow repository permissions",
+ "enum": ["repository"],
+ },
+ "name": {
+ "type": "string",
+ "description": "The name of the repository for which we are receiving access",
+ },
+ "actions": {
+ "type": "array",
+ "description": "List of specific verbs which can be performed against repository",
+ "items": {"type": "string", "enum": ["push", "pull", "*"]},
+ },
},
- },
},
- },
}
class InvalidJWTException(Exception):
- pass
+ pass
def get_auth_headers(repository=None, scopes=None):
- """ Returns a dictionary of headers for auth responses. """
- headers = {}
- realm_auth_path = url_for('v2.generate_registry_jwt')
- authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(),
- realm_auth_path,
- app.config['SERVER_HOSTNAME'])
- if repository:
- scopes_string = "repository:{0}".format(repository)
- if scopes:
- scopes_string += ':' + ','.join(scopes)
+ """ Returns a dictionary of headers for auth responses. """
+ headers = {}
+ realm_auth_path = url_for("v2.generate_registry_jwt")
+ authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(
+ get_app_url(), realm_auth_path, app.config["SERVER_HOSTNAME"]
+ )
+ if repository:
+ scopes_string = "repository:{0}".format(repository)
+ if scopes:
+ scopes_string += ":" + ",".join(scopes)
- authenticate += ',scope="{0}"'.format(scopes_string)
+ authenticate += ',scope="{0}"'.format(scopes_string)
- headers['WWW-Authenticate'] = authenticate
- headers['Docker-Distribution-API-Version'] = 'registry/2.0'
- return headers
+ headers["WWW-Authenticate"] = authenticate
+ headers["Docker-Distribution-API-Version"] = "registry/2.0"
+ return headers
def identity_from_bearer_token(bearer_header):
- """ Process a bearer header and return the loaded identity, or raise InvalidJWTException if an
+ """ Process a bearer header and return the loaded identity, or raise InvalidJWTException if an
identity could not be loaded. Expects tokens and grants in the format of the Docker registry
v2 auth spec: https://docs.docker.com/registry/spec/auth/token/
"""
- logger.debug('Validating auth header: %s', bearer_header)
+ logger.debug("Validating auth header: %s", bearer_header)
- try:
- payload = decode_bearer_header(bearer_header, instance_keys, app.config,
- metric_queue=metric_queue)
- except InvalidBearerTokenException as bte:
- logger.exception('Invalid bearer token: %s', bte)
- raise InvalidJWTException(bte)
-
- loaded_identity = Identity(payload['sub'], 'signed_jwt')
-
- # Process the grants from the payload
- if 'access' in payload:
try:
- validate(payload['access'], ACCESS_SCHEMA)
- except ValidationError:
- logger.exception('We should not be minting invalid credentials')
- raise InvalidJWTException('Token contained invalid or malformed access grants')
+ payload = decode_bearer_header(
+ bearer_header, instance_keys, app.config, metric_queue=metric_queue
+ )
+ except InvalidBearerTokenException as bte:
+ logger.exception("Invalid bearer token: %s", bte)
+ raise InvalidJWTException(bte)
- lib_namespace = app.config['LIBRARY_NAMESPACE']
- for grant in payload['access']:
- namespace, repo_name = parse_namespace_repository(grant['name'], lib_namespace)
+ loaded_identity = Identity(payload["sub"], "signed_jwt")
- if '*' in grant['actions']:
- loaded_identity.provides.add(repository_admin_grant(namespace, repo_name))
- elif 'push' in grant['actions']:
- loaded_identity.provides.add(repository_write_grant(namespace, repo_name))
- elif 'pull' in grant['actions']:
- loaded_identity.provides.add(repository_read_grant(namespace, repo_name))
+ # Process the grants from the payload
+ if "access" in payload:
+ try:
+ validate(payload["access"], ACCESS_SCHEMA)
+ except ValidationError:
+ logger.exception("We should not be minting invalid credentials")
+ raise InvalidJWTException(
+ "Token contained invalid or malformed access grants"
+ )
- default_context = {
- 'kind': 'anonymous'
- }
+ lib_namespace = app.config["LIBRARY_NAMESPACE"]
+ for grant in payload["access"]:
+ namespace, repo_name = parse_namespace_repository(
+ grant["name"], lib_namespace
+ )
- if payload['sub'] != ANONYMOUS_SUB:
- default_context = {
- 'kind': 'user',
- 'user': payload['sub'],
- }
+ if "*" in grant["actions"]:
+ loaded_identity.provides.add(
+ repository_admin_grant(namespace, repo_name)
+ )
+ elif "push" in grant["actions"]:
+ loaded_identity.provides.add(
+ repository_write_grant(namespace, repo_name)
+ )
+ elif "pull" in grant["actions"]:
+ loaded_identity.provides.add(
+ repository_read_grant(namespace, repo_name)
+ )
- return loaded_identity, payload.get('context', default_context)
+ default_context = {"kind": "anonymous"}
+
+ if payload["sub"] != ANONYMOUS_SUB:
+ default_context = {"kind": "user", "user": payload["sub"]}
+
+ return loaded_identity, payload.get("context", default_context)
def process_registry_jwt_auth(scopes=None):
- """ Processes the registry JWT auth token found in the authorization header. If none found,
+ """ Processes the registry JWT auth token found in the authorization header. If none found,
no error is returned. If an invalid token is found, raises a 401.
"""
- def inner(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- logger.debug('Called with params: %s, %s', args, kwargs)
- auth = request.headers.get('authorization', '').strip()
- if auth:
- try:
- extracted_identity, context_dict = identity_from_bearer_token(auth)
- identity_changed.send(app, identity=extracted_identity)
- logger.debug('Identity changed to %s', extracted_identity.id)
- auth_context = SignedAuthContext.build_from_signed_dict(context_dict)
- if auth_context is not None:
- logger.debug('Auth context set to %s', auth_context.signed_data)
- set_authenticated_context(auth_context)
+ def inner(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ logger.debug("Called with params: %s, %s", args, kwargs)
+ auth = request.headers.get("authorization", "").strip()
+ if auth:
+ try:
+ extracted_identity, context_dict = identity_from_bearer_token(auth)
+ identity_changed.send(app, identity=extracted_identity)
+ logger.debug("Identity changed to %s", extracted_identity.id)
- except InvalidJWTException as ije:
- repository = None
- if 'namespace_name' in kwargs and 'repo_name' in kwargs:
- repository = kwargs['namespace_name'] + '/' + kwargs['repo_name']
+ auth_context = SignedAuthContext.build_from_signed_dict(
+ context_dict
+ )
+ if auth_context is not None:
+ logger.debug("Auth context set to %s", auth_context.signed_data)
+ set_authenticated_context(auth_context)
- abort(401, message=ije.message, headers=get_auth_headers(repository=repository,
- scopes=scopes))
- else:
- logger.debug('No auth header.')
+ except InvalidJWTException as ije:
+ repository = None
+ if "namespace_name" in kwargs and "repo_name" in kwargs:
+ repository = (
+ kwargs["namespace_name"] + "/" + kwargs["repo_name"]
+ )
- return func(*args, **kwargs)
- return wrapper
- return inner
+ abort(
+ 401,
+ message=ije.message,
+ headers=get_auth_headers(repository=repository, scopes=scopes),
+ )
+ else:
+ logger.debug("No auth header.")
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return inner
diff --git a/auth/scopes.py b/auth/scopes.py
index dbbb0ae1c..c16a3dbf3 100644
--- a/auth/scopes.py
+++ b/auth/scopes.py
@@ -2,145 +2,194 @@ from collections import namedtuple
import features
import re
-Scope = namedtuple('scope', ['scope', 'icon', 'dangerous', 'title', 'description'])
+Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"])
-READ_REPO = Scope(scope='repo:read',
- icon='fa-hdd-o',
- dangerous=False,
- title='View all visible repositories',
- description=('This application will be able to view and pull all repositories '
- 'visible to the granting user or robot account'))
+READ_REPO = Scope(
+ scope="repo:read",
+ icon="fa-hdd-o",
+ dangerous=False,
+ title="View all visible repositories",
+ description=(
+ "This application will be able to view and pull all repositories "
+ "visible to the granting user or robot account"
+ ),
+)
-WRITE_REPO = Scope(scope='repo:write',
- icon='fa-hdd-o',
- dangerous=False,
- title='Read/Write to any accessible repositories',
- description=('This application will be able to view, push and pull to all '
- 'repositories to which the granting user or robot account has '
- 'write access'))
+WRITE_REPO = Scope(
+ scope="repo:write",
+ icon="fa-hdd-o",
+ dangerous=False,
+ title="Read/Write to any accessible repositories",
+ description=(
+ "This application will be able to view, push and pull to all "
+ "repositories to which the granting user or robot account has "
+ "write access"
+ ),
+)
-ADMIN_REPO = Scope(scope='repo:admin',
- icon='fa-hdd-o',
- dangerous=False,
- title='Administer Repositories',
- description=('This application will have administrator access to all '
- 'repositories to which the granting user or robot account has '
- 'access'))
+ADMIN_REPO = Scope(
+ scope="repo:admin",
+ icon="fa-hdd-o",
+ dangerous=False,
+ title="Administer Repositories",
+ description=(
+ "This application will have administrator access to all "
+ "repositories to which the granting user or robot account has "
+ "access"
+ ),
+)
-CREATE_REPO = Scope(scope='repo:create',
- icon='fa-plus',
- dangerous=False,
- title='Create Repositories',
- description=('This application will be able to create repositories in to any '
- 'namespaces that the granting user or robot account is allowed '
- 'to create repositories'))
+CREATE_REPO = Scope(
+ scope="repo:create",
+ icon="fa-plus",
+ dangerous=False,
+ title="Create Repositories",
+ description=(
+ "This application will be able to create repositories in to any "
+ "namespaces that the granting user or robot account is allowed "
+ "to create repositories"
+ ),
+)
-READ_USER = Scope(scope= 'user:read',
- icon='fa-user',
- dangerous=False,
- title='Read User Information',
- description=('This application will be able to read user information such as '
- 'username and email address.'))
+READ_USER = Scope(
+ scope="user:read",
+ icon="fa-user",
+ dangerous=False,
+ title="Read User Information",
+ description=(
+ "This application will be able to read user information such as "
+ "username and email address."
+ ),
+)
-ADMIN_USER = Scope(scope= 'user:admin',
- icon='fa-gear',
- dangerous=True,
- title='Administer User',
- description=('This application will be able to administer your account '
- 'including creating robots and granting them permissions '
- 'to your repositories. You should have absolute trust in the '
- 'requesting application before granting this permission.'))
+ADMIN_USER = Scope(
+ scope="user:admin",
+ icon="fa-gear",
+ dangerous=True,
+ title="Administer User",
+ description=(
+ "This application will be able to administer your account "
+ "including creating robots and granting them permissions "
+ "to your repositories. You should have absolute trust in the "
+ "requesting application before granting this permission."
+ ),
+)
-ORG_ADMIN = Scope(scope='org:admin',
- icon='fa-gear',
- dangerous=True,
- title='Administer Organization',
- description=('This application will be able to administer your organizations '
- 'including creating robots, creating teams, adjusting team '
- 'membership, and changing billing settings. You should have '
- 'absolute trust in the requesting application before granting this '
- 'permission.'))
+ORG_ADMIN = Scope(
+ scope="org:admin",
+ icon="fa-gear",
+ dangerous=True,
+ title="Administer Organization",
+ description=(
+ "This application will be able to administer your organizations "
+ "including creating robots, creating teams, adjusting team "
+ "membership, and changing billing settings. You should have "
+ "absolute trust in the requesting application before granting this "
+ "permission."
+ ),
+)
-DIRECT_LOGIN = Scope(scope='direct_user_login',
- icon='fa-exclamation-triangle',
- dangerous=True,
- title='Full Access',
- description=('This scope should not be available to OAuth applications. '
- 'Never approve a request for this scope!'))
+DIRECT_LOGIN = Scope(
+ scope="direct_user_login",
+ icon="fa-exclamation-triangle",
+ dangerous=True,
+ title="Full Access",
+ description=(
+ "This scope should not be available to OAuth applications. "
+ "Never approve a request for this scope!"
+ ),
+)
-SUPERUSER = Scope(scope='super:user',
- icon='fa-street-view',
- dangerous=True,
- title='Super User Access',
- description=('This application will be able to administer your installation '
- 'including managing users, managing organizations and other '
- 'features found in the superuser panel. You should have '
- 'absolute trust in the requesting application before granting this '
- 'permission.'))
+SUPERUSER = Scope(
+ scope="super:user",
+ icon="fa-street-view",
+ dangerous=True,
+ title="Super User Access",
+ description=(
+ "This application will be able to administer your installation "
+ "including managing users, managing organizations and other "
+ "features found in the superuser panel. You should have "
+ "absolute trust in the requesting application before granting this "
+ "permission."
+ ),
+)
-ALL_SCOPES = {scope.scope: scope for scope in (READ_REPO, WRITE_REPO, ADMIN_REPO, CREATE_REPO,
- READ_USER, ORG_ADMIN, SUPERUSER, ADMIN_USER)}
+ALL_SCOPES = {
+ scope.scope: scope
+ for scope in (
+ READ_REPO,
+ WRITE_REPO,
+ ADMIN_REPO,
+ CREATE_REPO,
+ READ_USER,
+ ORG_ADMIN,
+ SUPERUSER,
+ ADMIN_USER,
+ )
+}
IMPLIED_SCOPES = {
- ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
- WRITE_REPO: {WRITE_REPO, READ_REPO},
- READ_REPO: {READ_REPO},
- CREATE_REPO: {CREATE_REPO},
- READ_USER: {READ_USER},
- ORG_ADMIN: {ORG_ADMIN},
- SUPERUSER: {SUPERUSER},
- ADMIN_USER: {ADMIN_USER},
- None: set(),
+ ADMIN_REPO: {ADMIN_REPO, WRITE_REPO, READ_REPO},
+ WRITE_REPO: {WRITE_REPO, READ_REPO},
+ READ_REPO: {READ_REPO},
+ CREATE_REPO: {CREATE_REPO},
+ READ_USER: {READ_USER},
+ ORG_ADMIN: {ORG_ADMIN},
+ SUPERUSER: {SUPERUSER},
+ ADMIN_USER: {ADMIN_USER},
+ None: set(),
}
def app_scopes(app_config):
- scopes_from_config = dict(ALL_SCOPES)
- if not app_config.get('FEATURE_SUPER_USERS', False):
- del scopes_from_config[SUPERUSER.scope]
- return scopes_from_config
+ scopes_from_config = dict(ALL_SCOPES)
+ if not app_config.get("FEATURE_SUPER_USERS", False):
+ del scopes_from_config[SUPERUSER.scope]
+ return scopes_from_config
def scopes_from_scope_string(scopes):
- if not scopes:
- scopes = ''
+ if not scopes:
+ scopes = ""
- # Note: The scopes string should be space seperated according to the spec:
- # https://tools.ietf.org/html/rfc6749#section-3.3
- # However, we also support commas for backwards compatibility with existing callers to our code.
- scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(' |,', scopes)}
- return scope_set if not None in scope_set else set()
+ # Note: The scopes string should be space seperated according to the spec:
+ # https://tools.ietf.org/html/rfc6749#section-3.3
+ # However, we also support commas for backwards compatibility with existing callers to our code.
+ scope_set = {ALL_SCOPES.get(scope, None) for scope in re.split(" |,", scopes)}
+ return scope_set if not None in scope_set else set()
def validate_scope_string(scopes):
- decoded = scopes_from_scope_string(scopes)
- return len(decoded) > 0
+ decoded = scopes_from_scope_string(scopes)
+ return len(decoded) > 0
def is_subset_string(full_string, expected_string):
- """ Returns true if the scopes found in expected_string are also found
+ """ Returns true if the scopes found in expected_string are also found
in full_string.
"""
- full_scopes = scopes_from_scope_string(full_string)
- if not full_scopes:
- return False
+ full_scopes = scopes_from_scope_string(full_string)
+ if not full_scopes:
+ return False
- full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes])
- expected_scopes = scopes_from_scope_string(expected_string)
- return expected_scopes.issubset(full_implied_scopes)
+ full_implied_scopes = set.union(*[IMPLIED_SCOPES[scope] for scope in full_scopes])
+ expected_scopes = scopes_from_scope_string(expected_string)
+ return expected_scopes.issubset(full_implied_scopes)
def get_scope_information(scopes_string):
- scopes = scopes_from_scope_string(scopes_string)
- scope_info = []
- for scope in scopes:
- scope_info.append({
- 'title': scope.title,
- 'scope': scope.scope,
- 'description': scope.description,
- 'icon': scope.icon,
- 'dangerous': scope.dangerous,
- })
+ scopes = scopes_from_scope_string(scopes_string)
+ scope_info = []
+ for scope in scopes:
+ scope_info.append(
+ {
+ "title": scope.title,
+ "scope": scope.scope,
+ "description": scope.description,
+ "icon": scope.icon,
+ "dangerous": scope.dangerous,
+ }
+ )
- return scope_info
+ return scope_info
diff --git a/auth/signedgrant.py b/auth/signedgrant.py
index b8169114d..4063115a0 100644
--- a/auth/signedgrant.py
+++ b/auth/signedgrant.py
@@ -8,48 +8,49 @@ from auth.validateresult import AuthKind, ValidateResult
logger = logging.getLogger(__name__)
# The prefix for all signatures of signed granted.
-SIGNATURE_PREFIX = 'sigv2='
+SIGNATURE_PREFIX = "sigv2="
+
def generate_signed_token(grants, user_context):
- """ Generates a signed session token with the given grants and user context. """
- ser = SecureCookieSessionInterface().get_signing_serializer(app)
- data_to_sign = {
- 'grants': grants,
- 'user_context': user_context,
- }
+ """ Generates a signed session token with the given grants and user context. """
+ ser = SecureCookieSessionInterface().get_signing_serializer(app)
+ data_to_sign = {"grants": grants, "user_context": user_context}
- encrypted = ser.dumps(data_to_sign)
- return '{0}{1}'.format(SIGNATURE_PREFIX, encrypted)
+ encrypted = ser.dumps(data_to_sign)
+ return "{0}{1}".format(SIGNATURE_PREFIX, encrypted)
def validate_signed_grant(auth_header):
- """ Validates a signed grant as found inside an auth header and returns whether it points to
+ """ Validates a signed grant as found inside an auth header and returns whether it points to
a valid grant.
"""
- if not auth_header:
- return ValidateResult(AuthKind.signed_grant, missing=True)
+ if not auth_header:
+ return ValidateResult(AuthKind.signed_grant, missing=True)
- # Try to parse the token from the header.
- normalized = [part.strip() for part in auth_header.split(' ') if part]
- if normalized[0].lower() != 'token' or len(normalized) != 2:
- logger.debug('Not a token: %s', auth_header)
- return ValidateResult(AuthKind.signed_grant, missing=True)
+ # Try to parse the token from the header.
+ normalized = [part.strip() for part in auth_header.split(" ") if part]
+ if normalized[0].lower() != "token" or len(normalized) != 2:
+ logger.debug("Not a token: %s", auth_header)
+ return ValidateResult(AuthKind.signed_grant, missing=True)
- # Check that it starts with the expected prefix.
- if not normalized[1].startswith(SIGNATURE_PREFIX):
- logger.debug('Not a signed grant token: %s', auth_header)
- return ValidateResult(AuthKind.signed_grant, missing=True)
+ # Check that it starts with the expected prefix.
+ if not normalized[1].startswith(SIGNATURE_PREFIX):
+ logger.debug("Not a signed grant token: %s", auth_header)
+ return ValidateResult(AuthKind.signed_grant, missing=True)
- # Decrypt the grant.
- encrypted = normalized[1][len(SIGNATURE_PREFIX):]
- ser = SecureCookieSessionInterface().get_signing_serializer(app)
+ # Decrypt the grant.
+ encrypted = normalized[1][len(SIGNATURE_PREFIX) :]
+ ser = SecureCookieSessionInterface().get_signing_serializer(app)
- try:
- token_data = ser.loads(encrypted, max_age=app.config['SIGNED_GRANT_EXPIRATION_SEC'])
- except BadSignature:
- logger.warning('Signed grant could not be validated: %s', encrypted)
- return ValidateResult(AuthKind.signed_grant,
- error_message='Signed grant could not be validated')
+ try:
+ token_data = ser.loads(
+ encrypted, max_age=app.config["SIGNED_GRANT_EXPIRATION_SEC"]
+ )
+ except BadSignature:
+ logger.warning("Signed grant could not be validated: %s", encrypted)
+ return ValidateResult(
+ AuthKind.signed_grant, error_message="Signed grant could not be validated"
+ )
- logger.debug('Successfully validated signed grant with data: %s', token_data)
- return ValidateResult(AuthKind.signed_grant, signed_data=token_data)
+ logger.debug("Successfully validated signed grant with data: %s", token_data)
+ return ValidateResult(AuthKind.signed_grant, signed_data=token_data)
diff --git a/auth/test/test_auth_context_type.py b/auth/test/test_auth_context_type.py
index 7778d7f90..0b4e8227a 100644
--- a/auth/test/test_auth_context_type.py
+++ b/auth/test/test_auth_context_type.py
@@ -1,51 +1,65 @@
import pytest
-from auth.auth_context_type import SignedAuthContext, ValidatedAuthContext, ContextEntityKind
+from auth.auth_context_type import (
+ SignedAuthContext,
+ ValidatedAuthContext,
+ ContextEntityKind,
+)
from data import model, database
from test.fixtures import *
+
def get_oauth_token(_):
- return database.OAuthAccessToken.get()
+ return database.OAuthAccessToken.get()
-@pytest.mark.parametrize('kind, entity_reference, loader', [
- (ContextEntityKind.anonymous, None, None),
- (ContextEntityKind.appspecifictoken, '%s%s' % ('a' * 60, 'b' * 60),
- model.appspecifictoken.access_valid_token),
- (ContextEntityKind.oauthtoken, None, get_oauth_token),
- (ContextEntityKind.robot, 'devtable+dtrobot', model.user.lookup_robot),
- (ContextEntityKind.user, 'devtable', model.user.get_user),
-])
-@pytest.mark.parametrize('v1_dict_format', [
- (True),
- (False),
-])
-def test_signed_auth_context(kind, entity_reference, loader, v1_dict_format, initialized_db):
- if kind == ContextEntityKind.anonymous:
- validated = ValidatedAuthContext()
- assert validated.is_anonymous
- else:
- ref = loader(entity_reference)
- validated = ValidatedAuthContext(**{kind.value: ref})
- assert not validated.is_anonymous
+@pytest.mark.parametrize(
+ "kind, entity_reference, loader",
+ [
+ (ContextEntityKind.anonymous, None, None),
+ (
+ ContextEntityKind.appspecifictoken,
+ "%s%s" % ("a" * 60, "b" * 60),
+ model.appspecifictoken.access_valid_token,
+ ),
+ (ContextEntityKind.oauthtoken, None, get_oauth_token),
+ (ContextEntityKind.robot, "devtable+dtrobot", model.user.lookup_robot),
+ (ContextEntityKind.user, "devtable", model.user.get_user),
+ ],
+)
+@pytest.mark.parametrize("v1_dict_format", [(True), (False)])
+def test_signed_auth_context(
+ kind, entity_reference, loader, v1_dict_format, initialized_db
+):
+ if kind == ContextEntityKind.anonymous:
+ validated = ValidatedAuthContext()
+ assert validated.is_anonymous
+ else:
+ ref = loader(entity_reference)
+ validated = ValidatedAuthContext(**{kind.value: ref})
+ assert not validated.is_anonymous
- assert validated.entity_kind == kind
- assert validated.unique_key
+ assert validated.entity_kind == kind
+ assert validated.unique_key
- signed = SignedAuthContext.build_from_signed_dict(validated.to_signed_dict(),
- v1_dict_format=v1_dict_format)
+ signed = SignedAuthContext.build_from_signed_dict(
+ validated.to_signed_dict(), v1_dict_format=v1_dict_format
+ )
- if not v1_dict_format:
- # Under legacy V1 format, we don't track the app specific token, merely its associated user.
- assert signed.entity_kind == kind
- assert signed.description == validated.description
- assert signed.credential_username == validated.credential_username
- assert signed.analytics_id_and_public_metadata() == validated.analytics_id_and_public_metadata()
- assert signed.unique_key == validated.unique_key
+ if not v1_dict_format:
+ # Under legacy V1 format, we don't track the app specific token, merely its associated user.
+ assert signed.entity_kind == kind
+ assert signed.description == validated.description
+ assert signed.credential_username == validated.credential_username
+ assert (
+ signed.analytics_id_and_public_metadata()
+ == validated.analytics_id_and_public_metadata()
+ )
+ assert signed.unique_key == validated.unique_key
- assert signed.is_anonymous == validated.is_anonymous
- assert signed.authed_user == validated.authed_user
- assert signed.has_nonrobot_user == validated.has_nonrobot_user
+ assert signed.is_anonymous == validated.is_anonymous
+ assert signed.authed_user == validated.authed_user
+ assert signed.has_nonrobot_user == validated.has_nonrobot_user
- assert signed.to_signed_dict() == validated.to_signed_dict()
+ assert signed.to_signed_dict() == validated.to_signed_dict()
diff --git a/auth/test/test_basic.py b/auth/test/test_basic.py
index 24279b4b2..c7ecdc09c 100644
--- a/auth/test/test_basic.py
+++ b/auth/test/test_basic.py
@@ -5,8 +5,11 @@ import pytest
from base64 import b64encode
from auth.basic import validate_basic_auth
-from auth.credentials import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
- APP_SPECIFIC_TOKEN_USERNAME)
+from auth.credentials import (
+ ACCESS_TOKEN_USERNAME,
+ OAUTH_TOKEN_USERNAME,
+ APP_SPECIFIC_TOKEN_USERNAME,
+)
from auth.validateresult import AuthKind, ValidateResult
from data import model
@@ -14,85 +17,120 @@ from test.fixtures import *
def _token(username, password):
- assert isinstance(username, basestring)
- assert isinstance(password, basestring)
- return 'basic ' + b64encode('%s:%s' % (username, password))
+ assert isinstance(username, basestring)
+ assert isinstance(password, basestring)
+ return "basic " + b64encode("%s:%s" % (username, password))
-@pytest.mark.parametrize('token, expected_result', [
- ('', ValidateResult(AuthKind.basic, missing=True)),
- ('someinvalidtoken', ValidateResult(AuthKind.basic, missing=True)),
- ('somefoobartoken', ValidateResult(AuthKind.basic, missing=True)),
- ('basic ', ValidateResult(AuthKind.basic, missing=True)),
- ('basic some token', ValidateResult(AuthKind.basic, missing=True)),
- ('basic sometoken', ValidateResult(AuthKind.basic, missing=True)),
- (_token(APP_SPECIFIC_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic,
- error_message='Invalid token')),
- (_token(ACCESS_TOKEN_USERNAME, 'invalid'), ValidateResult(AuthKind.basic,
- error_message='Invalid access token')),
- (_token(OAUTH_TOKEN_USERNAME, 'invalid'),
- ValidateResult(AuthKind.basic, error_message='OAuth access token could not be validated')),
- (_token('devtable', 'invalid'), ValidateResult(AuthKind.basic,
- error_message='Invalid Username or Password')),
- (_token('devtable+somebot', 'invalid'), ValidateResult(
- AuthKind.basic, error_message='Could not find robot with username: devtable+somebot')),
- (_token('disabled', 'password'), ValidateResult(
- AuthKind.basic,
- error_message='This user has been disabled. Please contact your administrator.')),])
+@pytest.mark.parametrize(
+ "token, expected_result",
+ [
+ ("", ValidateResult(AuthKind.basic, missing=True)),
+ ("someinvalidtoken", ValidateResult(AuthKind.basic, missing=True)),
+ ("somefoobartoken", ValidateResult(AuthKind.basic, missing=True)),
+ ("basic ", ValidateResult(AuthKind.basic, missing=True)),
+ ("basic some token", ValidateResult(AuthKind.basic, missing=True)),
+ ("basic sometoken", ValidateResult(AuthKind.basic, missing=True)),
+ (
+ _token(APP_SPECIFIC_TOKEN_USERNAME, "invalid"),
+ ValidateResult(AuthKind.basic, error_message="Invalid token"),
+ ),
+ (
+ _token(ACCESS_TOKEN_USERNAME, "invalid"),
+ ValidateResult(AuthKind.basic, error_message="Invalid access token"),
+ ),
+ (
+ _token(OAUTH_TOKEN_USERNAME, "invalid"),
+ ValidateResult(
+ AuthKind.basic,
+ error_message="OAuth access token could not be validated",
+ ),
+ ),
+ (
+ _token("devtable", "invalid"),
+ ValidateResult(
+ AuthKind.basic, error_message="Invalid Username or Password"
+ ),
+ ),
+ (
+ _token("devtable+somebot", "invalid"),
+ ValidateResult(
+ AuthKind.basic,
+ error_message="Could not find robot with username: devtable+somebot",
+ ),
+ ),
+ (
+ _token("disabled", "password"),
+ ValidateResult(
+ AuthKind.basic,
+ error_message="This user has been disabled. Please contact your administrator.",
+ ),
+ ),
+ ],
+)
def test_validate_basic_auth_token(token, expected_result, app):
- result = validate_basic_auth(token)
- assert result == expected_result
+ result = validate_basic_auth(token)
+ assert result == expected_result
def test_valid_user(app):
- token = _token('devtable', 'password')
- result = validate_basic_auth(token)
- assert result == ValidateResult(AuthKind.basic, user=model.user.get_user('devtable'))
+ token = _token("devtable", "password")
+ result = validate_basic_auth(token)
+ assert result == ValidateResult(
+ AuthKind.basic, user=model.user.get_user("devtable")
+ )
def test_valid_robot(app):
- robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable'))
- token = _token(robot.username, password)
- result = validate_basic_auth(token)
- assert result == ValidateResult(AuthKind.basic, robot=robot)
+ robot, password = model.user.create_robot(
+ "somerobot", model.user.get_user("devtable")
+ )
+ token = _token(robot.username, password)
+ result = validate_basic_auth(token)
+ assert result == ValidateResult(AuthKind.basic, robot=robot)
def test_valid_token(app):
- access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken')
- token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
- result = validate_basic_auth(token)
- assert result == ValidateResult(AuthKind.basic, token=access_token)
+ access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
+ token = _token(ACCESS_TOKEN_USERNAME, access_token.get_code())
+ result = validate_basic_auth(token)
+ assert result == ValidateResult(AuthKind.basic, token=access_token)
def test_valid_oauth(app):
- user = model.user.get_user('devtable')
- app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
- oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read')
- token = _token(OAUTH_TOKEN_USERNAME, code)
- result = validate_basic_auth(token)
- assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
+ user = model.user.get_user("devtable")
+ app = model.oauth.list_applications_for_org(
+ model.user.get_user_or_org("buynlarge")
+ )[0]
+ oauth_token, code = model.oauth.create_access_token_for_testing(
+ user, app.client_id, "repo:read"
+ )
+ token = _token(OAUTH_TOKEN_USERNAME, code)
+ result = validate_basic_auth(token)
+ assert result == ValidateResult(AuthKind.basic, oauthtoken=oauth_token)
def test_valid_app_specific_token(app):
- user = model.user.get_user('devtable')
- app_specific_token = model.appspecifictoken.create_token(user, 'some token')
- full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
- token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
- result = validate_basic_auth(token)
- assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token)
+ user = model.user.get_user("devtable")
+ app_specific_token = model.appspecifictoken.create_token(user, "some token")
+ full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
+ token = _token(APP_SPECIFIC_TOKEN_USERNAME, full_token)
+ result = validate_basic_auth(token)
+ assert result == ValidateResult(AuthKind.basic, appspecifictoken=app_specific_token)
def test_invalid_unicode(app):
- token = '\xebOH'
- header = 'basic ' + b64encode(token)
- result = validate_basic_auth(header)
- assert result == ValidateResult(AuthKind.basic, missing=True)
+ token = "\xebOH"
+ header = "basic " + b64encode(token)
+ result = validate_basic_auth(header)
+ assert result == ValidateResult(AuthKind.basic, missing=True)
def test_invalid_unicode_2(app):
- token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
- header = 'basic ' + b64encode('devtable+somerobot:%s' % token)
- result = validate_basic_auth(header)
- assert result == ValidateResult(
- AuthKind.basic,
- error_message='Could not find robot with username: devtable+somerobot and supplied password.')
+ token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
+ header = "basic " + b64encode("devtable+somerobot:%s" % token)
+ result = validate_basic_auth(header)
+ assert result == ValidateResult(
+ AuthKind.basic,
+ error_message="Could not find robot with username: devtable+somerobot and supplied password.",
+ )
diff --git a/auth/test/test_cookie.py b/auth/test/test_cookie.py
index 8c212d709..b9e69b571 100644
--- a/auth/test/test_cookie.py
+++ b/auth/test/test_cookie.py
@@ -9,58 +9,58 @@ from test.fixtures import *
def test_anonymous_cookie(app):
- assert validate_session_cookie().missing
+ assert validate_session_cookie().missing
def test_invalidformatted_cookie(app):
- # "Login" with a non-UUID reference.
- someuser = model.user.get_user('devtable')
- login_user(LoginWrappedDBUser('somenonuuid', someuser))
+ # "Login" with a non-UUID reference.
+ someuser = model.user.get_user("devtable")
+ login_user(LoginWrappedDBUser("somenonuuid", someuser))
- # Ensure we get an invalid session cookie format error.
- result = validate_session_cookie()
- assert result.authed_user is None
- assert result.context.identity is None
- assert not result.has_nonrobot_user
- assert result.error_message == 'Invalid session cookie format'
+ # Ensure we get an invalid session cookie format error.
+ result = validate_session_cookie()
+ assert result.authed_user is None
+ assert result.context.identity is None
+ assert not result.has_nonrobot_user
+ assert result.error_message == "Invalid session cookie format"
def test_disabled_user(app):
- # "Login" with a disabled user.
- someuser = model.user.get_user('disabled')
- login_user(LoginWrappedDBUser(someuser.uuid, someuser))
+ # "Login" with a disabled user.
+ someuser = model.user.get_user("disabled")
+ login_user(LoginWrappedDBUser(someuser.uuid, someuser))
- # Ensure we get an invalid session cookie format error.
- result = validate_session_cookie()
- assert result.authed_user is None
- assert result.context.identity is None
- assert not result.has_nonrobot_user
- assert result.error_message == 'User account is disabled'
+ # Ensure we get an invalid session cookie format error.
+ result = validate_session_cookie()
+ assert result.authed_user is None
+ assert result.context.identity is None
+ assert not result.has_nonrobot_user
+ assert result.error_message == "User account is disabled"
def test_valid_user(app):
- # Login with a valid user.
- someuser = model.user.get_user('devtable')
- login_user(LoginWrappedDBUser(someuser.uuid, someuser))
+ # Login with a valid user.
+ someuser = model.user.get_user("devtable")
+ login_user(LoginWrappedDBUser(someuser.uuid, someuser))
- result = validate_session_cookie()
- assert result.authed_user == someuser
- assert result.context.identity is not None
- assert result.has_nonrobot_user
- assert result.error_message is None
+ result = validate_session_cookie()
+ assert result.authed_user == someuser
+ assert result.context.identity is not None
+ assert result.has_nonrobot_user
+ assert result.error_message is None
def test_valid_organization(app):
- # "Login" with a valid organization.
- someorg = model.user.get_namespace_user('buynlarge')
- someorg.uuid = str(uuid.uuid4())
- someorg.verified = True
- someorg.save()
+ # "Login" with a valid organization.
+ someorg = model.user.get_namespace_user("buynlarge")
+ someorg.uuid = str(uuid.uuid4())
+ someorg.verified = True
+ someorg.save()
- login_user(LoginWrappedDBUser(someorg.uuid, someorg))
+ login_user(LoginWrappedDBUser(someorg.uuid, someorg))
- result = validate_session_cookie()
- assert result.authed_user is None
- assert result.context.identity is None
- assert not result.has_nonrobot_user
- assert result.error_message == 'Cannot login to organization'
+ result = validate_session_cookie()
+ assert result.authed_user is None
+ assert result.context.identity is None
+ assert not result.has_nonrobot_user
+ assert result.error_message == "Cannot login to organization"
diff --git a/auth/test/test_credentials.py b/auth/test/test_credentials.py
index 4e55c470c..08e5a39c1 100644
--- a/auth/test/test_credentials.py
+++ b/auth/test/test_credentials.py
@@ -1,147 +1,184 @@
# -*- coding: utf-8 -*-
from auth.credentials import validate_credentials, CredentialKind
-from auth.credential_consts import (ACCESS_TOKEN_USERNAME, OAUTH_TOKEN_USERNAME,
- APP_SPECIFIC_TOKEN_USERNAME)
+from auth.credential_consts import (
+ ACCESS_TOKEN_USERNAME,
+ OAUTH_TOKEN_USERNAME,
+ APP_SPECIFIC_TOKEN_USERNAME,
+)
from auth.validateresult import AuthKind, ValidateResult
from data import model
from test.fixtures import *
+
def test_valid_user(app):
- result, kind = validate_credentials('devtable', 'password')
- assert kind == CredentialKind.user
- assert result == ValidateResult(AuthKind.credentials, user=model.user.get_user('devtable'))
+ result, kind = validate_credentials("devtable", "password")
+ assert kind == CredentialKind.user
+ assert result == ValidateResult(
+ AuthKind.credentials, user=model.user.get_user("devtable")
+ )
+
def test_valid_robot(app):
- robot, password = model.user.create_robot('somerobot', model.user.get_user('devtable'))
- result, kind = validate_credentials(robot.username, password)
- assert kind == CredentialKind.robot
- assert result == ValidateResult(AuthKind.credentials, robot=robot)
+ robot, password = model.user.create_robot(
+ "somerobot", model.user.get_user("devtable")
+ )
+ result, kind = validate_credentials(robot.username, password)
+ assert kind == CredentialKind.robot
+ assert result == ValidateResult(AuthKind.credentials, robot=robot)
+
def test_valid_robot_for_disabled_user(app):
- user = model.user.get_user('devtable')
- user.enabled = False
- user.save()
+ user = model.user.get_user("devtable")
+ user.enabled = False
+ user.save()
- robot, password = model.user.create_robot('somerobot', user)
- result, kind = validate_credentials(robot.username, password)
- assert kind == CredentialKind.robot
+ robot, password = model.user.create_robot("somerobot", user)
+ result, kind = validate_credentials(robot.username, password)
+ assert kind == CredentialKind.robot
+
+ err = "This user has been disabled. Please contact your administrator."
+ assert result == ValidateResult(AuthKind.credentials, error_message=err)
- err = 'This user has been disabled. Please contact your administrator.'
- assert result == ValidateResult(AuthKind.credentials, error_message=err)
def test_valid_token(app):
- access_token = model.token.create_delegate_token('devtable', 'simple', 'sometoken')
- result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
- assert kind == CredentialKind.token
- assert result == ValidateResult(AuthKind.credentials, token=access_token)
+ access_token = model.token.create_delegate_token("devtable", "simple", "sometoken")
+ result, kind = validate_credentials(ACCESS_TOKEN_USERNAME, access_token.get_code())
+ assert kind == CredentialKind.token
+ assert result == ValidateResult(AuthKind.credentials, token=access_token)
+
def test_valid_oauth(app):
- user = model.user.get_user('devtable')
- app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
- oauth_token, code = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read')
- result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
- assert kind == CredentialKind.oauth_token
- assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
+ user = model.user.get_user("devtable")
+ app = model.oauth.list_applications_for_org(
+ model.user.get_user_or_org("buynlarge")
+ )[0]
+ oauth_token, code = model.oauth.create_access_token_for_testing(
+ user, app.client_id, "repo:read"
+ )
+ result, kind = validate_credentials(OAUTH_TOKEN_USERNAME, code)
+ assert kind == CredentialKind.oauth_token
+ assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
+
def test_invalid_user(app):
- result, kind = validate_credentials('devtable', 'somepassword')
- assert kind == CredentialKind.user
- assert result == ValidateResult(AuthKind.credentials,
- error_message='Invalid Username or Password')
+ result, kind = validate_credentials("devtable", "somepassword")
+ assert kind == CredentialKind.user
+ assert result == ValidateResult(
+ AuthKind.credentials, error_message="Invalid Username or Password"
+ )
+
def test_valid_app_specific_token(app):
- user = model.user.get_user('devtable')
- app_specific_token = model.appspecifictoken.create_token(user, 'some token')
- full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
- result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
- assert kind == CredentialKind.app_specific_token
- assert result == ValidateResult(AuthKind.credentials, appspecifictoken=app_specific_token)
+ user = model.user.get_user("devtable")
+ app_specific_token = model.appspecifictoken.create_token(user, "some token")
+ full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
+ result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
+ assert kind == CredentialKind.app_specific_token
+ assert result == ValidateResult(
+ AuthKind.credentials, appspecifictoken=app_specific_token
+ )
+
def test_valid_app_specific_token_for_disabled_user(app):
- user = model.user.get_user('devtable')
- user.enabled = False
- user.save()
+ user = model.user.get_user("devtable")
+ user.enabled = False
+ user.save()
- app_specific_token = model.appspecifictoken.create_token(user, 'some token')
- full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
- result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
- assert kind == CredentialKind.app_specific_token
+ app_specific_token = model.appspecifictoken.create_token(user, "some token")
+ full_token = model.appspecifictoken.get_full_token_string(app_specific_token)
+ result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
+ assert kind == CredentialKind.app_specific_token
- err = 'This user has been disabled. Please contact your administrator.'
- assert result == ValidateResult(AuthKind.credentials, error_message=err)
+ err = "This user has been disabled. Please contact your administrator."
+ assert result == ValidateResult(AuthKind.credentials, error_message=err)
+
+
+def test_invalid_app_specific_token(app):
+ result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, "somecode")
+ assert kind == CredentialKind.app_specific_token
+ assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
-def test_invalid_app_specific_token(app):
- result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, 'somecode')
- assert kind == CredentialKind.app_specific_token
- assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token')
def test_invalid_app_specific_token_code(app):
- user = model.user.get_user('devtable')
- app_specific_token = model.appspecifictoken.create_token(user, 'some token')
- full_token = app_specific_token.token_name + 'something'
- result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
- assert kind == CredentialKind.app_specific_token
- assert result == ValidateResult(AuthKind.credentials, error_message='Invalid token')
+ user = model.user.get_user("devtable")
+ app_specific_token = model.appspecifictoken.create_token(user, "some token")
+ full_token = app_specific_token.token_name + "something"
+ result, kind = validate_credentials(APP_SPECIFIC_TOKEN_USERNAME, full_token)
+ assert kind == CredentialKind.app_specific_token
+ assert result == ValidateResult(AuthKind.credentials, error_message="Invalid token")
-def test_unicode(app):
- result, kind = validate_credentials('someusername', 'some₪code')
- assert kind == CredentialKind.user
- assert not result.auth_valid
- assert result == ValidateResult(AuthKind.credentials,
- error_message='Invalid Username or Password')
-def test_unicode_robot(app):
- robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable'))
- result, kind = validate_credentials(robot.username, 'some₪code')
+def test_unicode(app):
+ result, kind = validate_credentials("someusername", "some₪code")
+ assert kind == CredentialKind.user
+ assert not result.auth_valid
+ assert result == ValidateResult(
+ AuthKind.credentials, error_message="Invalid Username or Password"
+ )
- assert kind == CredentialKind.robot
- assert not result.auth_valid
- msg = 'Could not find robot with username: devtable+somerobot and supplied password.'
- assert result == ValidateResult(AuthKind.credentials, error_message=msg)
+def test_unicode_robot(app):
+ robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
+ result, kind = validate_credentials(robot.username, "some₪code")
+
+ assert kind == CredentialKind.robot
+ assert not result.auth_valid
+
+ msg = (
+ "Could not find robot with username: devtable+somerobot and supplied password."
+ )
+ assert result == ValidateResult(AuthKind.credentials, error_message=msg)
+
def test_invalid_user(app):
- result, kind = validate_credentials('someinvaliduser', 'password')
- assert kind == CredentialKind.user
- assert not result.authed_user
- assert not result.auth_valid
+ result, kind = validate_credentials("someinvaliduser", "password")
+ assert kind == CredentialKind.user
+ assert not result.authed_user
+ assert not result.auth_valid
+
def test_invalid_user_password(app):
- result, kind = validate_credentials('devtable', 'somepassword')
- assert kind == CredentialKind.user
- assert not result.authed_user
- assert not result.auth_valid
+ result, kind = validate_credentials("devtable", "somepassword")
+ assert kind == CredentialKind.user
+ assert not result.authed_user
+ assert not result.auth_valid
+
def test_invalid_robot(app):
- result, kind = validate_credentials('devtable+doesnotexist', 'password')
- assert kind == CredentialKind.robot
- assert not result.authed_user
- assert not result.auth_valid
+ result, kind = validate_credentials("devtable+doesnotexist", "password")
+ assert kind == CredentialKind.robot
+ assert not result.authed_user
+ assert not result.auth_valid
+
def test_invalid_robot_token(app):
- robot, _ = model.user.create_robot('somerobot', model.user.get_user('devtable'))
- result, kind = validate_credentials(robot.username, 'invalidpassword')
- assert kind == CredentialKind.robot
- assert not result.authed_user
- assert not result.auth_valid
+ robot, _ = model.user.create_robot("somerobot", model.user.get_user("devtable"))
+ result, kind = validate_credentials(robot.username, "invalidpassword")
+ assert kind == CredentialKind.robot
+ assert not result.authed_user
+ assert not result.auth_valid
+
def test_invalid_unicode_robot(app):
- token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
- result, kind = validate_credentials('devtable+somerobot', token)
- assert kind == CredentialKind.robot
- assert not result.auth_valid
- msg = 'Could not find robot with username: devtable+somerobot'
- assert result == ValidateResult(AuthKind.credentials, error_message=msg)
+ token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
+ result, kind = validate_credentials("devtable+somerobot", token)
+ assert kind == CredentialKind.robot
+ assert not result.auth_valid
+ msg = "Could not find robot with username: devtable+somerobot"
+ assert result == ValidateResult(AuthKind.credentials, error_message=msg)
+
def test_invalid_unicode_robot_2(app):
- user = model.user.get_user('devtable')
- robot, password = model.user.create_robot('somerobot', user)
+ user = model.user.get_user("devtable")
+ robot, password = model.user.create_robot("somerobot", user)
- token = '“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”'
- result, kind = validate_credentials('devtable+somerobot', token)
- assert kind == CredentialKind.robot
- assert not result.auth_valid
- msg = 'Could not find robot with username: devtable+somerobot and supplied password.'
- assert result == ValidateResult(AuthKind.credentials, error_message=msg)
+ token = "“4JPCOLIVMAY32Q3XGVPHC4CBF8SKII5FWNYMASOFDIVSXTC5I5NBU”"
+ result, kind = validate_credentials("devtable+somerobot", token)
+ assert kind == CredentialKind.robot
+ assert not result.auth_valid
+ msg = (
+ "Could not find robot with username: devtable+somerobot and supplied password."
+ )
+ assert result == ValidateResult(AuthKind.credentials, error_message=msg)
diff --git a/auth/test/test_decorators.py b/auth/test/test_decorators.py
index b0477f7bd..87b4d2ae9 100644
--- a/auth/test/test_decorators.py
+++ b/auth/test/test_decorators.py
@@ -7,99 +7,102 @@ from werkzeug.exceptions import HTTPException
from app import LoginWrappedDBUser
from auth.auth_context import get_authenticated_user
from auth.decorators import (
- extract_namespace_repo_from_session, require_session_login, process_auth_or_cookie)
+ extract_namespace_repo_from_session,
+ require_session_login,
+ process_auth_or_cookie,
+)
from data import model
from test.fixtures import *
def test_extract_namespace_repo_from_session_missing(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- session.clear()
- with pytest.raises(HTTPException):
- extract_namespace_repo_from_session(emptyfunc)()
+ session.clear()
+ with pytest.raises(HTTPException):
+ extract_namespace_repo_from_session(emptyfunc)()
def test_extract_namespace_repo_from_session_present(app):
- encountered = []
+ encountered = []
- def somefunc(namespace, repository):
- encountered.append(namespace)
- encountered.append(repository)
+ def somefunc(namespace, repository):
+ encountered.append(namespace)
+ encountered.append(repository)
- # Add the namespace and repository to the session.
- session.clear()
- session['namespace'] = 'foo'
- session['repository'] = 'bar'
+ # Add the namespace and repository to the session.
+ session.clear()
+ session["namespace"] = "foo"
+ session["repository"] = "bar"
- # Call the decorated method.
- extract_namespace_repo_from_session(somefunc)()
+ # Call the decorated method.
+ extract_namespace_repo_from_session(somefunc)()
- assert encountered[0] == 'foo'
- assert encountered[1] == 'bar'
+ assert encountered[0] == "foo"
+ assert encountered[1] == "bar"
def test_require_session_login_missing(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- with pytest.raises(HTTPException):
- require_session_login(emptyfunc)()
+ with pytest.raises(HTTPException):
+ require_session_login(emptyfunc)()
def test_require_session_login_valid_user(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- # Login as a valid user.
- someuser = model.user.get_user('devtable')
- login_user(LoginWrappedDBUser(someuser.uuid, someuser))
+ # Login as a valid user.
+ someuser = model.user.get_user("devtable")
+ login_user(LoginWrappedDBUser(someuser.uuid, someuser))
- # Call the function.
- require_session_login(emptyfunc)()
+ # Call the function.
+ require_session_login(emptyfunc)()
- # Ensure the authenticated user was updated.
- assert get_authenticated_user() == someuser
+ # Ensure the authenticated user was updated.
+ assert get_authenticated_user() == someuser
def test_require_session_login_invalid_user(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- # "Login" as a disabled user.
- someuser = model.user.get_user('disabled')
- login_user(LoginWrappedDBUser(someuser.uuid, someuser))
+ # "Login" as a disabled user.
+ someuser = model.user.get_user("disabled")
+ login_user(LoginWrappedDBUser(someuser.uuid, someuser))
- # Call the function.
- with pytest.raises(HTTPException):
- require_session_login(emptyfunc)()
+ # Call the function.
+ with pytest.raises(HTTPException):
+ require_session_login(emptyfunc)()
- # Ensure the authenticated user was not updated.
- assert get_authenticated_user() is None
+ # Ensure the authenticated user was not updated.
+ assert get_authenticated_user() is None
def test_process_auth_or_cookie_invalid_user(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- # Call the function.
- process_auth_or_cookie(emptyfunc)()
+ # Call the function.
+ process_auth_or_cookie(emptyfunc)()
- # Ensure the authenticated user was not updated.
- assert get_authenticated_user() is None
+ # Ensure the authenticated user was not updated.
+ assert get_authenticated_user() is None
def test_process_auth_or_cookie_valid_user(app):
- def emptyfunc():
- pass
+ def emptyfunc():
+ pass
- # Login as a valid user.
- someuser = model.user.get_user('devtable')
- login_user(LoginWrappedDBUser(someuser.uuid, someuser))
+ # Login as a valid user.
+ someuser = model.user.get_user("devtable")
+ login_user(LoginWrappedDBUser(someuser.uuid, someuser))
- # Call the function.
- process_auth_or_cookie(emptyfunc)()
+ # Call the function.
+ process_auth_or_cookie(emptyfunc)()
- # Ensure the authenticated user was updated.
- assert get_authenticated_user() == someuser
+ # Ensure the authenticated user was updated.
+ assert get_authenticated_user() == someuser
diff --git a/auth/test/test_oauth.py b/auth/test/test_oauth.py
index f678f2604..1453e878a 100644
--- a/auth/test/test_oauth.py
+++ b/auth/test/test_oauth.py
@@ -6,50 +6,63 @@ from data import model
from test.fixtures import *
-@pytest.mark.parametrize('header, expected_result', [
- ('', ValidateResult(AuthKind.oauth, missing=True)),
- ('somerandomtoken', ValidateResult(AuthKind.oauth, missing=True)),
- ('bearer some random token', ValidateResult(AuthKind.oauth, missing=True)),
- ('bearer invalidtoken',
- ValidateResult(AuthKind.oauth, error_message='OAuth access token could not be validated')),])
+@pytest.mark.parametrize(
+ "header, expected_result",
+ [
+ ("", ValidateResult(AuthKind.oauth, missing=True)),
+ ("somerandomtoken", ValidateResult(AuthKind.oauth, missing=True)),
+ ("bearer some random token", ValidateResult(AuthKind.oauth, missing=True)),
+ (
+ "bearer invalidtoken",
+ ValidateResult(
+ AuthKind.oauth,
+ error_message="OAuth access token could not be validated",
+ ),
+ ),
+ ],
+)
def test_bearer(header, expected_result, app):
- assert validate_bearer_auth(header) == expected_result
+ assert validate_bearer_auth(header) == expected_result
def test_valid_oauth(app):
- user = model.user.get_user('devtable')
- app = model.oauth.list_applications_for_org(model.user.get_user_or_org('buynlarge'))[0]
- token_string = '%s%s' % ('a' * 20, 'b' * 20)
- oauth_token, _ = model.oauth.create_access_token_for_testing(user, app.client_id, 'repo:read',
- access_token=token_string)
- result = validate_bearer_auth('bearer ' + token_string)
- assert result.context.oauthtoken == oauth_token
- assert result.authed_user == user
- assert result.auth_valid
+ user = model.user.get_user("devtable")
+ app = model.oauth.list_applications_for_org(
+ model.user.get_user_or_org("buynlarge")
+ )[0]
+ token_string = "%s%s" % ("a" * 20, "b" * 20)
+ oauth_token, _ = model.oauth.create_access_token_for_testing(
+ user, app.client_id, "repo:read", access_token=token_string
+ )
+ result = validate_bearer_auth("bearer " + token_string)
+ assert result.context.oauthtoken == oauth_token
+ assert result.authed_user == user
+ assert result.auth_valid
def test_disabled_user_oauth(app):
- user = model.user.get_user('disabled')
- token_string = '%s%s' % ('a' * 20, 'b' * 20)
- oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin',
- access_token=token_string)
+ user = model.user.get_user("disabled")
+ token_string = "%s%s" % ("a" * 20, "b" * 20)
+ oauth_token, _ = model.oauth.create_access_token_for_testing(
+ user, "deadbeef", "repo:admin", access_token=token_string
+ )
- result = validate_bearer_auth('bearer ' + token_string)
- assert result.context.oauthtoken is None
- assert result.authed_user is None
- assert not result.auth_valid
- assert result.error_message == 'Granter of the oauth access token is disabled'
+ result = validate_bearer_auth("bearer " + token_string)
+ assert result.context.oauthtoken is None
+ assert result.authed_user is None
+ assert not result.auth_valid
+ assert result.error_message == "Granter of the oauth access token is disabled"
def test_expired_token(app):
- user = model.user.get_user('devtable')
- token_string = '%s%s' % ('a' * 20, 'b' * 20)
- oauth_token, _ = model.oauth.create_access_token_for_testing(user, 'deadbeef', 'repo:admin',
- access_token=token_string,
- expires_in=-1000)
+ user = model.user.get_user("devtable")
+ token_string = "%s%s" % ("a" * 20, "b" * 20)
+ oauth_token, _ = model.oauth.create_access_token_for_testing(
+ user, "deadbeef", "repo:admin", access_token=token_string, expires_in=-1000
+ )
- result = validate_bearer_auth('bearer ' + token_string)
- assert result.context.oauthtoken is None
- assert result.authed_user is None
- assert not result.auth_valid
- assert result.error_message == 'OAuth access token has expired'
+ result = validate_bearer_auth("bearer " + token_string)
+ assert result.context.oauthtoken is None
+ assert result.authed_user is None
+ assert not result.auth_valid
+ assert result.error_message == "OAuth access token has expired"
diff --git a/auth/test/test_permissions.py b/auth/test/test_permissions.py
index f2849934d..e97a11f1d 100644
--- a/auth/test/test_permissions.py
+++ b/auth/test/test_permissions.py
@@ -6,32 +6,33 @@ from data import model
from test.fixtures import *
-SUPER_USERNAME = 'devtable'
-UNSUPER_USERNAME = 'freshuser'
+SUPER_USERNAME = "devtable"
+UNSUPER_USERNAME = "freshuser"
+
@pytest.fixture()
def superuser(initialized_db):
- return model.user.get_user(SUPER_USERNAME)
+ return model.user.get_user(SUPER_USERNAME)
@pytest.fixture()
def normie(initialized_db):
- return model.user.get_user(UNSUPER_USERNAME)
+ return model.user.get_user(UNSUPER_USERNAME)
def test_superuser_matrix(superuser, normie):
- test_cases = [
- (superuser, {scopes.SUPERUSER}, True),
- (superuser, {scopes.DIRECT_LOGIN}, True),
- (superuser, {scopes.READ_USER, scopes.SUPERUSER}, True),
- (superuser, {scopes.READ_USER}, False),
- (normie, {scopes.SUPERUSER}, False),
- (normie, {scopes.DIRECT_LOGIN}, False),
- (normie, {scopes.READ_USER, scopes.SUPERUSER}, False),
- (normie, {scopes.READ_USER}, False),
- ]
+ test_cases = [
+ (superuser, {scopes.SUPERUSER}, True),
+ (superuser, {scopes.DIRECT_LOGIN}, True),
+ (superuser, {scopes.READ_USER, scopes.SUPERUSER}, True),
+ (superuser, {scopes.READ_USER}, False),
+ (normie, {scopes.SUPERUSER}, False),
+ (normie, {scopes.DIRECT_LOGIN}, False),
+ (normie, {scopes.READ_USER, scopes.SUPERUSER}, False),
+ (normie, {scopes.READ_USER}, False),
+ ]
- for user_obj, scope_set, expected in test_cases:
- perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set)
- has_su = perm_user.can(SuperUserPermission())
- assert has_su == expected
+ for user_obj, scope_set, expected in test_cases:
+ perm_user = QuayDeferredPermissionUser.for_user(user_obj, scope_set)
+ has_su = perm_user.can(SuperUserPermission())
+ assert has_su == expected
diff --git a/auth/test/test_registry_jwt.py b/auth/test/test_registry_jwt.py
index fc6548d74..ffcb9fca7 100644
--- a/auth/test/test_registry_jwt.py
+++ b/auth/test/test_registry_jwt.py
@@ -14,190 +14,226 @@ from initdb import setup_database_for_testing, finished_database_for_testing
from util.morecollections import AttrDict
from util.security.registry_jwt import ANONYMOUS_SUB, build_context_and_subject
-TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
-TEST_USER = AttrDict({'username': 'joeuser', 'uuid': 'foobar', 'enabled': True})
+TEST_AUDIENCE = app.config["SERVER_HOSTNAME"]
+TEST_USER = AttrDict({"username": "joeuser", "uuid": "foobar", "enabled": True})
MAX_SIGNED_S = 3660
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
-ANONYMOUS_SUB = '(anonymous)'
-SERVICE_NAME = 'quay'
+ANONYMOUS_SUB = "(anonymous)"
+SERVICE_NAME = "quay"
# This import has to come below any references to "app".
from test.fixtures import *
-def _access(typ='repository', name='somens/somerepo', actions=None):
- actions = [] if actions is None else actions
- return [{
- 'type': typ,
- 'name': name,
- 'actions': actions,
- }]
+def _access(typ="repository", name="somens/somerepo", actions=None):
+ actions = [] if actions is None else actions
+ return [{"type": typ, "name": name, "actions": actions}]
def _delete_field(token_data, field_name):
- token_data.pop(field_name)
- return token_data
+ token_data.pop(field_name)
+ return token_data
-def _token_data(access=[], context=None, audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
- exp=None, nbf=None, iss=None, subject=None):
- if subject is None:
- _, subject = build_context_and_subject(ValidatedAuthContext(user=user))
- return {
- 'iss': iss or instance_keys.service_name,
- 'aud': audience,
- 'nbf': nbf if nbf is not None else int(time.time()),
- 'iat': iat if iat is not None else int(time.time()),
- 'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
- 'sub': subject,
- 'access': access,
- 'context': context,
- }
+def _token_data(
+ access=[],
+ context=None,
+ audience=TEST_AUDIENCE,
+ user=TEST_USER,
+ iat=None,
+ exp=None,
+ nbf=None,
+ iss=None,
+ subject=None,
+):
+ if subject is None:
+ _, subject = build_context_and_subject(ValidatedAuthContext(user=user))
+ return {
+ "iss": iss or instance_keys.service_name,
+ "aud": audience,
+ "nbf": nbf if nbf is not None else int(time.time()),
+ "iat": iat if iat is not None else int(time.time()),
+ "exp": exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
+ "sub": subject,
+ "access": access,
+ "context": context,
+ }
def _token(token_data, key_id=None, private_key=None, skip_header=False, alg=None):
- key_id = key_id or instance_keys.local_key_id
- private_key = private_key or instance_keys.local_private_key
+ key_id = key_id or instance_keys.local_key_id
+ private_key = private_key or instance_keys.local_private_key
- if alg == "none":
- private_key = None
+ if alg == "none":
+ private_key = None
- token_headers = {'kid': key_id}
+ token_headers = {"kid": key_id}
- if skip_header:
- token_headers = {}
+ if skip_header:
+ token_headers = {}
- token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers)
- return 'Bearer {0}'.format(token_data)
+ token_data = jwt.encode(
+ token_data, private_key, alg or "RS256", headers=token_headers
+ )
+ return "Bearer {0}".format(token_data)
def _parse_token(token):
- return identity_from_bearer_token(token)[0]
+ return identity_from_bearer_token(token)[0]
def test_accepted_token(initialized_db):
- token = _token(_token_data())
- identity = _parse_token(token)
- assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username,
- identity.id)
- assert len(identity.provides) == 0
+ token = _token(_token_data())
+ identity = _parse_token(token)
+ assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
+ TEST_USER.username,
+ identity.id,
+ )
+ assert len(identity.provides) == 0
- anon_token = _token(_token_data(user=None))
- anon_identity = _parse_token(anon_token)
- assert anon_identity.id == ANONYMOUS_SUB, 'should be %s, but was %s' % (ANONYMOUS_SUB,
- anon_identity.id)
- assert len(identity.provides) == 0
+ anon_token = _token(_token_data(user=None))
+ anon_identity = _parse_token(anon_token)
+ assert anon_identity.id == ANONYMOUS_SUB, "should be %s, but was %s" % (
+ ANONYMOUS_SUB,
+ anon_identity.id,
+ )
+ assert len(identity.provides) == 0
-@pytest.mark.parametrize('access', [
- (_access(actions=['pull', 'push'])),
- (_access(actions=['pull', '*'])),
- (_access(actions=['*', 'push'])),
- (_access(actions=['*'])),
- (_access(actions=['pull', '*', 'push'])),])
+@pytest.mark.parametrize(
+ "access",
+ [
+ (_access(actions=["pull", "push"])),
+ (_access(actions=["pull", "*"])),
+ (_access(actions=["*", "push"])),
+ (_access(actions=["*"])),
+ (_access(actions=["pull", "*", "push"])),
+ ],
+)
def test_token_with_access(access, initialized_db):
- token = _token(_token_data(access=access))
- identity = _parse_token(token)
- assert identity.id == TEST_USER.username, 'should be %s, but was %s' % (TEST_USER.username,
- identity.id)
- assert len(identity.provides) == 1
+ token = _token(_token_data(access=access))
+ identity = _parse_token(token)
+ assert identity.id == TEST_USER.username, "should be %s, but was %s" % (
+ TEST_USER.username,
+ identity.id,
+ )
+ assert len(identity.provides) == 1
- role = list(identity.provides)[0][3]
- if "*" in access[0]['actions']:
- assert role == 'admin'
- elif "push" in access[0]['actions']:
- assert role == 'write'
- elif "pull" in access[0]['actions']:
- assert role == 'read'
+ role = list(identity.provides)[0][3]
+ if "*" in access[0]["actions"]:
+ assert role == "admin"
+ elif "push" in access[0]["actions"]:
+ assert role == "write"
+ elif "pull" in access[0]["actions"]:
+ assert role == "read"
-@pytest.mark.parametrize('token', [
- pytest.param(_token(
- _token_data(access=[{
- 'toipe': 'repository',
- 'namesies': 'somens/somerepo',
- 'akshuns': ['pull', 'push', '*']}])), id='bad access'),
- pytest.param(_token(_token_data(audience='someotherapp')), id='bad aud'),
- pytest.param(_token(_delete_field(_token_data(), 'aud')), id='no aud'),
- pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id='future nbf'),
- pytest.param(_token(_delete_field(_token_data(), 'nbf')), id='no nbf'),
- pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id='future iat'),
- pytest.param(_token(_delete_field(_token_data(), 'iat')), id='no iat'),
- pytest.param(_token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)), id='exp too long'),
- pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id='expired'),
- pytest.param(_token(_delete_field(_token_data(), 'exp')), id='no exp'),
- pytest.param(_token(_delete_field(_token_data(), 'sub')), id='no sub'),
- pytest.param(_token(_token_data(iss='badissuer')), id='bad iss'),
- pytest.param(_token(_delete_field(_token_data(), 'iss')), id='no iss'),
- pytest.param(_token(_token_data(), skip_header=True), id='no header'),
- pytest.param(_token(_token_data(), key_id='someunknownkey'), id='bad key'),
- pytest.param(_token(_token_data(), key_id='kid7'), id='bad key :: kid7'),
- pytest.param(_token(_token_data(), alg='none', private_key=None), id='none alg'),
- pytest.param('some random token', id='random token'),
- pytest.param('Bearer: sometokenhere', id='extra bearer'),
- pytest.param('\nBearer: dGVzdA', id='leading newline'),
-])
+@pytest.mark.parametrize(
+ "token",
+ [
+ pytest.param(
+ _token(
+ _token_data(
+ access=[
+ {
+ "toipe": "repository",
+ "namesies": "somens/somerepo",
+ "akshuns": ["pull", "push", "*"],
+ }
+ ]
+ )
+ ),
+ id="bad access",
+ ),
+ pytest.param(_token(_token_data(audience="someotherapp")), id="bad aud"),
+ pytest.param(_token(_delete_field(_token_data(), "aud")), id="no aud"),
+ pytest.param(_token(_token_data(nbf=int(time.time()) + 600)), id="future nbf"),
+ pytest.param(_token(_delete_field(_token_data(), "nbf")), id="no nbf"),
+ pytest.param(_token(_token_data(iat=int(time.time()) + 600)), id="future iat"),
+ pytest.param(_token(_delete_field(_token_data(), "iat")), id="no iat"),
+ pytest.param(
+ _token(_token_data(exp=int(time.time()) + MAX_SIGNED_S * 2)),
+ id="exp too long",
+ ),
+ pytest.param(_token(_token_data(exp=int(time.time()) - 60)), id="expired"),
+ pytest.param(_token(_delete_field(_token_data(), "exp")), id="no exp"),
+ pytest.param(_token(_delete_field(_token_data(), "sub")), id="no sub"),
+ pytest.param(_token(_token_data(iss="badissuer")), id="bad iss"),
+ pytest.param(_token(_delete_field(_token_data(), "iss")), id="no iss"),
+ pytest.param(_token(_token_data(), skip_header=True), id="no header"),
+ pytest.param(_token(_token_data(), key_id="someunknownkey"), id="bad key"),
+ pytest.param(_token(_token_data(), key_id="kid7"), id="bad key :: kid7"),
+ pytest.param(
+ _token(_token_data(), alg="none", private_key=None), id="none alg"
+ ),
+ pytest.param("some random token", id="random token"),
+ pytest.param("Bearer: sometokenhere", id="extra bearer"),
+ pytest.param("\nBearer: dGVzdA", id="leading newline"),
+ ],
+)
def test_invalid_jwt(token, initialized_db):
- with pytest.raises(InvalidJWTException):
- _parse_token(token)
+ with pytest.raises(InvalidJWTException):
+ _parse_token(token)
def test_mixing_keys_e2e(initialized_db):
- token_data = _token_data()
+ token_data = _token_data()
- # Create a new key for testing.
- p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey',
- name='newkey', metadata={})
- private_key = p.exportKey('PEM')
+ # Create a new key for testing.
+ p, key = model.service_keys.generate_service_key(
+ instance_keys.service_name, None, kid="newkey", name="newkey", metadata={}
+ )
+ private_key = p.exportKey("PEM")
- # Test first with the new valid, but unapproved key.
- unapproved_key_token = _token(token_data, key_id='newkey', private_key=private_key)
- with pytest.raises(InvalidJWTException):
- _parse_token(unapproved_key_token)
+ # Test first with the new valid, but unapproved key.
+ unapproved_key_token = _token(token_data, key_id="newkey", private_key=private_key)
+ with pytest.raises(InvalidJWTException):
+ _parse_token(unapproved_key_token)
- # Approve the key and try again.
- admin_user = model.user.get_user('devtable')
- model.service_keys.approve_service_key(key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user)
+ # Approve the key and try again.
+ admin_user = model.user.get_user("devtable")
+ model.service_keys.approve_service_key(
+ key.kid, ServiceKeyApprovalType.SUPERUSER, approver=admin_user
+ )
- valid_token = _token(token_data, key_id='newkey', private_key=private_key)
+ valid_token = _token(token_data, key_id="newkey", private_key=private_key)
- identity = _parse_token(valid_token)
- assert identity.id == TEST_USER.username
- assert len(identity.provides) == 0
+ identity = _parse_token(valid_token)
+ assert identity.id == TEST_USER.username
+ assert len(identity.provides) == 0
- # Try using a different private key with the existing key ID.
- bad_private_token = _token(token_data, key_id='newkey',
- private_key=instance_keys.local_private_key)
- with pytest.raises(InvalidJWTException):
- _parse_token(bad_private_token)
+ # Try using a different private key with the existing key ID.
+ bad_private_token = _token(
+ token_data, key_id="newkey", private_key=instance_keys.local_private_key
+ )
+ with pytest.raises(InvalidJWTException):
+ _parse_token(bad_private_token)
- # Try using a different key ID with the existing private key.
- kid_mismatch_token = _token(token_data, key_id=instance_keys.local_key_id,
- private_key=private_key)
- with pytest.raises(InvalidJWTException):
- _parse_token(kid_mismatch_token)
+ # Try using a different key ID with the existing private key.
+ kid_mismatch_token = _token(
+ token_data, key_id=instance_keys.local_key_id, private_key=private_key
+ )
+ with pytest.raises(InvalidJWTException):
+ _parse_token(kid_mismatch_token)
- # Delete the new key.
- key.delete_instance(recursive=True)
+ # Delete the new key.
+ key.delete_instance(recursive=True)
- # Ensure it still works (via the cache.)
- deleted_key_token = _token(token_data, key_id='newkey', private_key=private_key)
- identity = _parse_token(deleted_key_token)
- assert identity.id == TEST_USER.username
- assert len(identity.provides) == 0
+ # Ensure it still works (via the cache.)
+ deleted_key_token = _token(token_data, key_id="newkey", private_key=private_key)
+ identity = _parse_token(deleted_key_token)
+ assert identity.id == TEST_USER.username
+ assert len(identity.provides) == 0
- # Break the cache.
- instance_keys.clear_cache()
+ # Break the cache.
+ instance_keys.clear_cache()
- # Ensure the key no longer works.
- with pytest.raises(InvalidJWTException):
- _parse_token(deleted_key_token)
+ # Ensure the key no longer works.
+ with pytest.raises(InvalidJWTException):
+ _parse_token(deleted_key_token)
-@pytest.mark.parametrize('token', [
- u'someunicodetoken✡',
- u'\xc9\xad\xbd',
-])
+@pytest.mark.parametrize("token", [u"someunicodetoken✡", u"\xc9\xad\xbd"])
def test_unicode_token(token):
- with pytest.raises(InvalidJWTException):
- _parse_token(token)
+ with pytest.raises(InvalidJWTException):
+ _parse_token(token)
diff --git a/auth/test/test_scopes.py b/auth/test/test_scopes.py
index b71140136..a5aa883ea 100644
--- a/auth/test/test_scopes.py
+++ b/auth/test/test_scopes.py
@@ -1,50 +1,55 @@
import pytest
from auth.scopes import (
- scopes_from_scope_string, validate_scope_string, ALL_SCOPES, is_subset_string)
+ scopes_from_scope_string,
+ validate_scope_string,
+ ALL_SCOPES,
+ is_subset_string,
+)
@pytest.mark.parametrize(
- 'scopes_string, expected',
- [
- # Valid single scopes.
- ('repo:read', ['repo:read']),
- ('repo:admin', ['repo:admin']),
-
- # Invalid scopes.
- ('not:valid', []),
- ('repo:admins', []),
-
- # Valid scope strings.
- ('repo:read repo:admin', ['repo:read', 'repo:admin']),
- ('repo:read,repo:admin', ['repo:read', 'repo:admin']),
- ('repo:read,repo:admin repo:write', ['repo:read', 'repo:admin', 'repo:write']),
-
- # Partially invalid scopes.
- ('repo:read,not:valid', []),
- ('repo:read repo:admins', []),
-
- # Invalid scope strings.
- ('repo:read|repo:admin', []),
-
- # Mixture of delimiters.
- ('repo:read, repo:admin', []),])
+ "scopes_string, expected",
+ [
+ # Valid single scopes.
+ ("repo:read", ["repo:read"]),
+ ("repo:admin", ["repo:admin"]),
+ # Invalid scopes.
+ ("not:valid", []),
+ ("repo:admins", []),
+ # Valid scope strings.
+ ("repo:read repo:admin", ["repo:read", "repo:admin"]),
+ ("repo:read,repo:admin", ["repo:read", "repo:admin"]),
+ ("repo:read,repo:admin repo:write", ["repo:read", "repo:admin", "repo:write"]),
+ # Partially invalid scopes.
+ ("repo:read,not:valid", []),
+ ("repo:read repo:admins", []),
+ # Invalid scope strings.
+ ("repo:read|repo:admin", []),
+ # Mixture of delimiters.
+ ("repo:read, repo:admin", []),
+ ],
+)
def test_parsing(scopes_string, expected):
- expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
- parsed_scope_set = scopes_from_scope_string(scopes_string)
- assert parsed_scope_set == expected_scope_set
- assert validate_scope_string(scopes_string) == bool(expected)
+ expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in expected}
+ parsed_scope_set = scopes_from_scope_string(scopes_string)
+ assert parsed_scope_set == expected_scope_set
+ assert validate_scope_string(scopes_string) == bool(expected)
-@pytest.mark.parametrize('superset, subset, result', [
- ('repo:read', 'repo:read', True),
- ('repo:read repo:admin', 'repo:read', True),
- ('repo:read,repo:admin', 'repo:read', True),
- ('repo:read,repo:admin', 'repo:admin', True),
- ('repo:read,repo:admin', 'repo:admin repo:read', True),
- ('', 'repo:read', False),
- ('unknown:tag', 'repo:read', False),
- ('repo:read unknown:tag', 'repo:read', False),
- ('repo:read,unknown:tag', 'repo:read', False),])
+@pytest.mark.parametrize(
+ "superset, subset, result",
+ [
+ ("repo:read", "repo:read", True),
+ ("repo:read repo:admin", "repo:read", True),
+ ("repo:read,repo:admin", "repo:read", True),
+ ("repo:read,repo:admin", "repo:admin", True),
+ ("repo:read,repo:admin", "repo:admin repo:read", True),
+ ("", "repo:read", False),
+ ("unknown:tag", "repo:read", False),
+ ("repo:read unknown:tag", "repo:read", False),
+ ("repo:read,unknown:tag", "repo:read", False),
+ ],
+)
def test_subset_string(superset, subset, result):
- assert is_subset_string(superset, subset) == result
+ assert is_subset_string(superset, subset) == result
diff --git a/auth/test/test_signedgrant.py b/auth/test/test_signedgrant.py
index e200f0bf1..5575a032d 100644
--- a/auth/test/test_signedgrant.py
+++ b/auth/test/test_signedgrant.py
@@ -1,32 +1,47 @@
import pytest
-from auth.signedgrant import validate_signed_grant, generate_signed_token, SIGNATURE_PREFIX
+from auth.signedgrant import (
+ validate_signed_grant,
+ generate_signed_token,
+ SIGNATURE_PREFIX,
+)
from auth.validateresult import AuthKind, ValidateResult
-@pytest.mark.parametrize('header, expected_result', [
- pytest.param('', ValidateResult(AuthKind.signed_grant, missing=True), id='Missing'),
- pytest.param('somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True),
- id='Invalid header'),
- pytest.param('token somerandomtoken', ValidateResult(AuthKind.signed_grant, missing=True),
- id='Random Token'),
- pytest.param('token ' + SIGNATURE_PREFIX + 'foo',
- ValidateResult(AuthKind.signed_grant,
- error_message='Signed grant could not be validated'),
- id='Invalid token'),
-])
+@pytest.mark.parametrize(
+ "header, expected_result",
+ [
+ pytest.param(
+ "", ValidateResult(AuthKind.signed_grant, missing=True), id="Missing"
+ ),
+ pytest.param(
+ "somerandomtoken",
+ ValidateResult(AuthKind.signed_grant, missing=True),
+ id="Invalid header",
+ ),
+ pytest.param(
+ "token somerandomtoken",
+ ValidateResult(AuthKind.signed_grant, missing=True),
+ id="Random Token",
+ ),
+ pytest.param(
+ "token " + SIGNATURE_PREFIX + "foo",
+ ValidateResult(
+ AuthKind.signed_grant,
+ error_message="Signed grant could not be validated",
+ ),
+ id="Invalid token",
+ ),
+ ],
+)
def test_token(header, expected_result):
- assert validate_signed_grant(header) == expected_result
+ assert validate_signed_grant(header) == expected_result
def test_valid_grant():
- header = 'token ' + generate_signed_token({'a': 'b'}, {'c': 'd'})
- expected = ValidateResult(AuthKind.signed_grant, signed_data={
- 'grants': {
- 'a': 'b',
- },
- 'user_context': {
- 'c': 'd'
- },
- })
- assert validate_signed_grant(header) == expected
+ header = "token " + generate_signed_token({"a": "b"}, {"c": "d"})
+ expected = ValidateResult(
+ AuthKind.signed_grant,
+ signed_data={"grants": {"a": "b"}, "user_context": {"c": "d"}},
+ )
+ assert validate_signed_grant(header) == expected
diff --git a/auth/test/test_validateresult.py b/auth/test/test_validateresult.py
index 90875da76..bc514e843 100644
--- a/auth/test/test_validateresult.py
+++ b/auth/test/test_validateresult.py
@@ -6,58 +6,68 @@ from data import model
from data.database import AppSpecificAuthToken
from test.fixtures import *
+
def get_user():
- return model.user.get_user('devtable')
+ return model.user.get_user("devtable")
+
def get_app_specific_token():
- return AppSpecificAuthToken.get()
+ return AppSpecificAuthToken.get()
+
def get_robot():
- robot, _ = model.user.create_robot('somebot', get_user())
- return robot
+ robot, _ = model.user.create_robot("somebot", get_user())
+ return robot
+
def get_token():
- return model.token.create_delegate_token('devtable', 'simple', 'sometoken')
+ return model.token.create_delegate_token("devtable", "simple", "sometoken")
+
def get_oauthtoken():
- user = model.user.get_user('devtable')
- return list(model.oauth.list_access_tokens_for_user(user))[0]
+ user = model.user.get_user("devtable")
+ return list(model.oauth.list_access_tokens_for_user(user))[0]
+
def get_signeddata():
- return {'grants': {'a': 'b'}, 'user_context': {'c': 'd'}}
+ return {"grants": {"a": "b"}, "user_context": {"c": "d"}}
-@pytest.mark.parametrize('get_entity,entity_kind', [
- (get_user, 'user'),
- (get_robot, 'robot'),
- (get_token, 'token'),
- (get_oauthtoken, 'oauthtoken'),
- (get_signeddata, 'signed_data'),
- (get_app_specific_token, 'appspecifictoken'),
-])
+
+@pytest.mark.parametrize(
+ "get_entity,entity_kind",
+ [
+ (get_user, "user"),
+ (get_robot, "robot"),
+ (get_token, "token"),
+ (get_oauthtoken, "oauthtoken"),
+ (get_signeddata, "signed_data"),
+ (get_app_specific_token, "appspecifictoken"),
+ ],
+)
def test_apply_context(get_entity, entity_kind, app):
- assert get_authenticated_context() is None
+ assert get_authenticated_context() is None
- entity = get_entity()
- args = {}
- args[entity_kind] = entity
+ entity = get_entity()
+ args = {}
+ args[entity_kind] = entity
- result = ValidateResult(AuthKind.basic, **args)
- result.apply_to_context()
+ result = ValidateResult(AuthKind.basic, **args)
+ result.apply_to_context()
- expected_user = entity if entity_kind == 'user' or entity_kind == 'robot' else None
- if entity_kind == 'oauthtoken':
- expected_user = entity.authorized_user
+ expected_user = entity if entity_kind == "user" or entity_kind == "robot" else None
+ if entity_kind == "oauthtoken":
+ expected_user = entity.authorized_user
- if entity_kind == 'appspecifictoken':
- expected_user = entity.user
+ if entity_kind == "appspecifictoken":
+ expected_user = entity.user
- expected_token = entity if entity_kind == 'token' else None
- expected_oauth = entity if entity_kind == 'oauthtoken' else None
- expected_appspecifictoken = entity if entity_kind == 'appspecifictoken' else None
- expected_grant = entity if entity_kind == 'signed_data' else None
+ expected_token = entity if entity_kind == "token" else None
+ expected_oauth = entity if entity_kind == "oauthtoken" else None
+ expected_appspecifictoken = entity if entity_kind == "appspecifictoken" else None
+ expected_grant = entity if entity_kind == "signed_data" else None
- assert get_authenticated_context().authed_user == expected_user
- assert get_authenticated_context().token == expected_token
- assert get_authenticated_context().oauthtoken == expected_oauth
- assert get_authenticated_context().appspecifictoken == expected_appspecifictoken
- assert get_authenticated_context().signed_data == expected_grant
+ assert get_authenticated_context().authed_user == expected_user
+ assert get_authenticated_context().token == expected_token
+ assert get_authenticated_context().oauthtoken == expected_oauth
+ assert get_authenticated_context().appspecifictoken == expected_appspecifictoken
+ assert get_authenticated_context().signed_data == expected_grant
diff --git a/auth/validateresult.py b/auth/validateresult.py
index 3235104e0..09cc09b11 100644
--- a/auth/validateresult.py
+++ b/auth/validateresult.py
@@ -3,54 +3,76 @@ from auth.auth_context_type import ValidatedAuthContext, ContextEntityKind
class AuthKind(Enum):
- cookie = 'cookie'
- basic = 'basic'
- oauth = 'oauth'
- signed_grant = 'signed_grant'
- credentials = 'credentials'
+ cookie = "cookie"
+ basic = "basic"
+ oauth = "oauth"
+ signed_grant = "signed_grant"
+ credentials = "credentials"
class ValidateResult(object):
- """ A result of validating auth in one form or another. """
- def __init__(self, kind, missing=False, user=None, token=None, oauthtoken=None,
- robot=None, appspecifictoken=None, signed_data=None, error_message=None):
- self.kind = kind
- self.missing = missing
- self.error_message = error_message
- self.context = ValidatedAuthContext(user=user, token=token, oauthtoken=oauthtoken, robot=robot,
- appspecifictoken=appspecifictoken, signed_data=signed_data)
+ """ A result of validating auth in one form or another. """
- def tuple(self):
- return (self.kind, self.missing, self.error_message, self.context.tuple())
+ def __init__(
+ self,
+ kind,
+ missing=False,
+ user=None,
+ token=None,
+ oauthtoken=None,
+ robot=None,
+ appspecifictoken=None,
+ signed_data=None,
+ error_message=None,
+ ):
+ self.kind = kind
+ self.missing = missing
+ self.error_message = error_message
+ self.context = ValidatedAuthContext(
+ user=user,
+ token=token,
+ oauthtoken=oauthtoken,
+ robot=robot,
+ appspecifictoken=appspecifictoken,
+ signed_data=signed_data,
+ )
- def __eq__(self, other):
- return self.tuple() == other.tuple()
+ def tuple(self):
+ return (self.kind, self.missing, self.error_message, self.context.tuple())
- def apply_to_context(self):
- """ Applies this auth result to the auth context and Flask-Principal. """
- self.context.apply_to_request_context()
+ def __eq__(self, other):
+ return self.tuple() == other.tuple()
- def with_kind(self, kind):
- """ Returns a copy of this result, but with the kind replaced. """
- result = ValidateResult(kind, missing=self.missing, error_message=self.error_message)
- result.context = self.context
- return result
+ def apply_to_context(self):
+ """ Applies this auth result to the auth context and Flask-Principal. """
+ self.context.apply_to_request_context()
- def __repr__(self):
- return 'ValidateResult: %s (missing: %s, error: %s)' % (self.kind, self.missing,
- self.error_message)
+ def with_kind(self, kind):
+ """ Returns a copy of this result, but with the kind replaced. """
+ result = ValidateResult(
+ kind, missing=self.missing, error_message=self.error_message
+ )
+ result.context = self.context
+ return result
- @property
- def authed_user(self):
- """ Returns the authenticated user, whether directly, or via an OAuth token. """
- return self.context.authed_user
+ def __repr__(self):
+ return "ValidateResult: %s (missing: %s, error: %s)" % (
+ self.kind,
+ self.missing,
+ self.error_message,
+ )
- @property
- def has_nonrobot_user(self):
- """ Returns whether a user (not a robot) was authenticated successfully. """
- return self.context.has_nonrobot_user
+ @property
+ def authed_user(self):
+ """ Returns the authenticated user, whether directly, or via an OAuth token. """
+ return self.context.authed_user
- @property
- def auth_valid(self):
- """ Returns whether authentication successfully occurred. """
- return self.context.entity_kind != ContextEntityKind.anonymous
+ @property
+ def has_nonrobot_user(self):
+ """ Returns whether a user (not a robot) was authenticated successfully. """
+ return self.context.has_nonrobot_user
+
+ @property
+ def auth_valid(self):
+ """ Returns whether authentication successfully occurred. """
+ return self.context.entity_kind != ContextEntityKind.anonymous
diff --git a/avatars/avatars.py b/avatars/avatars.py
index 737b51191..67969eee7 100644
--- a/avatars/avatars.py
+++ b/avatars/avatars.py
@@ -6,110 +6,133 @@ from requests.exceptions import RequestException
logger = logging.getLogger(__name__)
+
class Avatar(object):
- def __init__(self, app=None):
- self.app = app
- self.state = self._init_app(app)
+ def __init__(self, app=None):
+ self.app = app
+ self.state = self._init_app(app)
- def _init_app(self, app):
- return AVATAR_CLASSES[app.config.get('AVATAR_KIND', 'Gravatar')](
- app.config['PREFERRED_URL_SCHEME'], app.config['AVATAR_COLORS'], app.config['HTTPCLIENT'])
+ def _init_app(self, app):
+ return AVATAR_CLASSES[app.config.get("AVATAR_KIND", "Gravatar")](
+ app.config["PREFERRED_URL_SCHEME"],
+ app.config["AVATAR_COLORS"],
+ app.config["HTTPCLIENT"],
+ )
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
class BaseAvatar(object):
- """ Base class for all avatar implementations. """
- def __init__(self, preferred_url_scheme, colors, http_client):
- self.preferred_url_scheme = preferred_url_scheme
- self.colors = colors
- self.http_client = http_client
+ """ Base class for all avatar implementations. """
- def get_mail_html(self, name, email_or_id, size=16, kind='user'):
- """ Returns the full HTML and CSS for viewing the avatar of the given name and email address,
+ def __init__(self, preferred_url_scheme, colors, http_client):
+ self.preferred_url_scheme = preferred_url_scheme
+ self.colors = colors
+ self.http_client = http_client
+
+ def get_mail_html(self, name, email_or_id, size=16, kind="user"):
+ """ Returns the full HTML and CSS for viewing the avatar of the given name and email address,
with an optional size.
"""
- data = self.get_data(name, email_or_id, kind)
- url = self._get_url(data['hash'], size) if kind != 'team' else None
- font_size = size - 6
+ data = self.get_data(name, email_or_id, kind)
+ url = self._get_url(data["hash"], size) if kind != "team" else None
+ font_size = size - 6
- if url is not None:
- # Try to load the gravatar. If we get a non-404 response, then we use it in place of
- # the CSS avatar.
- try:
- response = self.http_client.get(url, timeout=5)
- if response.status_code == 200:
- return """
""" % (url, size, size, kind)
- except RequestException:
- logger.exception('Could not retrieve avatar for user %s', name)
+ if url is not None:
+ # Try to load the gravatar. If we get a non-404 response, then we use it in place of
+ # the CSS avatar.
+ try:
+ response = self.http_client.get(url, timeout=5)
+ if response.status_code == 200:
+ return """
""" % (
+ url,
+ size,
+ size,
+ kind,
+ )
+ except RequestException:
+ logger.exception("Could not retrieve avatar for user %s", name)
- radius = '50%' if kind == 'team' else '0%'
- letter = 'Ω' if kind == 'team' and data['name'] == 'owners' else data['name'].upper()[0]
+ radius = "50%" if kind == "team" else "0%"
+ letter = (
+ "Ω"
+ if kind == "team" and data["name"] == "owners"
+ else data["name"].upper()[0]
+ )
- return """
+ return """
%s
-""" % (size, size, data['color'], font_size, size, radius, letter)
+""" % (
+ size,
+ size,
+ data["color"],
+ font_size,
+ size,
+ radius,
+ letter,
+ )
- def get_data_for_user(self, user):
- return self.get_data(user.username, user.email, 'robot' if user.robot else 'user')
+ def get_data_for_user(self, user):
+ return self.get_data(
+ user.username, user.email, "robot" if user.robot else "user"
+ )
- def get_data_for_team(self, team):
- return self.get_data(team.name, team.name, 'team')
+ def get_data_for_team(self, team):
+ return self.get_data(team.name, team.name, "team")
- def get_data_for_org(self, org):
- return self.get_data(org.username, org.email, 'org')
+ def get_data_for_org(self, org):
+ return self.get_data(org.username, org.email, "org")
- def get_data_for_external_user(self, external_user):
- return self.get_data(external_user.username, external_user.email, 'user')
+ def get_data_for_external_user(self, external_user):
+ return self.get_data(external_user.username, external_user.email, "user")
- def get_data(self, name, email_or_id, kind='user'):
- """ Computes and returns the full data block for the avatar:
+ def get_data(self, name, email_or_id, kind="user"):
+ """ Computes and returns the full data block for the avatar:
{
'name': name,
'hash': The gravatar hash, if any.
'color': The color for the avatar
}
"""
- colors = self.colors
+ colors = self.colors
- # Note: email_or_id may be None if gotten from external auth when email is disabled,
- # so use the username in that case.
- username_email_or_id = email_or_id or name
- hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest()
+ # Note: email_or_id may be None if gotten from external auth when email is disabled,
+ # so use the username in that case.
+ username_email_or_id = email_or_id or name
+ hash_value = hashlib.md5(username_email_or_id.strip().lower()).hexdigest()
- byte_count = int(math.ceil(math.log(len(colors), 16)))
- byte_data = hash_value[0:byte_count]
- hash_color = colors[int(byte_data, 16) % len(colors)]
+ byte_count = int(math.ceil(math.log(len(colors), 16)))
+ byte_data = hash_value[0:byte_count]
+ hash_color = colors[int(byte_data, 16) % len(colors)]
- return {
- 'name': name,
- 'hash': hash_value,
- 'color': hash_color,
- 'kind': kind
- }
+ return {"name": name, "hash": hash_value, "color": hash_color, "kind": kind}
- def _get_url(self, hash_value, size):
- """ Returns the URL for displaying the overlay avatar. """
- return None
+ def _get_url(self, hash_value, size):
+ """ Returns the URL for displaying the overlay avatar. """
+ return None
class GravatarAvatar(BaseAvatar):
- """ Avatar system that uses gravatar for generating avatars. """
- def _get_url(self, hash_value, size=16):
- return '%s://www.gravatar.com/avatar/%s?d=404&size=%s' % (self.preferred_url_scheme,
- hash_value, size)
+ """ Avatar system that uses gravatar for generating avatars. """
+
+ def _get_url(self, hash_value, size=16):
+ return "%s://www.gravatar.com/avatar/%s?d=404&size=%s" % (
+ self.preferred_url_scheme,
+ hash_value,
+ size,
+ )
+
class LocalAvatar(BaseAvatar):
- """ Avatar system that uses the local system for generating avatars. """
- pass
+ """ Avatar system that uses the local system for generating avatars. """
-AVATAR_CLASSES = {
- 'gravatar': GravatarAvatar,
- 'local': LocalAvatar
-}
+ pass
+
+
+AVATAR_CLASSES = {"gravatar": GravatarAvatar, "local": LocalAvatar}
diff --git a/boot.py b/boot.py
index 228fb2987..d9c906ab5 100755
--- a/boot.py
+++ b/boot.py
@@ -16,7 +16,7 @@ from data.model.release import set_region_release
from data.model.service_keys import get_service_key
from util.config.database import sync_database_with_config
from util.generatepresharedkey import generate_key
-from _init import CONF_DIR
+from _init import CONF_DIR
logger = logging.getLogger(__name__)
@@ -24,108 +24,117 @@ logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_audience():
- audience = app.config.get('JWTPROXY_AUDIENCE')
+ audience = app.config.get("JWTPROXY_AUDIENCE")
- if audience:
- return audience
+ if audience:
+ return audience
- scheme = app.config.get('PREFERRED_URL_SCHEME')
- hostname = app.config.get('SERVER_HOSTNAME')
+ scheme = app.config.get("PREFERRED_URL_SCHEME")
+ hostname = app.config.get("SERVER_HOSTNAME")
- # hostname includes port, use that
- if ':' in hostname:
- return urlunparse((scheme, hostname, '', '', '', ''))
+ # hostname includes port, use that
+ if ":" in hostname:
+ return urlunparse((scheme, hostname, "", "", "", ""))
- # no port, guess based on scheme
- if scheme == 'https':
- port = '443'
- else:
- port = '80'
+ # no port, guess based on scheme
+ if scheme == "https":
+ port = "443"
+ else:
+ port = "80"
- return urlunparse((scheme, hostname + ':' + port, '', '', '', ''))
+ return urlunparse((scheme, hostname + ":" + port, "", "", "", ""))
def _verify_service_key():
- try:
- with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION']) as f:
- quay_key_id = f.read()
-
try:
- get_service_key(quay_key_id, approved_only=False)
- assert os.path.exists(app.config['INSTANCE_SERVICE_KEY_LOCATION'])
- return quay_key_id
- except ServiceKeyDoesNotExist:
- logger.exception('Could not find non-expired existing service key %s; creating a new one',
- quay_key_id)
- return None
+ with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"]) as f:
+ quay_key_id = f.read()
- # Found a valid service key, so exiting.
- except IOError:
- logger.exception('Could not load existing service key; creating a new one')
- return None
+ try:
+ get_service_key(quay_key_id, approved_only=False)
+ assert os.path.exists(app.config["INSTANCE_SERVICE_KEY_LOCATION"])
+ return quay_key_id
+ except ServiceKeyDoesNotExist:
+ logger.exception(
+ "Could not find non-expired existing service key %s; creating a new one",
+ quay_key_id,
+ )
+ return None
+
+ # Found a valid service key, so exiting.
+ except IOError:
+ logger.exception("Could not load existing service key; creating a new one")
+ return None
def setup_jwt_proxy():
- """
+ """
Creates a service key for quay to use in the jwtproxy and generates the JWT proxy configuration.
"""
- if os.path.exists(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml')):
- # Proxy is already setup. Make sure the service key is still valid.
- quay_key_id = _verify_service_key()
- if quay_key_id is not None:
- return
+ if os.path.exists(os.path.join(CONF_DIR, "jwtproxy_conf.yaml")):
+ # Proxy is already setup. Make sure the service key is still valid.
+ quay_key_id = _verify_service_key()
+ if quay_key_id is not None:
+ return
- # Ensure we have an existing key if in read-only mode.
- if app.config.get('REGISTRY_STATE', 'normal') == 'readonly':
- quay_key_id = _verify_service_key()
- if quay_key_id is None:
- raise Exception('No valid service key found for read-only registry.')
- else:
- # Generate the key for this Quay instance to use.
- minutes_until_expiration = app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
- expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
- quay_key, quay_key_id = generate_key(app.config['INSTANCE_SERVICE_KEY_SERVICE'],
- get_audience(), expiration_date=expiration)
+ # Ensure we have an existing key if in read-only mode.
+ if app.config.get("REGISTRY_STATE", "normal") == "readonly":
+ quay_key_id = _verify_service_key()
+ if quay_key_id is None:
+ raise Exception("No valid service key found for read-only registry.")
+ else:
+ # Generate the key for this Quay instance to use.
+ minutes_until_expiration = app.config.get(
+ "INSTANCE_SERVICE_KEY_EXPIRATION", 120
+ )
+ expiration = datetime.now() + timedelta(minutes=minutes_until_expiration)
+ quay_key, quay_key_id = generate_key(
+ app.config["INSTANCE_SERVICE_KEY_SERVICE"],
+ get_audience(),
+ expiration_date=expiration,
+ )
- with open(app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'], mode='w') as f:
- f.truncate(0)
- f.write(quay_key_id)
+ with open(app.config["INSTANCE_SERVICE_KEY_KID_LOCATION"], mode="w") as f:
+ f.truncate(0)
+ f.write(quay_key_id)
- with open(app.config['INSTANCE_SERVICE_KEY_LOCATION'], mode='w') as f:
- f.truncate(0)
- f.write(quay_key.exportKey())
+ with open(app.config["INSTANCE_SERVICE_KEY_LOCATION"], mode="w") as f:
+ f.truncate(0)
+ f.write(quay_key.exportKey())
- # Generate the JWT proxy configuration.
- audience = get_audience()
- registry = audience + '/keys'
- security_issuer = app.config.get('SECURITY_SCANNER_ISSUER_NAME', 'security_scanner')
+ # Generate the JWT proxy configuration.
+ audience = get_audience()
+ registry = audience + "/keys"
+ security_issuer = app.config.get("SECURITY_SCANNER_ISSUER_NAME", "security_scanner")
- with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml.jnj')) as f:
- template = Template(f.read())
- rendered = template.render(
- conf_dir=CONF_DIR,
- audience=audience,
- registry=registry,
- key_id=quay_key_id,
- security_issuer=security_issuer,
- service_key_location=app.config['INSTANCE_SERVICE_KEY_LOCATION'],
- )
+ with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml.jnj")) as f:
+ template = Template(f.read())
+ rendered = template.render(
+ conf_dir=CONF_DIR,
+ audience=audience,
+ registry=registry,
+ key_id=quay_key_id,
+ security_issuer=security_issuer,
+ service_key_location=app.config["INSTANCE_SERVICE_KEY_LOCATION"],
+ )
- with open(os.path.join(CONF_DIR, 'jwtproxy_conf.yaml'), 'w') as f:
- f.write(rendered)
+ with open(os.path.join(CONF_DIR, "jwtproxy_conf.yaml"), "w") as f:
+ f.write(rendered)
def main():
- if not app.config.get('SETUP_COMPLETE', False):
- raise Exception('Your configuration bundle is either not mounted or setup has not been completed')
+ if not app.config.get("SETUP_COMPLETE", False):
+ raise Exception(
+ "Your configuration bundle is either not mounted or setup has not been completed"
+ )
- sync_database_with_config(app.config)
- setup_jwt_proxy()
+ sync_database_with_config(app.config)
+ setup_jwt_proxy()
- # Record deploy
- if release.REGION and release.GIT_HEAD:
- set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
+ # Record deploy
+ if release.REGION and release.GIT_HEAD:
+ set_region_release(release.SERVICE, release.REGION, release.GIT_HEAD)
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ main()
diff --git a/buildman/asyncutil.py b/buildman/asyncutil.py
index accb13542..f913072c4 100644
--- a/buildman/asyncutil.py
+++ b/buildman/asyncutil.py
@@ -5,38 +5,39 @@ from trollius import get_event_loop, coroutine
def wrap_with_threadpool(obj, worker_threads=1):
- """
+ """
Wraps a class in an async executor so that it can be safely used in an event loop like trollius.
"""
- async_executor = ThreadPoolExecutor(worker_threads)
- return AsyncWrapper(obj, executor=async_executor), async_executor
+ async_executor = ThreadPoolExecutor(worker_threads)
+ return AsyncWrapper(obj, executor=async_executor), async_executor
class AsyncWrapper(object):
- """ Wrapper class which will transform a syncronous library to one that can be used with
+ """ Wrapper class which will transform a syncronous library to one that can be used with
trollius coroutines.
"""
- def __init__(self, delegate, loop=None, executor=None):
- self._loop = loop if loop is not None else get_event_loop()
- self._delegate = delegate
- self._executor = executor
- def __getattr__(self, attrib):
- delegate_attr = getattr(self._delegate, attrib)
+ def __init__(self, delegate, loop=None, executor=None):
+ self._loop = loop if loop is not None else get_event_loop()
+ self._delegate = delegate
+ self._executor = executor
- if not callable(delegate_attr):
- return delegate_attr
+ def __getattr__(self, attrib):
+ delegate_attr = getattr(self._delegate, attrib)
- def wrapper(*args, **kwargs):
- """ Wraps the delegate_attr with primitives that will transform sync calls to ones shelled
+ if not callable(delegate_attr):
+ return delegate_attr
+
+ def wrapper(*args, **kwargs):
+ """ Wraps the delegate_attr with primitives that will transform sync calls to ones shelled
out to a thread pool.
"""
- callable_delegate_attr = partial(delegate_attr, *args, **kwargs)
- return self._loop.run_in_executor(self._executor, callable_delegate_attr)
+ callable_delegate_attr = partial(delegate_attr, *args, **kwargs)
+ return self._loop.run_in_executor(self._executor, callable_delegate_attr)
- return wrapper
+ return wrapper
- @coroutine
- def __call__(self, *args, **kwargs):
- callable_delegate_attr = partial(self._delegate, *args, **kwargs)
- return self._loop.run_in_executor(self._executor, callable_delegate_attr)
+ @coroutine
+ def __call__(self, *args, **kwargs):
+ callable_delegate_attr = partial(self._delegate, *args, **kwargs)
+ return self._loop.run_in_executor(self._executor, callable_delegate_attr)
diff --git a/buildman/builder.py b/buildman/builder.py
index 0261c262d..8c31da891 100644
--- a/buildman/builder.py
+++ b/buildman/builder.py
@@ -18,80 +18,104 @@ from raven.conf import setup_logging
logger = logging.getLogger(__name__)
-BUILD_MANAGERS = {
- 'enterprise': EnterpriseManager,
- 'ephemeral': EphemeralBuilderManager,
-}
+BUILD_MANAGERS = {"enterprise": EnterpriseManager, "ephemeral": EphemeralBuilderManager}
-EXTERNALLY_MANAGED = 'external'
+EXTERNALLY_MANAGED = "external"
DEFAULT_WEBSOCKET_PORT = 8787
DEFAULT_CONTROLLER_PORT = 8686
LOG_FORMAT = "%(asctime)s [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
+
def run_build_manager():
- if not features.BUILD_SUPPORT:
- logger.debug('Building is disabled. Please enable the feature flag')
- while True:
- time.sleep(1000)
- return
+ if not features.BUILD_SUPPORT:
+ logger.debug("Building is disabled. Please enable the feature flag")
+ while True:
+ time.sleep(1000)
+ return
- if app.config.get('REGISTRY_STATE', 'normal') == 'readonly':
- logger.debug('Building is disabled while in read-only mode.')
- while True:
- time.sleep(1000)
- return
+ if app.config.get("REGISTRY_STATE", "normal") == "readonly":
+ logger.debug("Building is disabled while in read-only mode.")
+ while True:
+ time.sleep(1000)
+ return
- build_manager_config = app.config.get('BUILD_MANAGER')
- if build_manager_config is None:
- return
+ build_manager_config = app.config.get("BUILD_MANAGER")
+ if build_manager_config is None:
+ return
- # If the build system is externally managed, then we just sleep this process.
- if build_manager_config[0] == EXTERNALLY_MANAGED:
- logger.debug('Builds are externally managed.')
- while True:
- time.sleep(1000)
- return
+ # If the build system is externally managed, then we just sleep this process.
+ if build_manager_config[0] == EXTERNALLY_MANAGED:
+ logger.debug("Builds are externally managed.")
+ while True:
+ time.sleep(1000)
+ return
- logger.debug('Asking to start build manager with lifecycle "%s"', build_manager_config[0])
- manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
- if manager_klass is None:
- return
+ logger.debug(
+ 'Asking to start build manager with lifecycle "%s"', build_manager_config[0]
+ )
+ manager_klass = BUILD_MANAGERS.get(build_manager_config[0])
+ if manager_klass is None:
+ return
- manager_hostname = os.environ.get('BUILDMAN_HOSTNAME',
- app.config.get('BUILDMAN_HOSTNAME',
- app.config['SERVER_HOSTNAME']))
- websocket_port = int(os.environ.get('BUILDMAN_WEBSOCKET_PORT',
- app.config.get('BUILDMAN_WEBSOCKET_PORT',
- DEFAULT_WEBSOCKET_PORT)))
- controller_port = int(os.environ.get('BUILDMAN_CONTROLLER_PORT',
- app.config.get('BUILDMAN_CONTROLLER_PORT',
- DEFAULT_CONTROLLER_PORT)))
+ manager_hostname = os.environ.get(
+ "BUILDMAN_HOSTNAME",
+ app.config.get("BUILDMAN_HOSTNAME", app.config["SERVER_HOSTNAME"]),
+ )
+ websocket_port = int(
+ os.environ.get(
+ "BUILDMAN_WEBSOCKET_PORT",
+ app.config.get("BUILDMAN_WEBSOCKET_PORT", DEFAULT_WEBSOCKET_PORT),
+ )
+ )
+ controller_port = int(
+ os.environ.get(
+ "BUILDMAN_CONTROLLER_PORT",
+ app.config.get("BUILDMAN_CONTROLLER_PORT", DEFAULT_CONTROLLER_PORT),
+ )
+ )
- logger.debug('Will pass buildman hostname %s to builders for websocket connection',
- manager_hostname)
+ logger.debug(
+ "Will pass buildman hostname %s to builders for websocket connection",
+ manager_hostname,
+ )
- logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
- ssl_context = None
- if os.environ.get('SSL_CONFIG'):
- logger.debug('Loading SSL cert and key')
- ssl_context = SSLContext()
- ssl_context.load_cert_chain(os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.cert'),
- os.path.join(os.environ.get('SSL_CONFIG'), 'ssl.key'))
+ logger.debug('Starting build manager with lifecycle "%s"', build_manager_config[0])
+ ssl_context = None
+ if os.environ.get("SSL_CONFIG"):
+ logger.debug("Loading SSL cert and key")
+ ssl_context = SSLContext()
+ ssl_context.load_cert_chain(
+ os.path.join(os.environ.get("SSL_CONFIG"), "ssl.cert"),
+ os.path.join(os.environ.get("SSL_CONFIG"), "ssl.key"),
+ )
- server = BuilderServer(app.config['SERVER_HOSTNAME'], dockerfile_build_queue, build_logs,
- user_files, manager_klass, build_manager_config[1], manager_hostname)
- server.run('0.0.0.0', websocket_port, controller_port, ssl=ssl_context)
+ server = BuilderServer(
+ app.config["SERVER_HOSTNAME"],
+ dockerfile_build_queue,
+ build_logs,
+ user_files,
+ manager_klass,
+ build_manager_config[1],
+ manager_hostname,
+ )
+ server.run("0.0.0.0", websocket_port, controller_port, ssl=ssl_context)
-if __name__ == '__main__':
- logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
- logging.getLogger('peewee').setLevel(logging.WARN)
- logging.getLogger('boto').setLevel(logging.WARN)
- if app.config.get('EXCEPTION_LOG_TYPE', 'FakeSentry') == 'Sentry':
- buildman_name = '%s:buildman' % socket.gethostname()
- setup_logging(SentryHandler(app.config.get('SENTRY_DSN', ''), name=buildman_name,
- level=logging.ERROR))
+if __name__ == "__main__":
+ logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
+ logging.getLogger("peewee").setLevel(logging.WARN)
+ logging.getLogger("boto").setLevel(logging.WARN)
- run_build_manager()
+ if app.config.get("EXCEPTION_LOG_TYPE", "FakeSentry") == "Sentry":
+ buildman_name = "%s:buildman" % socket.gethostname()
+ setup_logging(
+ SentryHandler(
+ app.config.get("SENTRY_DSN", ""),
+ name=buildman_name,
+ level=logging.ERROR,
+ )
+ )
+
+ run_build_manager()
diff --git a/buildman/component/basecomponent.py b/buildman/component/basecomponent.py
index bd4032776..8806b5629 100644
--- a/buildman/component/basecomponent.py
+++ b/buildman/component/basecomponent.py
@@ -1,13 +1,15 @@
from autobahn.asyncio.wamp import ApplicationSession
-class BaseComponent(ApplicationSession):
- """ Base class for all registered component sessions in the server. """
- def __init__(self, config, **kwargs):
- ApplicationSession.__init__(self, config)
- self.server = None
- self.parent_manager = None
- self.build_logs = None
- self.user_files = None
- def kind(self):
- raise NotImplementedError
\ No newline at end of file
+class BaseComponent(ApplicationSession):
+ """ Base class for all registered component sessions in the server. """
+
+ def __init__(self, config, **kwargs):
+ ApplicationSession.__init__(self, config)
+ self.server = None
+ self.parent_manager = None
+ self.build_logs = None
+ self.user_files = None
+
+ def kind(self):
+ raise NotImplementedError
diff --git a/buildman/component/buildcomponent.py b/buildman/component/buildcomponent.py
index 62c64e6b8..9e7b4946a 100644
--- a/buildman/component/buildcomponent.py
+++ b/buildman/component/buildcomponent.py
@@ -27,513 +27,632 @@ BUILD_HEARTBEAT_DELAY = datetime.timedelta(seconds=30)
HEARTBEAT_TIMEOUT = 10
INITIAL_TIMEOUT = 25
-SUPPORTED_WORKER_VERSIONS = ['0.3']
+SUPPORTED_WORKER_VERSIONS = ["0.3"]
# Label which marks a manifest with its source build ID.
-INTERNAL_LABEL_BUILD_UUID = 'quay.build.uuid'
+INTERNAL_LABEL_BUILD_UUID = "quay.build.uuid"
logger = logging.getLogger(__name__)
+
class ComponentStatus(object):
- """ ComponentStatus represents the possible states of a component. """
- JOINING = 'joining'
- WAITING = 'waiting'
- RUNNING = 'running'
- BUILDING = 'building'
- TIMED_OUT = 'timeout'
+ """ ComponentStatus represents the possible states of a component. """
+
+ JOINING = "joining"
+ WAITING = "waiting"
+ RUNNING = "running"
+ BUILDING = "building"
+ TIMED_OUT = "timeout"
+
class BuildComponent(BaseComponent):
- """ An application session component which conducts one (or more) builds. """
- def __init__(self, config, realm=None, token=None, **kwargs):
- self.expected_token = token
- self.builder_realm = realm
+ """ An application session component which conducts one (or more) builds. """
- self.parent_manager = None
- self.registry_hostname = None
+ def __init__(self, config, realm=None, token=None, **kwargs):
+ self.expected_token = token
+ self.builder_realm = realm
- self._component_status = ComponentStatus.JOINING
- self._last_heartbeat = None
- self._current_job = None
- self._build_status = None
- self._image_info = None
- self._worker_version = None
+ self.parent_manager = None
+ self.registry_hostname = None
- BaseComponent.__init__(self, config, **kwargs)
+ self._component_status = ComponentStatus.JOINING
+ self._last_heartbeat = None
+ self._current_job = None
+ self._build_status = None
+ self._image_info = None
+ self._worker_version = None
- def kind(self):
- return 'builder'
+ BaseComponent.__init__(self, config, **kwargs)
- def onConnect(self):
- self.join(self.builder_realm)
+ def kind(self):
+ return "builder"
- @trollius.coroutine
- def onJoin(self, details):
- logger.debug('Registering methods and listeners for component %s', self.builder_realm)
- yield From(self.register(self._on_ready, u'io.quay.buildworker.ready'))
- yield From(self.register(self._determine_cache_tag, u'io.quay.buildworker.determinecachetag'))
- yield From(self.register(self._ping, u'io.quay.buildworker.ping'))
- yield From(self.register(self._on_log_message, u'io.quay.builder.logmessagesynchronously'))
+ def onConnect(self):
+ self.join(self.builder_realm)
- yield From(self.subscribe(self._on_heartbeat, u'io.quay.builder.heartbeat'))
+ @trollius.coroutine
+ def onJoin(self, details):
+ logger.debug(
+ "Registering methods and listeners for component %s", self.builder_realm
+ )
+ yield From(self.register(self._on_ready, u"io.quay.buildworker.ready"))
+ yield From(
+ self.register(
+ self._determine_cache_tag, u"io.quay.buildworker.determinecachetag"
+ )
+ )
+ yield From(self.register(self._ping, u"io.quay.buildworker.ping"))
+ yield From(
+ self.register(
+ self._on_log_message, u"io.quay.builder.logmessagesynchronously"
+ )
+ )
- yield From(self._set_status(ComponentStatus.WAITING))
+ yield From(self.subscribe(self._on_heartbeat, u"io.quay.builder.heartbeat"))
- @trollius.coroutine
- def start_build(self, build_job):
- """ Starts a build. """
- if self._component_status not in (ComponentStatus.WAITING, ComponentStatus.RUNNING):
- logger.debug('Could not start build for component %s (build %s, worker version: %s): %s',
- self.builder_realm, build_job.repo_build.uuid, self._worker_version,
- self._component_status)
- raise Return()
+ yield From(self._set_status(ComponentStatus.WAITING))
- logger.debug('Starting build for component %s (build %s, worker version: %s)',
- self.builder_realm, build_job.repo_build.uuid, self._worker_version)
+ @trollius.coroutine
+ def start_build(self, build_job):
+ """ Starts a build. """
+ if self._component_status not in (
+ ComponentStatus.WAITING,
+ ComponentStatus.RUNNING,
+ ):
+ logger.debug(
+ "Could not start build for component %s (build %s, worker version: %s): %s",
+ self.builder_realm,
+ build_job.repo_build.uuid,
+ self._worker_version,
+ self._component_status,
+ )
+ raise Return()
- self._current_job = build_job
- self._build_status = StatusHandler(self.build_logs, build_job.repo_build.uuid)
- self._image_info = {}
+ logger.debug(
+ "Starting build for component %s (build %s, worker version: %s)",
+ self.builder_realm,
+ build_job.repo_build.uuid,
+ self._worker_version,
+ )
- yield From(self._set_status(ComponentStatus.BUILDING))
+ self._current_job = build_job
+ self._build_status = StatusHandler(self.build_logs, build_job.repo_build.uuid)
+ self._image_info = {}
- # Send the notification that the build has started.
- build_job.send_notification('build_start')
+ yield From(self._set_status(ComponentStatus.BUILDING))
- # Parse the build configuration.
- try:
- build_config = build_job.build_config
- except BuildJobLoadException as irbe:
- yield From(self._build_failure('Could not load build job information', irbe))
- raise Return()
+ # Send the notification that the build has started.
+ build_job.send_notification("build_start")
- base_image_information = {}
+ # Parse the build configuration.
+ try:
+ build_config = build_job.build_config
+ except BuildJobLoadException as irbe:
+ yield From(
+ self._build_failure("Could not load build job information", irbe)
+ )
+ raise Return()
- # Add the pull robot information, if any.
- if build_job.pull_credentials:
- base_image_information['username'] = build_job.pull_credentials.get('username', '')
- base_image_information['password'] = build_job.pull_credentials.get('password', '')
+ base_image_information = {}
- # Retrieve the repository's fully qualified name.
- repo = build_job.repo_build.repository
- repository_name = repo.namespace_user.username + '/' + repo.name
+ # Add the pull robot information, if any.
+ if build_job.pull_credentials:
+ base_image_information["username"] = build_job.pull_credentials.get(
+ "username", ""
+ )
+ base_image_information["password"] = build_job.pull_credentials.get(
+ "password", ""
+ )
- # Parse the build queue item into build arguments.
- # build_package: URL to the build package to download and untar/unzip.
- # defaults to empty string to avoid requiring a pointer on the builder.
- # sub_directory: The location within the build package of the Dockerfile and the build context.
- # repository: The repository for which this build is occurring.
- # registry: The registry for which this build is occuring (e.g. 'quay.io').
- # pull_token: The token to use when pulling the cache for building.
- # push_token: The token to use to push the built image.
- # tag_names: The name(s) of the tag(s) for the newly built image.
- # base_image: The image name and credentials to use to conduct the base image pull.
- # username: The username for pulling the base image (if any).
- # password: The password for pulling the base image (if any).
- context, dockerfile_path = self.extract_dockerfile_args(build_config)
- build_arguments = {
- 'build_package': build_job.get_build_package_url(self.user_files),
- 'context': context,
- 'dockerfile_path': dockerfile_path,
- 'repository': repository_name,
- 'registry': self.registry_hostname,
- 'pull_token': build_job.repo_build.access_token.get_code(),
- 'push_token': build_job.repo_build.access_token.get_code(),
- 'tag_names': build_config.get('docker_tags', ['latest']),
- 'base_image': base_image_information,
- }
+ # Retrieve the repository's fully qualified name.
+ repo = build_job.repo_build.repository
+ repository_name = repo.namespace_user.username + "/" + repo.name
- # If the trigger has a private key, it's using git, thus we should add
- # git data to the build args.
- # url: url used to clone the git repository
- # sha: the sha1 identifier of the commit to check out
- # private_key: the key used to get read access to the git repository
+ # Parse the build queue item into build arguments.
+ # build_package: URL to the build package to download and untar/unzip.
+ # defaults to empty string to avoid requiring a pointer on the builder.
+ # sub_directory: The location within the build package of the Dockerfile and the build context.
+ # repository: The repository for which this build is occurring.
+ # registry: The registry for which this build is occuring (e.g. 'quay.io').
+ # pull_token: The token to use when pulling the cache for building.
+ # push_token: The token to use to push the built image.
+ # tag_names: The name(s) of the tag(s) for the newly built image.
+ # base_image: The image name and credentials to use to conduct the base image pull.
+ # username: The username for pulling the base image (if any).
+ # password: The password for pulling the base image (if any).
+ context, dockerfile_path = self.extract_dockerfile_args(build_config)
+ build_arguments = {
+ "build_package": build_job.get_build_package_url(self.user_files),
+ "context": context,
+ "dockerfile_path": dockerfile_path,
+ "repository": repository_name,
+ "registry": self.registry_hostname,
+ "pull_token": build_job.repo_build.access_token.get_code(),
+ "push_token": build_job.repo_build.access_token.get_code(),
+ "tag_names": build_config.get("docker_tags", ["latest"]),
+ "base_image": base_image_information,
+ }
- # TODO(remove-unenc): Remove legacy field.
- private_key = None
- if build_job.repo_build.trigger is not None and \
- build_job.repo_build.trigger.secure_private_key is not None:
- private_key = build_job.repo_build.trigger.secure_private_key.decrypt()
+ # If the trigger has a private key, it's using git, thus we should add
+ # git data to the build args.
+ # url: url used to clone the git repository
+ # sha: the sha1 identifier of the commit to check out
+ # private_key: the key used to get read access to the git repository
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS) and \
- private_key is None and \
- build_job.repo_build.trigger is not None:
- private_key = build_job.repo_build.trigger.private_key
+ # TODO(remove-unenc): Remove legacy field.
+ private_key = None
+ if (
+ build_job.repo_build.trigger is not None
+ and build_job.repo_build.trigger.secure_private_key is not None
+ ):
+ private_key = build_job.repo_build.trigger.secure_private_key.decrypt()
- if private_key is not None:
- build_arguments['git'] = {
- 'url': build_config['trigger_metadata'].get('git_url', ''),
- 'sha': BuildComponent._commit_sha(build_config),
- 'private_key': private_key or '',
- }
+ if (
+ ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
+ and private_key is None
+ and build_job.repo_build.trigger is not None
+ ):
+ private_key = build_job.repo_build.trigger.private_key
- # If the build args have no buildpack, mark it as a failure before sending
- # it to a builder instance.
- if not build_arguments['build_package'] and not build_arguments['git']:
- logger.error('%s: insufficient build args: %s',
- self._current_job.repo_build.uuid, build_arguments)
- yield From(self._build_failure('Insufficient build arguments. No buildpack available.'))
- raise Return()
+ if private_key is not None:
+ build_arguments["git"] = {
+ "url": build_config["trigger_metadata"].get("git_url", ""),
+ "sha": BuildComponent._commit_sha(build_config),
+ "private_key": private_key or "",
+ }
- # Invoke the build.
- logger.debug('Invoking build: %s', self.builder_realm)
- logger.debug('With Arguments: %s', build_arguments)
+ # If the build args have no buildpack, mark it as a failure before sending
+ # it to a builder instance.
+ if not build_arguments["build_package"] and not build_arguments["git"]:
+ logger.error(
+ "%s: insufficient build args: %s",
+ self._current_job.repo_build.uuid,
+ build_arguments,
+ )
+ yield From(
+ self._build_failure(
+ "Insufficient build arguments. No buildpack available."
+ )
+ )
+ raise Return()
- def build_complete_callback(result):
- """ This function is used to execute a coroutine as the callback. """
- trollius.ensure_future(self._build_complete(result))
+ # Invoke the build.
+ logger.debug("Invoking build: %s", self.builder_realm)
+ logger.debug("With Arguments: %s", build_arguments)
- self.call("io.quay.builder.build", **build_arguments).add_done_callback(build_complete_callback)
+ def build_complete_callback(result):
+ """ This function is used to execute a coroutine as the callback. """
+ trollius.ensure_future(self._build_complete(result))
- # Set the heartbeat for the future. If the builder never receives the build call,
- # then this will cause a timeout after 30 seconds. We know the builder has registered
- # by this point, so it makes sense to have a timeout.
- self._last_heartbeat = datetime.datetime.utcnow() + BUILD_HEARTBEAT_DELAY
+ self.call("io.quay.builder.build", **build_arguments).add_done_callback(
+ build_complete_callback
+ )
- @staticmethod
- def extract_dockerfile_args(build_config):
- dockerfile_path = build_config.get('build_subdir', '')
- context = build_config.get('context', '')
- if not (dockerfile_path == '' or context == ''):
- # This should not happen and can be removed when we centralize validating build_config
- dockerfile_abspath = slash_join('', dockerfile_path)
- if ".." in os.path.relpath(dockerfile_abspath, context):
- return os.path.split(dockerfile_path)
- dockerfile_path = os.path.relpath(dockerfile_abspath, context)
+ # Set the heartbeat for the future. If the builder never receives the build call,
+ # then this will cause a timeout after 30 seconds. We know the builder has registered
+ # by this point, so it makes sense to have a timeout.
+ self._last_heartbeat = datetime.datetime.utcnow() + BUILD_HEARTBEAT_DELAY
- return context, dockerfile_path
+ @staticmethod
+ def extract_dockerfile_args(build_config):
+ dockerfile_path = build_config.get("build_subdir", "")
+ context = build_config.get("context", "")
+ if not (dockerfile_path == "" or context == ""):
+ # This should not happen and can be removed when we centralize validating build_config
+ dockerfile_abspath = slash_join("", dockerfile_path)
+ if ".." in os.path.relpath(dockerfile_abspath, context):
+ return os.path.split(dockerfile_path)
+ dockerfile_path = os.path.relpath(dockerfile_abspath, context)
- @staticmethod
- def _commit_sha(build_config):
- """ Determines whether the metadata is using an old schema or not and returns the commit. """
- commit_sha = build_config['trigger_metadata'].get('commit', '')
- old_commit_sha = build_config['trigger_metadata'].get('commit_sha', '')
- return commit_sha or old_commit_sha
+ return context, dockerfile_path
- @staticmethod
- def name_and_path(subdir):
- """ Returns the dockerfile path and name """
- if subdir.endswith("/"):
- subdir += "Dockerfile"
- elif not subdir.endswith("Dockerfile"):
- subdir += "/Dockerfile"
- return os.path.split(subdir)
+ @staticmethod
+ def _commit_sha(build_config):
+ """ Determines whether the metadata is using an old schema or not and returns the commit. """
+ commit_sha = build_config["trigger_metadata"].get("commit", "")
+ old_commit_sha = build_config["trigger_metadata"].get("commit_sha", "")
+ return commit_sha or old_commit_sha
- @staticmethod
- def _total_completion(statuses, total_images):
- """ Returns the current amount completion relative to the total completion of a build. """
- percentage_with_sizes = float(len(statuses.values())) / total_images
- sent_bytes = sum([status['current'] for status in statuses.values()])
- total_bytes = sum([status['total'] for status in statuses.values()])
- return float(sent_bytes) / total_bytes * percentage_with_sizes
+ @staticmethod
+ def name_and_path(subdir):
+ """ Returns the dockerfile path and name """
+ if subdir.endswith("/"):
+ subdir += "Dockerfile"
+ elif not subdir.endswith("Dockerfile"):
+ subdir += "/Dockerfile"
+ return os.path.split(subdir)
- @staticmethod
- def _process_pushpull_status(status_dict, current_phase, docker_data, images):
- """ Processes the status of a push or pull by updating the provided status_dict and images. """
- if not docker_data:
- return
+ @staticmethod
+ def _total_completion(statuses, total_images):
+ """ Returns the current amount completion relative to the total completion of a build. """
+ percentage_with_sizes = float(len(statuses.values())) / total_images
+ sent_bytes = sum([status["current"] for status in statuses.values()])
+ total_bytes = sum([status["total"] for status in statuses.values()])
+ return float(sent_bytes) / total_bytes * percentage_with_sizes
- num_images = 0
- status_completion_key = ''
+ @staticmethod
+ def _process_pushpull_status(status_dict, current_phase, docker_data, images):
+ """ Processes the status of a push or pull by updating the provided status_dict and images. """
+ if not docker_data:
+ return
- if current_phase == 'pushing':
- status_completion_key = 'push_completion'
- num_images = status_dict['total_commands']
- elif current_phase == 'pulling':
- status_completion_key = 'pull_completion'
- elif current_phase == 'priming-cache':
- status_completion_key = 'cache_completion'
- else:
- return
+ num_images = 0
+ status_completion_key = ""
- if 'progressDetail' in docker_data and 'id' in docker_data:
- image_id = docker_data['id']
- detail = docker_data['progressDetail']
+ if current_phase == "pushing":
+ status_completion_key = "push_completion"
+ num_images = status_dict["total_commands"]
+ elif current_phase == "pulling":
+ status_completion_key = "pull_completion"
+ elif current_phase == "priming-cache":
+ status_completion_key = "cache_completion"
+ else:
+ return
- if 'current' in detail and 'total' in detail:
- images[image_id] = detail
- status_dict[status_completion_key] = \
- BuildComponent._total_completion(images, max(len(images), num_images))
+ if "progressDetail" in docker_data and "id" in docker_data:
+ image_id = docker_data["id"]
+ detail = docker_data["progressDetail"]
+ if "current" in detail and "total" in detail:
+ images[image_id] = detail
+ status_dict[status_completion_key] = BuildComponent._total_completion(
+ images, max(len(images), num_images)
+ )
- @trollius.coroutine
- def _on_log_message(self, phase, json_data):
- """ Tails log messages and updates the build status. """
- # Update the heartbeat.
- self._last_heartbeat = datetime.datetime.utcnow()
+ @trollius.coroutine
+ def _on_log_message(self, phase, json_data):
+ """ Tails log messages and updates the build status. """
+ # Update the heartbeat.
+ self._last_heartbeat = datetime.datetime.utcnow()
- # Parse any of the JSON data logged.
- log_data = {}
- if json_data:
- try:
- log_data = json.loads(json_data)
- except ValueError:
- pass
+ # Parse any of the JSON data logged.
+ log_data = {}
+ if json_data:
+ try:
+ log_data = json.loads(json_data)
+ except ValueError:
+ pass
- # Extract the current status message (if any).
- fully_unwrapped = ''
- keys_to_extract = ['error', 'status', 'stream']
- for key in keys_to_extract:
- if key in log_data:
- fully_unwrapped = log_data[key]
- break
+ # Extract the current status message (if any).
+ fully_unwrapped = ""
+ keys_to_extract = ["error", "status", "stream"]
+ for key in keys_to_extract:
+ if key in log_data:
+ fully_unwrapped = log_data[key]
+ break
- # Determine if this is a step string.
- current_step = None
- current_status_string = str(fully_unwrapped.encode('utf-8'))
+ # Determine if this is a step string.
+ current_step = None
+ current_status_string = str(fully_unwrapped.encode("utf-8"))
- if current_status_string and phase == BUILD_PHASE.BUILDING:
- current_step = extract_current_step(current_status_string)
+ if current_status_string and phase == BUILD_PHASE.BUILDING:
+ current_step = extract_current_step(current_status_string)
+
+ # Parse and update the phase and the status_dict. The status dictionary contains
+ # the pull/push progress, as well as the current step index.
+ with self._build_status as status_dict:
+ try:
+ changed_phase = yield From(
+ self._build_status.set_phase(phase, log_data.get("status_data"))
+ )
+ if changed_phase:
+ logger.debug(
+ "Build %s has entered a new phase: %s",
+ self.builder_realm,
+ phase,
+ )
+ elif self._current_job.repo_build.phase == BUILD_PHASE.CANCELLED:
+ build_id = self._current_job.repo_build.uuid
+ logger.debug(
+ "Trying to move cancelled build into phase: %s with id: %s",
+ phase,
+ build_id,
+ )
+ raise Return(False)
+ except InvalidRepositoryBuildException:
+ build_id = self._current_job.repo_build.uuid
+ logger.warning(
+ "Build %s was not found; repo was probably deleted", build_id
+ )
+ raise Return(False)
+
+ BuildComponent._process_pushpull_status(
+ status_dict, phase, log_data, self._image_info
+ )
+
+ # If the current message represents the beginning of a new step, then update the
+ # current command index.
+ if current_step is not None:
+ status_dict["current_command"] = current_step
+
+ # If the json data contains an error, then something went wrong with a push or pull.
+ if "error" in log_data:
+ yield From(self._build_status.set_error(log_data["error"]))
+
+ if current_step is not None:
+ yield From(self._build_status.set_command(current_status_string))
+ elif phase == BUILD_PHASE.BUILDING:
+ yield From(self._build_status.append_log(current_status_string))
+ raise Return(True)
+
+ @trollius.coroutine
+ def _determine_cache_tag(
+ self, command_comments, base_image_name, base_image_tag, base_image_id
+ ):
+ with self._build_status as status_dict:
+ status_dict["total_commands"] = len(command_comments) + 1
+
+ logger.debug(
+ "Checking cache on realm %s. Base image: %s:%s (%s)",
+ self.builder_realm,
+ base_image_name,
+ base_image_tag,
+ base_image_id,
+ )
+
+ tag_found = self._current_job.determine_cached_tag(
+ base_image_id, command_comments
+ )
+ raise Return(tag_found or "")
+
+ @trollius.coroutine
+ def _build_failure(self, error_message, exception=None):
+ """ Handles and logs a failed build. """
+ yield From(
+ self._build_status.set_error(
+ error_message, {"internal_error": str(exception) if exception else None}
+ )
+ )
- # Parse and update the phase and the status_dict. The status dictionary contains
- # the pull/push progress, as well as the current step index.
- with self._build_status as status_dict:
- try:
- changed_phase = yield From(self._build_status.set_phase(phase, log_data.get('status_data')))
- if changed_phase:
- logger.debug('Build %s has entered a new phase: %s', self.builder_realm, phase)
- elif self._current_job.repo_build.phase == BUILD_PHASE.CANCELLED:
- build_id = self._current_job.repo_build.uuid
- logger.debug('Trying to move cancelled build into phase: %s with id: %s', phase, build_id)
- raise Return(False)
- except InvalidRepositoryBuildException:
build_id = self._current_job.repo_build.uuid
- logger.warning('Build %s was not found; repo was probably deleted', build_id)
- raise Return(False)
+ logger.warning("Build %s failed with message: %s", build_id, error_message)
- BuildComponent._process_pushpull_status(status_dict, phase, log_data, self._image_info)
-
- # If the current message represents the beginning of a new step, then update the
- # current command index.
- if current_step is not None:
- status_dict['current_command'] = current_step
-
- # If the json data contains an error, then something went wrong with a push or pull.
- if 'error' in log_data:
- yield From(self._build_status.set_error(log_data['error']))
-
- if current_step is not None:
- yield From(self._build_status.set_command(current_status_string))
- elif phase == BUILD_PHASE.BUILDING:
- yield From(self._build_status.append_log(current_status_string))
- raise Return(True)
-
- @trollius.coroutine
- def _determine_cache_tag(self, command_comments, base_image_name, base_image_tag, base_image_id):
- with self._build_status as status_dict:
- status_dict['total_commands'] = len(command_comments) + 1
-
- logger.debug('Checking cache on realm %s. Base image: %s:%s (%s)', self.builder_realm,
- base_image_name, base_image_tag, base_image_id)
-
- tag_found = self._current_job.determine_cached_tag(base_image_id, command_comments)
- raise Return(tag_found or '')
-
- @trollius.coroutine
- def _build_failure(self, error_message, exception=None):
- """ Handles and logs a failed build. """
- yield From(self._build_status.set_error(error_message, {
- 'internal_error': str(exception) if exception else None
- }))
-
- build_id = self._current_job.repo_build.uuid
- logger.warning('Build %s failed with message: %s', build_id, error_message)
-
- # Mark that the build has finished (in an error state)
- yield From(self._build_finished(BuildJobResult.ERROR))
-
- @trollius.coroutine
- def _build_complete(self, result):
- """ Wraps up a completed build. Handles any errors and calls self._build_finished. """
- build_id = self._current_job.repo_build.uuid
-
- try:
- # Retrieve the result. This will raise an ApplicationError on any error that occurred.
- result_value = result.result()
- kwargs = {}
-
- # Note: If we are hitting an older builder that didn't return ANY map data, then the result
- # value will be a bool instead of a proper CallResult object.
- # Therefore: we have a try-except guard here to ensure we don't hit this pitfall.
- try:
- kwargs = result_value.kwresults
- except:
- pass
-
- try:
- yield From(self._build_status.set_phase(BUILD_PHASE.COMPLETE))
- except InvalidRepositoryBuildException:
- logger.warning('Build %s was not found; repo was probably deleted', build_id)
- raise Return()
-
- yield From(self._build_finished(BuildJobResult.COMPLETE))
-
- # Label the pushed manifests with the build metadata.
- manifest_digests = kwargs.get('digests') or []
- repository = registry_model.lookup_repository(self._current_job.namespace,
- self._current_job.repo_name)
- if repository is not None:
- for digest in manifest_digests:
- with UseThenDisconnect(app.config):
- manifest = registry_model.lookup_manifest_by_digest(repository, digest,
- require_available=True)
- if manifest is None:
- continue
-
- registry_model.create_manifest_label(manifest, INTERNAL_LABEL_BUILD_UUID,
- build_id, 'internal', 'text/plain')
-
- # Send the notification that the build has completed successfully.
- self._current_job.send_notification('build_success',
- image_id=kwargs.get('image_id'),
- manifest_digests=manifest_digests)
- except ApplicationError as aex:
- worker_error = WorkerError(aex.error, aex.kwargs.get('base_error'))
-
- # Write the error to the log.
- yield From(self._build_status.set_error(worker_error.public_message(),
- worker_error.extra_data(),
- internal_error=worker_error.is_internal_error(),
- requeued=self._current_job.has_retries_remaining()))
-
- # Send the notification that the build has failed.
- self._current_job.send_notification('build_failure',
- error_message=worker_error.public_message())
-
- # Mark the build as completed.
- if worker_error.is_internal_error():
- logger.exception('[BUILD INTERNAL ERROR: Remote] Build ID: %s: %s', build_id,
- worker_error.public_message())
- yield From(self._build_finished(BuildJobResult.INCOMPLETE))
- else:
- logger.debug('Got remote failure exception for build %s: %s', build_id, aex)
+ # Mark that the build has finished (in an error state)
yield From(self._build_finished(BuildJobResult.ERROR))
- # Remove the current job.
- self._current_job = None
+ @trollius.coroutine
+ def _build_complete(self, result):
+ """ Wraps up a completed build. Handles any errors and calls self._build_finished. """
+ build_id = self._current_job.repo_build.uuid
+ try:
+ # Retrieve the result. This will raise an ApplicationError on any error that occurred.
+ result_value = result.result()
+ kwargs = {}
- @trollius.coroutine
- def _build_finished(self, job_status):
- """ Alerts the parent that a build has completed and sets the status back to running. """
- yield From(self.parent_manager.job_completed(self._current_job, job_status, self))
+ # Note: If we are hitting an older builder that didn't return ANY map data, then the result
+ # value will be a bool instead of a proper CallResult object.
+ # Therefore: we have a try-except guard here to ensure we don't hit this pitfall.
+ try:
+ kwargs = result_value.kwresults
+ except:
+ pass
- # Set the component back to a running state.
- yield From(self._set_status(ComponentStatus.RUNNING))
+ try:
+ yield From(self._build_status.set_phase(BUILD_PHASE.COMPLETE))
+ except InvalidRepositoryBuildException:
+ logger.warning(
+ "Build %s was not found; repo was probably deleted", build_id
+ )
+ raise Return()
- @staticmethod
- def _ping():
- """ Ping pong. """
- return 'pong'
+ yield From(self._build_finished(BuildJobResult.COMPLETE))
- @trollius.coroutine
- def _on_ready(self, token, version):
- logger.debug('On ready called (token "%s")', token)
- self._worker_version = version
+ # Label the pushed manifests with the build metadata.
+ manifest_digests = kwargs.get("digests") or []
+ repository = registry_model.lookup_repository(
+ self._current_job.namespace, self._current_job.repo_name
+ )
+ if repository is not None:
+ for digest in manifest_digests:
+ with UseThenDisconnect(app.config):
+ manifest = registry_model.lookup_manifest_by_digest(
+ repository, digest, require_available=True
+ )
+ if manifest is None:
+ continue
- if not version in SUPPORTED_WORKER_VERSIONS:
- logger.warning('Build component (token "%s") is running an out-of-date version: %s', token,
- version)
- raise Return(False)
+ registry_model.create_manifest_label(
+ manifest,
+ INTERNAL_LABEL_BUILD_UUID,
+ build_id,
+ "internal",
+ "text/plain",
+ )
- if self._component_status != ComponentStatus.WAITING:
- logger.warning('Build component (token "%s") is already connected', self.expected_token)
- raise Return(False)
+ # Send the notification that the build has completed successfully.
+ self._current_job.send_notification(
+ "build_success",
+ image_id=kwargs.get("image_id"),
+ manifest_digests=manifest_digests,
+ )
+ except ApplicationError as aex:
+ worker_error = WorkerError(aex.error, aex.kwargs.get("base_error"))
- if token != self.expected_token:
- logger.warning('Builder token mismatch. Expected: "%s". Found: "%s"', self.expected_token,
- token)
- raise Return(False)
+ # Write the error to the log.
+ yield From(
+ self._build_status.set_error(
+ worker_error.public_message(),
+ worker_error.extra_data(),
+ internal_error=worker_error.is_internal_error(),
+ requeued=self._current_job.has_retries_remaining(),
+ )
+ )
- yield From(self._set_status(ComponentStatus.RUNNING))
+ # Send the notification that the build has failed.
+ self._current_job.send_notification(
+ "build_failure", error_message=worker_error.public_message()
+ )
- # Start the heartbeat check and updating loop.
- loop = trollius.get_event_loop()
- loop.create_task(self._heartbeat())
- logger.debug('Build worker %s is connected and ready', self.builder_realm)
- raise Return(True)
+ # Mark the build as completed.
+ if worker_error.is_internal_error():
+ logger.exception(
+ "[BUILD INTERNAL ERROR: Remote] Build ID: %s: %s",
+ build_id,
+ worker_error.public_message(),
+ )
+ yield From(self._build_finished(BuildJobResult.INCOMPLETE))
+ else:
+ logger.debug(
+ "Got remote failure exception for build %s: %s", build_id, aex
+ )
+ yield From(self._build_finished(BuildJobResult.ERROR))
- @trollius.coroutine
- def _set_status(self, phase):
- if phase == ComponentStatus.RUNNING:
- yield From(self.parent_manager.build_component_ready(self))
+ # Remove the current job.
+ self._current_job = None
- self._component_status = phase
+ @trollius.coroutine
+ def _build_finished(self, job_status):
+ """ Alerts the parent that a build has completed and sets the status back to running. """
+ yield From(
+ self.parent_manager.job_completed(self._current_job, job_status, self)
+ )
- def _on_heartbeat(self):
- """ Updates the last known heartbeat. """
- if self._component_status == ComponentStatus.TIMED_OUT:
- return
+ # Set the component back to a running state.
+ yield From(self._set_status(ComponentStatus.RUNNING))
- logger.debug('Got heartbeat on realm %s', self.builder_realm)
- self._last_heartbeat = datetime.datetime.utcnow()
+ @staticmethod
+ def _ping():
+ """ Ping pong. """
+ return "pong"
- @trollius.coroutine
- def _heartbeat(self):
- """ Coroutine that runs every HEARTBEAT_TIMEOUT seconds, both checking the worker's heartbeat
+ @trollius.coroutine
+ def _on_ready(self, token, version):
+ logger.debug('On ready called (token "%s")', token)
+ self._worker_version = version
+
+ if not version in SUPPORTED_WORKER_VERSIONS:
+ logger.warning(
+ 'Build component (token "%s") is running an out-of-date version: %s',
+ token,
+ version,
+ )
+ raise Return(False)
+
+ if self._component_status != ComponentStatus.WAITING:
+ logger.warning(
+ 'Build component (token "%s") is already connected', self.expected_token
+ )
+ raise Return(False)
+
+ if token != self.expected_token:
+ logger.warning(
+ 'Builder token mismatch. Expected: "%s". Found: "%s"',
+ self.expected_token,
+ token,
+ )
+ raise Return(False)
+
+ yield From(self._set_status(ComponentStatus.RUNNING))
+
+ # Start the heartbeat check and updating loop.
+ loop = trollius.get_event_loop()
+ loop.create_task(self._heartbeat())
+ logger.debug("Build worker %s is connected and ready", self.builder_realm)
+ raise Return(True)
+
+ @trollius.coroutine
+ def _set_status(self, phase):
+ if phase == ComponentStatus.RUNNING:
+ yield From(self.parent_manager.build_component_ready(self))
+
+ self._component_status = phase
+
+ def _on_heartbeat(self):
+ """ Updates the last known heartbeat. """
+ if self._component_status == ComponentStatus.TIMED_OUT:
+ return
+
+ logger.debug("Got heartbeat on realm %s", self.builder_realm)
+ self._last_heartbeat = datetime.datetime.utcnow()
+
+ @trollius.coroutine
+ def _heartbeat(self):
+ """ Coroutine that runs every HEARTBEAT_TIMEOUT seconds, both checking the worker's heartbeat
and updating the heartbeat in the build status dictionary (if applicable). This allows
the build system to catch crashes from either end.
"""
- yield From(trollius.sleep(INITIAL_TIMEOUT))
+ yield From(trollius.sleep(INITIAL_TIMEOUT))
- while True:
- # If the component is no longer running or actively building, nothing more to do.
- if (self._component_status != ComponentStatus.RUNNING and
- self._component_status != ComponentStatus.BUILDING):
- raise Return()
+ while True:
+ # If the component is no longer running or actively building, nothing more to do.
+ if (
+ self._component_status != ComponentStatus.RUNNING
+ and self._component_status != ComponentStatus.BUILDING
+ ):
+ raise Return()
- # If there is an active build, write the heartbeat to its status.
- if self._build_status is not None:
- with self._build_status as status_dict:
- status_dict['heartbeat'] = int(time.time())
+ # If there is an active build, write the heartbeat to its status.
+ if self._build_status is not None:
+ with self._build_status as status_dict:
+ status_dict["heartbeat"] = int(time.time())
- # Mark the build item.
- current_job = self._current_job
- if current_job is not None:
- yield From(self.parent_manager.job_heartbeat(current_job))
+ # Mark the build item.
+ current_job = self._current_job
+ if current_job is not None:
+ yield From(self.parent_manager.job_heartbeat(current_job))
- # Check the heartbeat from the worker.
- logger.debug('Checking heartbeat on realm %s', self.builder_realm)
- if (self._last_heartbeat and
- self._last_heartbeat < datetime.datetime.utcnow() - HEARTBEAT_DELTA):
- logger.debug('Heartbeat on realm %s has expired: %s', self.builder_realm,
- self._last_heartbeat)
+ # Check the heartbeat from the worker.
+ logger.debug("Checking heartbeat on realm %s", self.builder_realm)
+ if (
+ self._last_heartbeat
+ and self._last_heartbeat < datetime.datetime.utcnow() - HEARTBEAT_DELTA
+ ):
+ logger.debug(
+ "Heartbeat on realm %s has expired: %s",
+ self.builder_realm,
+ self._last_heartbeat,
+ )
- yield From(self._timeout())
- raise Return()
+ yield From(self._timeout())
+ raise Return()
- logger.debug('Heartbeat on realm %s is valid: %s (%s).', self.builder_realm,
- self._last_heartbeat, self._component_status)
+ logger.debug(
+ "Heartbeat on realm %s is valid: %s (%s).",
+ self.builder_realm,
+ self._last_heartbeat,
+ self._component_status,
+ )
- yield From(trollius.sleep(HEARTBEAT_TIMEOUT))
+ yield From(trollius.sleep(HEARTBEAT_TIMEOUT))
- @trollius.coroutine
- def _timeout(self):
- if self._component_status == ComponentStatus.TIMED_OUT:
- raise Return()
+ @trollius.coroutine
+ def _timeout(self):
+ if self._component_status == ComponentStatus.TIMED_OUT:
+ raise Return()
- yield From(self._set_status(ComponentStatus.TIMED_OUT))
- logger.warning('Build component with realm %s has timed out', self.builder_realm)
+ yield From(self._set_status(ComponentStatus.TIMED_OUT))
+ logger.warning(
+ "Build component with realm %s has timed out", self.builder_realm
+ )
- # If we still have a running job, then it has not completed and we need to tell the parent
- # manager.
- if self._current_job is not None:
- yield From(self._build_status.set_error('Build worker timed out', internal_error=True,
- requeued=self._current_job.has_retries_remaining()))
+ # If we still have a running job, then it has not completed and we need to tell the parent
+ # manager.
+ if self._current_job is not None:
+ yield From(
+ self._build_status.set_error(
+ "Build worker timed out",
+ internal_error=True,
+ requeued=self._current_job.has_retries_remaining(),
+ )
+ )
- build_id = self._current_job.build_uuid
- logger.error('[BUILD INTERNAL ERROR: Timeout] Build ID: %s', build_id)
- yield From(self.parent_manager.job_completed(self._current_job,
- BuildJobResult.INCOMPLETE,
- self))
+ build_id = self._current_job.build_uuid
+ logger.error("[BUILD INTERNAL ERROR: Timeout] Build ID: %s", build_id)
+ yield From(
+ self.parent_manager.job_completed(
+ self._current_job, BuildJobResult.INCOMPLETE, self
+ )
+ )
- # Unregister the current component so that it cannot be invoked again.
- self.parent_manager.build_component_disposed(self, True)
+ # Unregister the current component so that it cannot be invoked again.
+ self.parent_manager.build_component_disposed(self, True)
- # Remove the job reference.
- self._current_job = None
+ # Remove the job reference.
+ self._current_job = None
- @trollius.coroutine
- def cancel_build(self):
- self.parent_manager.build_component_disposed(self, True)
- self._current_job = None
- yield From(self._set_status(ComponentStatus.RUNNING))
+ @trollius.coroutine
+ def cancel_build(self):
+ self.parent_manager.build_component_disposed(self, True)
+ self._current_job = None
+ yield From(self._set_status(ComponentStatus.RUNNING))
diff --git a/buildman/component/buildparse.py b/buildman/component/buildparse.py
index 3560c0861..18d678cae 100644
--- a/buildman/component/buildparse.py
+++ b/buildman/component/buildparse.py
@@ -1,15 +1,16 @@
import re
+
def extract_current_step(current_status_string):
- """ Attempts to extract the current step numeric identifier from the given status string. Returns the step
+ """ Attempts to extract the current step numeric identifier from the given status string. Returns the step
number or None if none.
"""
- # Older format: `Step 12 :`
- # Newer format: `Step 4/13 :`
- step_increment = re.search(r'Step ([0-9]+)/([0-9]+) :', current_status_string)
- if step_increment:
- return int(step_increment.group(1))
+ # Older format: `Step 12 :`
+ # Newer format: `Step 4/13 :`
+ step_increment = re.search(r"Step ([0-9]+)/([0-9]+) :", current_status_string)
+ if step_increment:
+ return int(step_increment.group(1))
- step_increment = re.search(r'Step ([0-9]+) :', current_status_string)
- if step_increment:
- return int(step_increment.group(1))
+ step_increment = re.search(r"Step ([0-9]+) :", current_status_string)
+ if step_increment:
+ return int(step_increment.group(1))
diff --git a/buildman/component/test/test_buildcomponent.py b/buildman/component/test/test_buildcomponent.py
index c4e026916..98d70dab0 100644
--- a/buildman/component/test/test_buildcomponent.py
+++ b/buildman/component/test/test_buildcomponent.py
@@ -3,34 +3,62 @@ import pytest
from buildman.component.buildcomponent import BuildComponent
-@pytest.mark.parametrize('input,expected_path,expected_file', [
- ("", "/", "Dockerfile"),
- ("/", "/", "Dockerfile"),
- ("/Dockerfile", "/", "Dockerfile"),
- ("/server.Dockerfile", "/", "server.Dockerfile"),
- ("/somepath", "/somepath", "Dockerfile"),
- ("/somepath/", "/somepath", "Dockerfile"),
- ("/somepath/Dockerfile", "/somepath", "Dockerfile"),
- ("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
- ("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
- ("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
- ("/somepath/some_other_path/Dockerfile", "/somepath/some_other_path", "Dockerfile"),
- ("/somepath/some_other_path/server.Dockerfile", "/somepath/some_other_path", "server.Dockerfile"),
-])
+@pytest.mark.parametrize(
+ "input,expected_path,expected_file",
+ [
+ ("", "/", "Dockerfile"),
+ ("/", "/", "Dockerfile"),
+ ("/Dockerfile", "/", "Dockerfile"),
+ ("/server.Dockerfile", "/", "server.Dockerfile"),
+ ("/somepath", "/somepath", "Dockerfile"),
+ ("/somepath/", "/somepath", "Dockerfile"),
+ ("/somepath/Dockerfile", "/somepath", "Dockerfile"),
+ ("/somepath/server.Dockerfile", "/somepath", "server.Dockerfile"),
+ ("/somepath/some_other_path", "/somepath/some_other_path", "Dockerfile"),
+ ("/somepath/some_other_path/", "/somepath/some_other_path", "Dockerfile"),
+ (
+ "/somepath/some_other_path/Dockerfile",
+ "/somepath/some_other_path",
+ "Dockerfile",
+ ),
+ (
+ "/somepath/some_other_path/server.Dockerfile",
+ "/somepath/some_other_path",
+ "server.Dockerfile",
+ ),
+ ],
+)
def test_path_is_dockerfile(input, expected_path, expected_file):
- actual_path, actual_file = BuildComponent.name_and_path(input)
- assert actual_path == expected_path
- assert actual_file == expected_file
+ actual_path, actual_file = BuildComponent.name_and_path(input)
+ assert actual_path == expected_path
+ assert actual_file == expected_file
-@pytest.mark.parametrize('build_config,context,dockerfile_path', [
- ({}, '', ''),
- ({'build_subdir': '/builddir/Dockerfile'}, '', '/builddir/Dockerfile'),
- ({'context': '/builddir'}, '/builddir', ''),
- ({'context': '/builddir', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'),
- ({'context': '/some_other_dir/Dockerfile', 'build_subdir': '/builddir/Dockerfile'}, '/builddir', 'Dockerfile'),
- ({'context': '/', 'build_subdir':'Dockerfile'}, '/', 'Dockerfile')
-])
+
+@pytest.mark.parametrize(
+ "build_config,context,dockerfile_path",
+ [
+ ({}, "", ""),
+ ({"build_subdir": "/builddir/Dockerfile"}, "", "/builddir/Dockerfile"),
+ ({"context": "/builddir"}, "/builddir", ""),
+ (
+ {"context": "/builddir", "build_subdir": "/builddir/Dockerfile"},
+ "/builddir",
+ "Dockerfile",
+ ),
+ (
+ {
+ "context": "/some_other_dir/Dockerfile",
+ "build_subdir": "/builddir/Dockerfile",
+ },
+ "/builddir",
+ "Dockerfile",
+ ),
+ ({"context": "/", "build_subdir": "Dockerfile"}, "/", "Dockerfile"),
+ ],
+)
def test_extract_dockerfile_args(build_config, context, dockerfile_path):
- actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(build_config)
- assert context == actual_context
- assert dockerfile_path == actual_dockerfile_path
+ actual_context, actual_dockerfile_path = BuildComponent.extract_dockerfile_args(
+ build_config
+ )
+ assert context == actual_context
+ assert dockerfile_path == actual_dockerfile_path
diff --git a/buildman/component/test/test_buildparse.py b/buildman/component/test/test_buildparse.py
index 3bdb7295e..e40b20189 100644
--- a/buildman/component/test/test_buildparse.py
+++ b/buildman/component/test/test_buildparse.py
@@ -3,14 +3,17 @@ import pytest
from buildman.component.buildparse import extract_current_step
-@pytest.mark.parametrize('input,expected_step', [
- ("", None),
- ("Step a :", None),
- ("Step 1 :", 1),
- ("Step 1 : ", 1),
- ("Step 1/2 : ", 1),
- ("Step 2/17 : ", 2),
- ("Step 4/13 : ARG somearg=foo", 4),
-])
+@pytest.mark.parametrize(
+ "input,expected_step",
+ [
+ ("", None),
+ ("Step a :", None),
+ ("Step 1 :", 1),
+ ("Step 1 : ", 1),
+ ("Step 1/2 : ", 1),
+ ("Step 2/17 : ", 2),
+ ("Step 4/13 : ARG somearg=foo", 4),
+ ],
+)
def test_extract_current_step(input, expected_step):
- assert extract_current_step(input) == expected_step
+ assert extract_current_step(input) == expected_step
diff --git a/buildman/enums.py b/buildman/enums.py
index f88d2b690..a7fe7bb99 100644
--- a/buildman/enums.py
+++ b/buildman/enums.py
@@ -1,21 +1,25 @@
from data.database import BUILD_PHASE
+
class BuildJobResult(object):
- """ Build job result enum """
- INCOMPLETE = 'incomplete'
- COMPLETE = 'complete'
- ERROR = 'error'
+ """ Build job result enum """
+
+ INCOMPLETE = "incomplete"
+ COMPLETE = "complete"
+ ERROR = "error"
class BuildServerStatus(object):
- """ Build server status enum """
- STARTING = 'starting'
- RUNNING = 'running'
- SHUTDOWN = 'shutting_down'
- EXCEPTION = 'exception'
+ """ Build server status enum """
+
+ STARTING = "starting"
+ RUNNING = "running"
+ SHUTDOWN = "shutting_down"
+ EXCEPTION = "exception"
+
RESULT_PHASES = {
- BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,
- BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE,
- BuildJobResult.ERROR: BUILD_PHASE.ERROR,
+ BuildJobResult.INCOMPLETE: BUILD_PHASE.INTERNAL_ERROR,
+ BuildJobResult.COMPLETE: BUILD_PHASE.COMPLETE,
+ BuildJobResult.ERROR: BUILD_PHASE.ERROR,
}
diff --git a/buildman/jobutil/buildjob.py b/buildman/jobutil/buildjob.py
index f245ce2bf..8ffbe5cd8 100644
--- a/buildman/jobutil/buildjob.py
+++ b/buildman/jobutil/buildjob.py
@@ -14,170 +14,196 @@ logger = logging.getLogger(__name__)
class BuildJobLoadException(Exception):
- """ Exception raised if a build job could not be instantiated for some reason. """
- pass
+ """ Exception raised if a build job could not be instantiated for some reason. """
+
+ pass
class BuildJob(object):
- """ Represents a single in-progress build job. """
- def __init__(self, job_item):
- self.job_item = job_item
+ """ Represents a single in-progress build job. """
- try:
- self.job_details = json.loads(job_item.body)
- self.build_notifier = BuildJobNotifier(self.build_uuid)
- except ValueError:
- raise BuildJobLoadException(
- 'Could not parse build queue item config with ID %s' % self.job_details['build_uuid']
- )
+ def __init__(self, job_item):
+ self.job_item = job_item
- @property
- def retries_remaining(self):
- return self.job_item.retries_remaining
+ try:
+ self.job_details = json.loads(job_item.body)
+ self.build_notifier = BuildJobNotifier(self.build_uuid)
+ except ValueError:
+ raise BuildJobLoadException(
+ "Could not parse build queue item config with ID %s"
+ % self.job_details["build_uuid"]
+ )
- def has_retries_remaining(self):
- return self.job_item.retries_remaining > 0
+ @property
+ def retries_remaining(self):
+ return self.job_item.retries_remaining
- def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None):
- self.build_notifier.send_notification(kind, error_message, image_id, manifest_digests)
+ def has_retries_remaining(self):
+ return self.job_item.retries_remaining > 0
- @lru_cache(maxsize=1)
- def _load_repo_build(self):
- with UseThenDisconnect(app.config):
- try:
- return model.build.get_repository_build(self.build_uuid)
- except model.InvalidRepositoryBuildException:
- raise BuildJobLoadException(
- 'Could not load repository build with ID %s' % self.build_uuid)
+ def send_notification(
+ self, kind, error_message=None, image_id=None, manifest_digests=None
+ ):
+ self.build_notifier.send_notification(
+ kind, error_message, image_id, manifest_digests
+ )
- @property
- def build_uuid(self):
- """ Returns the unique UUID for this build job. """
- return self.job_details['build_uuid']
+ @lru_cache(maxsize=1)
+ def _load_repo_build(self):
+ with UseThenDisconnect(app.config):
+ try:
+ return model.build.get_repository_build(self.build_uuid)
+ except model.InvalidRepositoryBuildException:
+ raise BuildJobLoadException(
+ "Could not load repository build with ID %s" % self.build_uuid
+ )
- @property
- def namespace(self):
- """ Returns the namespace under which this build is running. """
- return self.repo_build.repository.namespace_user.username
+ @property
+ def build_uuid(self):
+ """ Returns the unique UUID for this build job. """
+ return self.job_details["build_uuid"]
- @property
- def repo_name(self):
- """ Returns the name of the repository under which this build is running. """
- return self.repo_build.repository.name
+ @property
+ def namespace(self):
+ """ Returns the namespace under which this build is running. """
+ return self.repo_build.repository.namespace_user.username
- @property
- def repo_build(self):
- return self._load_repo_build()
+ @property
+ def repo_name(self):
+ """ Returns the name of the repository under which this build is running. """
+ return self.repo_build.repository.name
- def get_build_package_url(self, user_files):
- """ Returns the URL of the build package for this build, if any or empty string if none. """
- archive_url = self.build_config.get('archive_url', None)
- if archive_url:
- return archive_url
+ @property
+ def repo_build(self):
+ return self._load_repo_build()
- if not self.repo_build.resource_key:
- return ''
+ def get_build_package_url(self, user_files):
+ """ Returns the URL of the build package for this build, if any or empty string if none. """
+ archive_url = self.build_config.get("archive_url", None)
+ if archive_url:
+ return archive_url
- return user_files.get_file_url(self.repo_build.resource_key, '127.0.0.1', requires_cors=False)
+ if not self.repo_build.resource_key:
+ return ""
- @property
- def pull_credentials(self):
- """ Returns the pull credentials for this job, or None if none. """
- return self.job_details.get('pull_credentials')
+ return user_files.get_file_url(
+ self.repo_build.resource_key, "127.0.0.1", requires_cors=False
+ )
- @property
- def build_config(self):
- try:
- return json.loads(self.repo_build.job_config)
- except ValueError:
- raise BuildJobLoadException(
- 'Could not parse repository build job config with ID %s' % self.job_details['build_uuid']
- )
+ @property
+ def pull_credentials(self):
+ """ Returns the pull credentials for this job, or None if none. """
+ return self.job_details.get("pull_credentials")
- def determine_cached_tag(self, base_image_id=None, cache_comments=None):
- """ Returns the tag to pull to prime the cache or None if none. """
- cached_tag = self._determine_cached_tag_by_tag()
- logger.debug('Determined cached tag %s for %s: %s', cached_tag, base_image_id, cache_comments)
- return cached_tag
+ @property
+ def build_config(self):
+ try:
+ return json.loads(self.repo_build.job_config)
+ except ValueError:
+ raise BuildJobLoadException(
+ "Could not parse repository build job config with ID %s"
+ % self.job_details["build_uuid"]
+ )
- def _determine_cached_tag_by_tag(self):
- """ Determines the cached tag by looking for one of the tags being built, and seeing if it
+ def determine_cached_tag(self, base_image_id=None, cache_comments=None):
+ """ Returns the tag to pull to prime the cache or None if none. """
+ cached_tag = self._determine_cached_tag_by_tag()
+ logger.debug(
+ "Determined cached tag %s for %s: %s",
+ cached_tag,
+ base_image_id,
+ cache_comments,
+ )
+ return cached_tag
+
+ def _determine_cached_tag_by_tag(self):
+ """ Determines the cached tag by looking for one of the tags being built, and seeing if it
exists in the repository. This is a fallback for when no comment information is available.
"""
- with UseThenDisconnect(app.config):
- tags = self.build_config.get('docker_tags', ['latest'])
- repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
- matching_tag = registry_model.find_matching_tag(repository, tags)
- if matching_tag is not None:
- return matching_tag.name
+ with UseThenDisconnect(app.config):
+ tags = self.build_config.get("docker_tags", ["latest"])
+ repository = RepositoryReference.for_repo_obj(self.repo_build.repository)
+ matching_tag = registry_model.find_matching_tag(repository, tags)
+ if matching_tag is not None:
+ return matching_tag.name
- most_recent_tag = registry_model.get_most_recent_tag(repository)
- if most_recent_tag is not None:
- return most_recent_tag.name
+ most_recent_tag = registry_model.get_most_recent_tag(repository)
+ if most_recent_tag is not None:
+ return most_recent_tag.name
- return None
+ return None
class BuildJobNotifier(object):
- """ A class for sending notifications to a job that only relies on the build_uuid """
+ """ A class for sending notifications to a job that only relies on the build_uuid """
- def __init__(self, build_uuid):
- self.build_uuid = build_uuid
+ def __init__(self, build_uuid):
+ self.build_uuid = build_uuid
- @property
- def repo_build(self):
- return self._load_repo_build()
+ @property
+ def repo_build(self):
+ return self._load_repo_build()
- @lru_cache(maxsize=1)
- def _load_repo_build(self):
- try:
- return model.build.get_repository_build(self.build_uuid)
- except model.InvalidRepositoryBuildException:
- raise BuildJobLoadException(
- 'Could not load repository build with ID %s' % self.build_uuid)
+ @lru_cache(maxsize=1)
+ def _load_repo_build(self):
+ try:
+ return model.build.get_repository_build(self.build_uuid)
+ except model.InvalidRepositoryBuildException:
+ raise BuildJobLoadException(
+ "Could not load repository build with ID %s" % self.build_uuid
+ )
- @property
- def build_config(self):
- try:
- return json.loads(self.repo_build.job_config)
- except ValueError:
- raise BuildJobLoadException(
- 'Could not parse repository build job config with ID %s' % self.repo_build.uuid
- )
+ @property
+ def build_config(self):
+ try:
+ return json.loads(self.repo_build.job_config)
+ except ValueError:
+ raise BuildJobLoadException(
+ "Could not parse repository build job config with ID %s"
+ % self.repo_build.uuid
+ )
- def send_notification(self, kind, error_message=None, image_id=None, manifest_digests=None):
- with UseThenDisconnect(app.config):
- tags = self.build_config.get('docker_tags', ['latest'])
- trigger = self.repo_build.trigger
- if trigger is not None and trigger.id is not None:
- trigger_kind = trigger.service.name
- else:
- trigger_kind = None
+ def send_notification(
+ self, kind, error_message=None, image_id=None, manifest_digests=None
+ ):
+ with UseThenDisconnect(app.config):
+ tags = self.build_config.get("docker_tags", ["latest"])
+ trigger = self.repo_build.trigger
+ if trigger is not None and trigger.id is not None:
+ trigger_kind = trigger.service.name
+ else:
+ trigger_kind = None
- event_data = {
- 'build_id': self.repo_build.uuid,
- 'build_name': self.repo_build.display_name,
- 'docker_tags': tags,
- 'trigger_id': trigger.uuid if trigger is not None else None,
- 'trigger_kind': trigger_kind,
- 'trigger_metadata': self.build_config.get('trigger_metadata', {})
- }
+ event_data = {
+ "build_id": self.repo_build.uuid,
+ "build_name": self.repo_build.display_name,
+ "docker_tags": tags,
+ "trigger_id": trigger.uuid if trigger is not None else None,
+ "trigger_kind": trigger_kind,
+ "trigger_metadata": self.build_config.get("trigger_metadata", {}),
+ }
- if image_id is not None:
- event_data['image_id'] = image_id
+ if image_id is not None:
+ event_data["image_id"] = image_id
- if manifest_digests:
- event_data['manifest_digests'] = manifest_digests
+ if manifest_digests:
+ event_data["manifest_digests"] = manifest_digests
- if error_message is not None:
- event_data['error_message'] = error_message
+ if error_message is not None:
+ event_data["error_message"] = error_message
- # TODO: remove when more endpoints have been converted to using
- # interfaces
- repo = AttrDict({
- 'namespace_name': self.repo_build.repository.namespace_user.username,
- 'name': self.repo_build.repository.name,
- })
- spawn_notification(repo, kind, event_data,
- subpage='build/%s' % self.repo_build.uuid,
- pathargs=['build', self.repo_build.uuid])
+ # TODO: remove when more endpoints have been converted to using
+ # interfaces
+ repo = AttrDict(
+ {
+ "namespace_name": self.repo_build.repository.namespace_user.username,
+ "name": self.repo_build.repository.name,
+ }
+ )
+ spawn_notification(
+ repo,
+ kind,
+ event_data,
+ subpage="build/%s" % self.repo_build.uuid,
+ pathargs=["build", self.repo_build.uuid],
+ )
diff --git a/buildman/jobutil/buildstatus.py b/buildman/jobutil/buildstatus.py
index 662dbaa10..f7bf4a767 100644
--- a/buildman/jobutil/buildstatus.py
+++ b/buildman/jobutil/buildstatus.py
@@ -13,76 +13,94 @@ logger = logging.getLogger(__name__)
class StatusHandler(object):
- """ Context wrapper for writing status to build logs. """
+ """ Context wrapper for writing status to build logs. """
- def __init__(self, build_logs, repository_build_uuid):
- self._current_phase = None
- self._current_command = None
- self._uuid = repository_build_uuid
- self._build_logs = AsyncWrapper(build_logs)
- self._sync_build_logs = build_logs
- self._build_model = AsyncWrapper(model.build)
+ def __init__(self, build_logs, repository_build_uuid):
+ self._current_phase = None
+ self._current_command = None
+ self._uuid = repository_build_uuid
+ self._build_logs = AsyncWrapper(build_logs)
+ self._sync_build_logs = build_logs
+ self._build_model = AsyncWrapper(model.build)
- self._status = {
- 'total_commands': 0,
- 'current_command': None,
- 'push_completion': 0.0,
- 'pull_completion': 0.0,
- }
+ self._status = {
+ "total_commands": 0,
+ "current_command": None,
+ "push_completion": 0.0,
+ "pull_completion": 0.0,
+ }
- # Write the initial status.
- self.__exit__(None, None, None)
+ # Write the initial status.
+ self.__exit__(None, None, None)
- @coroutine
- def _append_log_message(self, log_message, log_type=None, log_data=None):
- log_data = log_data or {}
- log_data['datetime'] = str(datetime.datetime.now())
+ @coroutine
+ def _append_log_message(self, log_message, log_type=None, log_data=None):
+ log_data = log_data or {}
+ log_data["datetime"] = str(datetime.datetime.now())
- try:
- yield From(self._build_logs.append_log_message(self._uuid, log_message, log_type, log_data))
- except RedisError:
- logger.exception('Could not save build log for build %s: %s', self._uuid, log_message)
+ try:
+ yield From(
+ self._build_logs.append_log_message(
+ self._uuid, log_message, log_type, log_data
+ )
+ )
+ except RedisError:
+ logger.exception(
+ "Could not save build log for build %s: %s", self._uuid, log_message
+ )
- @coroutine
- def append_log(self, log_message, extra_data=None):
- if log_message is None:
- return
+ @coroutine
+ def append_log(self, log_message, extra_data=None):
+ if log_message is None:
+ return
- yield From(self._append_log_message(log_message, log_data=extra_data))
+ yield From(self._append_log_message(log_message, log_data=extra_data))
- @coroutine
- def set_command(self, command, extra_data=None):
- if self._current_command == command:
- raise Return()
+ @coroutine
+ def set_command(self, command, extra_data=None):
+ if self._current_command == command:
+ raise Return()
- self._current_command = command
- yield From(self._append_log_message(command, self._build_logs.COMMAND, extra_data))
+ self._current_command = command
+ yield From(
+ self._append_log_message(command, self._build_logs.COMMAND, extra_data)
+ )
- @coroutine
- def set_error(self, error_message, extra_data=None, internal_error=False, requeued=False):
- error_phase = BUILD_PHASE.INTERNAL_ERROR if internal_error and requeued else BUILD_PHASE.ERROR
- yield From(self.set_phase(error_phase))
+ @coroutine
+ def set_error(
+ self, error_message, extra_data=None, internal_error=False, requeued=False
+ ):
+ error_phase = (
+ BUILD_PHASE.INTERNAL_ERROR
+ if internal_error and requeued
+ else BUILD_PHASE.ERROR
+ )
+ yield From(self.set_phase(error_phase))
- extra_data = extra_data or {}
- extra_data['internal_error'] = internal_error
- yield From(self._append_log_message(error_message, self._build_logs.ERROR, extra_data))
+ extra_data = extra_data or {}
+ extra_data["internal_error"] = internal_error
+ yield From(
+ self._append_log_message(error_message, self._build_logs.ERROR, extra_data)
+ )
- @coroutine
- def set_phase(self, phase, extra_data=None):
- if phase == self._current_phase:
- raise Return(False)
+ @coroutine
+ def set_phase(self, phase, extra_data=None):
+ if phase == self._current_phase:
+ raise Return(False)
- self._current_phase = phase
- yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data))
+ self._current_phase = phase
+ yield From(self._append_log_message(phase, self._build_logs.PHASE, extra_data))
- # Update the repository build with the new phase
- raise Return(self._build_model.update_phase_then_close(self._uuid, phase))
+ # Update the repository build with the new phase
+ raise Return(self._build_model.update_phase_then_close(self._uuid, phase))
- def __enter__(self):
- return self._status
+ def __enter__(self):
+ return self._status
- def __exit__(self, exc_type, value, traceback):
- try:
- self._sync_build_logs.set_status(self._uuid, self._status)
- except RedisError:
- logger.exception('Could not set status of build %s to %s', self._uuid, self._status)
+ def __exit__(self, exc_type, value, traceback):
+ try:
+ self._sync_build_logs.set_status(self._uuid, self._status)
+ except RedisError:
+ logger.exception(
+ "Could not set status of build %s to %s", self._uuid, self._status
+ )
diff --git a/buildman/jobutil/workererror.py b/buildman/jobutil/workererror.py
index 9245f312e..111ffad2d 100644
--- a/buildman/jobutil/workererror.py
+++ b/buildman/jobutil/workererror.py
@@ -1,119 +1,99 @@
class WorkerError(object):
- """ Helper class which represents errors raised by a build worker. """
- def __init__(self, error_code, base_message=None):
- self._error_code = error_code
- self._base_message = base_message
+ """ Helper class which represents errors raised by a build worker. """
- self._error_handlers = {
- 'io.quay.builder.buildpackissue': {
- 'message': 'Could not load build package',
- 'is_internal': True,
- },
+ def __init__(self, error_code, base_message=None):
+ self._error_code = error_code
+ self._base_message = base_message
- 'io.quay.builder.gitfailure': {
- 'message': 'Could not clone git repository',
- 'show_base_error': True,
- },
+ self._error_handlers = {
+ "io.quay.builder.buildpackissue": {
+ "message": "Could not load build package",
+ "is_internal": True,
+ },
+ "io.quay.builder.gitfailure": {
+ "message": "Could not clone git repository",
+ "show_base_error": True,
+ },
+ "io.quay.builder.gitcheckout": {
+ "message": "Could not checkout git ref. If you force pushed recently, "
+ + "the commit may be missing.",
+ "show_base_error": True,
+ },
+ "io.quay.builder.cannotextractbuildpack": {
+ "message": "Could not extract the contents of the build package"
+ },
+ "io.quay.builder.cannotpullforcache": {
+ "message": "Could not pull cached image",
+ "is_internal": True,
+ },
+ "io.quay.builder.dockerfileissue": {
+ "message": "Could not find or parse Dockerfile",
+ "show_base_error": True,
+ },
+ "io.quay.builder.cannotpullbaseimage": {
+ "message": "Could not pull base image",
+ "show_base_error": True,
+ },
+ "io.quay.builder.internalerror": {
+ "message": "An internal error occurred while building. Please submit a ticket.",
+ "is_internal": True,
+ },
+ "io.quay.builder.buildrunerror": {
+ "message": "Could not start the build process",
+ "is_internal": True,
+ },
+ "io.quay.builder.builderror": {
+ "message": "A build step failed",
+ "show_base_error": True,
+ },
+ "io.quay.builder.tagissue": {
+ "message": "Could not tag built image",
+ "is_internal": True,
+ },
+ "io.quay.builder.pushissue": {
+ "message": "Could not push built image",
+ "show_base_error": True,
+ "is_internal": True,
+ },
+ "io.quay.builder.dockerconnecterror": {
+ "message": "Could not connect to Docker daemon",
+ "is_internal": True,
+ },
+ "io.quay.builder.missingorinvalidargument": {
+ "message": "Missing required arguments for builder",
+ "is_internal": True,
+ },
+ "io.quay.builder.cachelookupissue": {
+ "message": "Error checking for a cached tag",
+ "is_internal": True,
+ },
+ "io.quay.builder.errorduringphasetransition": {
+ "message": "Error during phase transition. If this problem persists "
+ + "please contact customer support.",
+ "is_internal": True,
+ },
+ "io.quay.builder.clientrejectedtransition": {
+ "message": "Build can not be finished due to user cancellation."
+ },
+ }
- 'io.quay.builder.gitcheckout': {
- 'message': 'Could not checkout git ref. If you force pushed recently, ' +
- 'the commit may be missing.',
- 'show_base_error': True,
- },
+ def is_internal_error(self):
+ handler = self._error_handlers.get(self._error_code)
+ return handler.get("is_internal", False) if handler else True
- 'io.quay.builder.cannotextractbuildpack': {
- 'message': 'Could not extract the contents of the build package'
- },
+ def public_message(self):
+ handler = self._error_handlers.get(self._error_code)
+ if not handler:
+ return "An unknown error occurred"
- 'io.quay.builder.cannotpullforcache': {
- 'message': 'Could not pull cached image',
- 'is_internal': True
- },
+ message = handler["message"]
+ if handler.get("show_base_error", False) and self._base_message:
+ message = message + ": " + self._base_message
- 'io.quay.builder.dockerfileissue': {
- 'message': 'Could not find or parse Dockerfile',
- 'show_base_error': True
- },
+ return message
- 'io.quay.builder.cannotpullbaseimage': {
- 'message': 'Could not pull base image',
- 'show_base_error': True
- },
+ def extra_data(self):
+ if self._base_message:
+ return {"base_error": self._base_message, "error_code": self._error_code}
- 'io.quay.builder.internalerror': {
- 'message': 'An internal error occurred while building. Please submit a ticket.',
- 'is_internal': True
- },
-
- 'io.quay.builder.buildrunerror': {
- 'message': 'Could not start the build process',
- 'is_internal': True
- },
-
- 'io.quay.builder.builderror': {
- 'message': 'A build step failed',
- 'show_base_error': True
- },
-
- 'io.quay.builder.tagissue': {
- 'message': 'Could not tag built image',
- 'is_internal': True
- },
-
- 'io.quay.builder.pushissue': {
- 'message': 'Could not push built image',
- 'show_base_error': True,
- 'is_internal': True
- },
-
- 'io.quay.builder.dockerconnecterror': {
- 'message': 'Could not connect to Docker daemon',
- 'is_internal': True
- },
-
- 'io.quay.builder.missingorinvalidargument': {
- 'message': 'Missing required arguments for builder',
- 'is_internal': True
- },
-
- 'io.quay.builder.cachelookupissue': {
- 'message': 'Error checking for a cached tag',
- 'is_internal': True
- },
-
- 'io.quay.builder.errorduringphasetransition': {
- 'message': 'Error during phase transition. If this problem persists ' +
- 'please contact customer support.',
- 'is_internal': True
- },
-
- 'io.quay.builder.clientrejectedtransition': {
- 'message': 'Build can not be finished due to user cancellation.',
- }
- }
-
- def is_internal_error(self):
- handler = self._error_handlers.get(self._error_code)
- return handler.get('is_internal', False) if handler else True
-
- def public_message(self):
- handler = self._error_handlers.get(self._error_code)
- if not handler:
- return 'An unknown error occurred'
-
- message = handler['message']
- if handler.get('show_base_error', False) and self._base_message:
- message = message + ': ' + self._base_message
-
- return message
-
- def extra_data(self):
- if self._base_message:
- return {
- 'base_error': self._base_message,
- 'error_code': self._error_code
- }
-
- return {
- 'error_code': self._error_code
- }
+ return {"error_code": self._error_code}
diff --git a/buildman/manager/basemanager.py b/buildman/manager/basemanager.py
index 23627830a..996a4eacc 100644
--- a/buildman/manager/basemanager.py
+++ b/buildman/manager/basemanager.py
@@ -1,71 +1,80 @@
from trollius import coroutine
+
class BaseManager(object):
- """ Base for all worker managers. """
- def __init__(self, register_component, unregister_component, job_heartbeat_callback,
- job_complete_callback, manager_hostname, heartbeat_period_sec):
- self.register_component = register_component
- self.unregister_component = unregister_component
- self.job_heartbeat_callback = job_heartbeat_callback
- self.job_complete_callback = job_complete_callback
- self.manager_hostname = manager_hostname
- self.heartbeat_period_sec = heartbeat_period_sec
+ """ Base for all worker managers. """
- @coroutine
- def job_heartbeat(self, build_job):
- """ Method invoked to tell the manager that a job is still running. This method will be called
+ def __init__(
+ self,
+ register_component,
+ unregister_component,
+ job_heartbeat_callback,
+ job_complete_callback,
+ manager_hostname,
+ heartbeat_period_sec,
+ ):
+ self.register_component = register_component
+ self.unregister_component = unregister_component
+ self.job_heartbeat_callback = job_heartbeat_callback
+ self.job_complete_callback = job_complete_callback
+ self.manager_hostname = manager_hostname
+ self.heartbeat_period_sec = heartbeat_period_sec
+
+ @coroutine
+ def job_heartbeat(self, build_job):
+ """ Method invoked to tell the manager that a job is still running. This method will be called
every few minutes. """
- self.job_heartbeat_callback(build_job)
+ self.job_heartbeat_callback(build_job)
- def overall_setup_time(self):
- """ Returns the number of seconds that the build system should wait before allowing the job
+ def overall_setup_time(self):
+ """ Returns the number of seconds that the build system should wait before allowing the job
to be picked up again after called 'schedule'.
"""
- raise NotImplementedError
+ raise NotImplementedError
- def shutdown(self):
- """ Indicates that the build controller server is in a shutdown state and that no new jobs
+ def shutdown(self):
+ """ Indicates that the build controller server is in a shutdown state and that no new jobs
or workers should be performed. Existing workers should be cleaned up once their jobs
have completed
"""
- raise NotImplementedError
+ raise NotImplementedError
- @coroutine
- def schedule(self, build_job):
- """ Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was
+ @coroutine
+ def schedule(self, build_job):
+ """ Schedules a queue item to be built. Returns a 2-tuple with (True, None) if the item was
properly scheduled and (False, a retry timeout in seconds) if all workers are busy or an
error occurs.
"""
- raise NotImplementedError
+ raise NotImplementedError
- def initialize(self, manager_config):
- """ Runs any initialization code for the manager. Called once the server is in a ready state.
+ def initialize(self, manager_config):
+ """ Runs any initialization code for the manager. Called once the server is in a ready state.
"""
- raise NotImplementedError
+ raise NotImplementedError
- @coroutine
- def build_component_ready(self, build_component):
- """ Method invoked whenever a build component announces itself as ready.
+ @coroutine
+ def build_component_ready(self, build_component):
+ """ Method invoked whenever a build component announces itself as ready.
"""
- raise NotImplementedError
+ raise NotImplementedError
- def build_component_disposed(self, build_component, timed_out):
- """ Method invoked whenever a build component has been disposed. The timed_out boolean indicates
+ def build_component_disposed(self, build_component, timed_out):
+ """ Method invoked whenever a build component has been disposed. The timed_out boolean indicates
whether the component's heartbeat timed out.
"""
- raise NotImplementedError
+ raise NotImplementedError
- @coroutine
- def job_completed(self, build_job, job_status, build_component):
- """ Method invoked once a job_item has completed, in some manner. The job_status will be
+ @coroutine
+ def job_completed(self, build_job, job_status, build_component):
+ """ Method invoked once a job_item has completed, in some manner. The job_status will be
one of: incomplete, error, complete. Implementations of this method should call coroutine
self.job_complete_callback with a status of Incomplete if they wish for the job to be
automatically requeued.
"""
- raise NotImplementedError
+ raise NotImplementedError
- def num_workers(self):
- """ Returns the number of active build workers currently registered. This includes those
+ def num_workers(self):
+ """ Returns the number of active build workers currently registered. This includes those
that are currently busy and awaiting more work.
"""
- raise NotImplementedError
+ raise NotImplementedError
diff --git a/buildman/manager/buildcanceller.py b/buildman/manager/buildcanceller.py
index dd49e9f38..c2ab2d9ad 100644
--- a/buildman/manager/buildcanceller.py
+++ b/buildman/manager/buildcanceller.py
@@ -5,23 +5,23 @@ from buildman.manager.noop_canceller import NoopCanceller
logger = logging.getLogger(__name__)
-CANCELLERS = {'ephemeral': OrchestratorCanceller}
+CANCELLERS = {"ephemeral": OrchestratorCanceller}
class BuildCanceller(object):
- """ A class to manage cancelling a build """
+ """ A class to manage cancelling a build """
- def __init__(self, app=None):
- self.build_manager_config = app.config.get('BUILD_MANAGER')
- if app is None or self.build_manager_config is None:
- self.handler = NoopCanceller()
- else:
- self.handler = None
+ def __init__(self, app=None):
+ self.build_manager_config = app.config.get("BUILD_MANAGER")
+ if app is None or self.build_manager_config is None:
+ self.handler = NoopCanceller()
+ else:
+ self.handler = None
- def try_cancel_build(self, uuid):
- """ A method to kill a running build """
- if self.handler is None:
- canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller)
- self.handler = canceller(self.build_manager_config[1])
+ def try_cancel_build(self, uuid):
+ """ A method to kill a running build """
+ if self.handler is None:
+ canceller = CANCELLERS.get(self.build_manager_config[0], NoopCanceller)
+ self.handler = canceller(self.build_manager_config[1])
- return self.handler.try_cancel_build(uuid)
+ return self.handler.try_cancel_build(uuid)
diff --git a/buildman/manager/enterprise.py b/buildman/manager/enterprise.py
index 3d32a61d0..0be01269c 100644
--- a/buildman/manager/enterprise.py
+++ b/buildman/manager/enterprise.py
@@ -7,86 +7,89 @@ from buildman.manager.basemanager import BaseManager
from trollius import From, Return, coroutine
-REGISTRATION_REALM = 'registration'
+REGISTRATION_REALM = "registration"
RETRY_TIMEOUT = 5
logger = logging.getLogger(__name__)
+
class DynamicRegistrationComponent(BaseComponent):
- """ Component session that handles dynamic registration of the builder components. """
+ """ Component session that handles dynamic registration of the builder components. """
- def onConnect(self):
- self.join(REGISTRATION_REALM)
+ def onConnect(self):
+ self.join(REGISTRATION_REALM)
- def onJoin(self, details):
- logger.debug('Registering registration method')
- yield From(self.register(self._worker_register, u'io.quay.buildworker.register'))
+ def onJoin(self, details):
+ logger.debug("Registering registration method")
+ yield From(
+ self.register(self._worker_register, u"io.quay.buildworker.register")
+ )
- def _worker_register(self):
- realm = self.parent_manager.add_build_component()
- logger.debug('Registering new build component+worker with realm %s', realm)
- return realm
+ def _worker_register(self):
+ realm = self.parent_manager.add_build_component()
+ logger.debug("Registering new build component+worker with realm %s", realm)
+ return realm
- def kind(self):
- return 'registration'
+ def kind(self):
+ return "registration"
class EnterpriseManager(BaseManager):
- """ Build manager implementation for the Enterprise Registry. """
+ """ Build manager implementation for the Enterprise Registry. """
- def __init__(self, *args, **kwargs):
- self.ready_components = set()
- self.all_components = set()
- self.shutting_down = False
+ def __init__(self, *args, **kwargs):
+ self.ready_components = set()
+ self.all_components = set()
+ self.shutting_down = False
- super(EnterpriseManager, self).__init__(*args, **kwargs)
+ super(EnterpriseManager, self).__init__(*args, **kwargs)
- def initialize(self, manager_config):
- # Add a component which is used by build workers for dynamic registration. Unlike
- # production, build workers in enterprise are long-lived and register dynamically.
- self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent)
+ def initialize(self, manager_config):
+ # Add a component which is used by build workers for dynamic registration. Unlike
+ # production, build workers in enterprise are long-lived and register dynamically.
+ self.register_component(REGISTRATION_REALM, DynamicRegistrationComponent)
- def overall_setup_time(self):
- # Builders are already registered, so the setup time should be essentially instant. We therefore
- # only return a minute here.
- return 60
+ def overall_setup_time(self):
+ # Builders are already registered, so the setup time should be essentially instant. We therefore
+ # only return a minute here.
+ return 60
- def add_build_component(self):
- """ Adds a new build component for an Enterprise Registry. """
- # Generate a new unique realm ID for the build worker.
- realm = str(uuid.uuid4())
- new_component = self.register_component(realm, BuildComponent, token="")
- self.all_components.add(new_component)
- return realm
+ def add_build_component(self):
+ """ Adds a new build component for an Enterprise Registry. """
+ # Generate a new unique realm ID for the build worker.
+ realm = str(uuid.uuid4())
+ new_component = self.register_component(realm, BuildComponent, token="")
+ self.all_components.add(new_component)
+ return realm
- @coroutine
- def schedule(self, build_job):
- """ Schedules a build for an Enterprise Registry. """
- if self.shutting_down or not self.ready_components:
- raise Return(False, RETRY_TIMEOUT)
+ @coroutine
+ def schedule(self, build_job):
+ """ Schedules a build for an Enterprise Registry. """
+ if self.shutting_down or not self.ready_components:
+ raise Return(False, RETRY_TIMEOUT)
- component = self.ready_components.pop()
+ component = self.ready_components.pop()
- yield From(component.start_build(build_job))
+ yield From(component.start_build(build_job))
- raise Return(True, None)
+ raise Return(True, None)
- @coroutine
- def build_component_ready(self, build_component):
- self.ready_components.add(build_component)
+ @coroutine
+ def build_component_ready(self, build_component):
+ self.ready_components.add(build_component)
- def shutdown(self):
- self.shutting_down = True
+ def shutdown(self):
+ self.shutting_down = True
- @coroutine
- def job_completed(self, build_job, job_status, build_component):
- yield From(self.job_complete_callback(build_job, job_status))
+ @coroutine
+ def job_completed(self, build_job, job_status, build_component):
+ yield From(self.job_complete_callback(build_job, job_status))
- def build_component_disposed(self, build_component, timed_out):
- self.all_components.remove(build_component)
- if build_component in self.ready_components:
- self.ready_components.remove(build_component)
+ def build_component_disposed(self, build_component, timed_out):
+ self.all_components.remove(build_component)
+ if build_component in self.ready_components:
+ self.ready_components.remove(build_component)
- self.unregister_component(build_component)
+ self.unregister_component(build_component)
- def num_workers(self):
- return len(self.all_components)
+ def num_workers(self):
+ return len(self.all_components)
diff --git a/buildman/manager/etcd_canceller.py b/buildman/manager/etcd_canceller.py
index ce92a1bbc..d4b129e52 100644
--- a/buildman/manager/etcd_canceller.py
+++ b/buildman/manager/etcd_canceller.py
@@ -5,33 +5,36 @@ logger = logging.getLogger(__name__)
class EtcdCanceller(object):
- """ A class that sends a message to etcd to cancel a build """
+ """ A class that sends a message to etcd to cancel a build """
- def __init__(self, config):
- etcd_host = config.get('ETCD_HOST', '127.0.0.1')
- etcd_port = config.get('ETCD_PORT', 2379)
- etcd_ca_cert = config.get('ETCD_CA_CERT', None)
- etcd_auth = config.get('ETCD_CERT_AND_KEY', None)
- if etcd_auth is not None:
- etcd_auth = tuple(etcd_auth)
+ def __init__(self, config):
+ etcd_host = config.get("ETCD_HOST", "127.0.0.1")
+ etcd_port = config.get("ETCD_PORT", 2379)
+ etcd_ca_cert = config.get("ETCD_CA_CERT", None)
+ etcd_auth = config.get("ETCD_CERT_AND_KEY", None)
+ if etcd_auth is not None:
+ etcd_auth = tuple(etcd_auth)
- etcd_protocol = 'http' if etcd_auth is None else 'https'
- logger.debug('Connecting to etcd on %s:%s', etcd_host, etcd_port)
- self._cancel_prefix = config.get('ETCD_CANCEL_PREFIX', 'cancel/')
- self._etcd_client = etcd.Client(
- host=etcd_host,
- port=etcd_port,
- cert=etcd_auth,
- ca_cert=etcd_ca_cert,
- protocol=etcd_protocol,
- read_timeout=5)
+ etcd_protocol = "http" if etcd_auth is None else "https"
+ logger.debug("Connecting to etcd on %s:%s", etcd_host, etcd_port)
+ self._cancel_prefix = config.get("ETCD_CANCEL_PREFIX", "cancel/")
+ self._etcd_client = etcd.Client(
+ host=etcd_host,
+ port=etcd_port,
+ cert=etcd_auth,
+ ca_cert=etcd_ca_cert,
+ protocol=etcd_protocol,
+ read_timeout=5,
+ )
- def try_cancel_build(self, build_uuid):
- """ Writes etcd message to cancel build_uuid. """
- logger.info("Cancelling build %s".format(build_uuid))
- try:
- self._etcd_client.write("{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60)
- return True
- except etcd.EtcdException:
- logger.exception("Failed to write to etcd client %s", build_uuid)
- return False
+ def try_cancel_build(self, build_uuid):
+ """ Writes etcd message to cancel build_uuid. """
+ logger.info("Cancelling build %s".format(build_uuid))
+ try:
+ self._etcd_client.write(
+ "{}{}".format(self._cancel_prefix, build_uuid), build_uuid, ttl=60
+ )
+ return True
+ except etcd.EtcdException:
+ logger.exception("Failed to write to etcd client %s", build_uuid)
+ return False
diff --git a/buildman/manager/executor.py b/buildman/manager/executor.py
index e82d7a316..7921adbcc 100644
--- a/buildman/manager/executor.py
+++ b/buildman/manager/executor.py
@@ -29,532 +29,605 @@ from _init import ROOT_DIR
logger = logging.getLogger(__name__)
-ONE_HOUR = 60*60
+ONE_HOUR = 60 * 60
-_TAG_RETRY_COUNT = 3 # Number of times to retry adding tags.
-_TAG_RETRY_SLEEP = 2 # Number of seconds to wait between tag retries.
+_TAG_RETRY_COUNT = 3 # Number of times to retry adding tags.
+_TAG_RETRY_SLEEP = 2 # Number of seconds to wait between tag retries.
ENV = Environment(loader=FileSystemLoader(os.path.join(ROOT_DIR, "buildman/templates")))
-TEMPLATE = ENV.get_template('cloudconfig.yaml')
+TEMPLATE = ENV.get_template("cloudconfig.yaml")
CloudConfigContext().populate_jinja_environment(ENV)
+
class ExecutorException(Exception):
- """ Exception raised when there is a problem starting or stopping a builder.
+ """ Exception raised when there is a problem starting or stopping a builder.
"""
- pass
+
+ pass
class BuilderExecutor(object):
- def __init__(self, executor_config, manager_hostname):
- """ Interface which can be plugged into the EphemeralNodeManager to provide a strategy for
+ def __init__(self, executor_config, manager_hostname):
+ """ Interface which can be plugged into the EphemeralNodeManager to provide a strategy for
starting and stopping builders.
"""
- self.executor_config = executor_config
- self.manager_hostname = manager_hostname
+ self.executor_config = executor_config
+ self.manager_hostname = manager_hostname
- default_websocket_scheme = 'wss' if app.config['PREFERRED_URL_SCHEME'] == 'https' else 'ws'
- self.websocket_scheme = executor_config.get("WEBSOCKET_SCHEME", default_websocket_scheme)
+ default_websocket_scheme = (
+ "wss" if app.config["PREFERRED_URL_SCHEME"] == "https" else "ws"
+ )
+ self.websocket_scheme = executor_config.get(
+ "WEBSOCKET_SCHEME", default_websocket_scheme
+ )
- @property
- def name(self):
- """ Name returns the unique name for this executor. """
- return self.executor_config.get('NAME') or self.__class__.__name__
+ @property
+ def name(self):
+ """ Name returns the unique name for this executor. """
+ return self.executor_config.get("NAME") or self.__class__.__name__
- @property
- def setup_time(self):
- """ Returns the amount of time (in seconds) to wait for the execution to start for the build.
+ @property
+ def setup_time(self):
+ """ Returns the amount of time (in seconds) to wait for the execution to start for the build.
If None, the manager's default will be used.
"""
- return self.executor_config.get('SETUP_TIME')
+ return self.executor_config.get("SETUP_TIME")
- @coroutine
- def start_builder(self, realm, token, build_uuid):
- """ Create a builder with the specified config. Returns a unique id which can be used to manage
+ @coroutine
+ def start_builder(self, realm, token, build_uuid):
+ """ Create a builder with the specified config. Returns a unique id which can be used to manage
the builder.
"""
- raise NotImplementedError
+ raise NotImplementedError
- @coroutine
- def stop_builder(self, builder_id):
- """ Stop a builder which is currently running.
+ @coroutine
+ def stop_builder(self, builder_id):
+ """ Stop a builder which is currently running.
"""
- raise NotImplementedError
+ raise NotImplementedError
- def allowed_for_namespace(self, namespace):
- """ Returns true if this executor can be used for builds in the given namespace. """
+ def allowed_for_namespace(self, namespace):
+ """ Returns true if this executor can be used for builds in the given namespace. """
- # Check for an explicit namespace whitelist.
- namespace_whitelist = self.executor_config.get('NAMESPACE_WHITELIST')
- if namespace_whitelist is not None and namespace in namespace_whitelist:
- return True
+ # Check for an explicit namespace whitelist.
+ namespace_whitelist = self.executor_config.get("NAMESPACE_WHITELIST")
+ if namespace_whitelist is not None and namespace in namespace_whitelist:
+ return True
- # Check for a staged rollout percentage. If found, we hash the namespace and, if it is found
- # in the first X% of the character space, we allow this executor to be used.
- staged_rollout = self.executor_config.get('STAGED_ROLLOUT')
- if staged_rollout is not None:
- bucket = int(hashlib.sha256(namespace).hexdigest()[-2:], 16)
- return bucket < (256 * staged_rollout)
+ # Check for a staged rollout percentage. If found, we hash the namespace and, if it is found
+ # in the first X% of the character space, we allow this executor to be used.
+ staged_rollout = self.executor_config.get("STAGED_ROLLOUT")
+ if staged_rollout is not None:
+ bucket = int(hashlib.sha256(namespace).hexdigest()[-2:], 16)
+ return bucket < (256 * staged_rollout)
- # If there are no restrictions in place, we are free to use this executor.
- return staged_rollout is None and namespace_whitelist is None
+ # If there are no restrictions in place, we are free to use this executor.
+ return staged_rollout is None and namespace_whitelist is None
- @property
- def minimum_retry_threshold(self):
- """ Returns the minimum number of retries required for this executor to be used or 0 if
+ @property
+ def minimum_retry_threshold(self):
+ """ Returns the minimum number of retries required for this executor to be used or 0 if
none. """
- return self.executor_config.get('MINIMUM_RETRY_THRESHOLD', 0)
+ return self.executor_config.get("MINIMUM_RETRY_THRESHOLD", 0)
- def generate_cloud_config(self, realm, token, build_uuid, coreos_channel,
- manager_hostname, quay_username=None,
- quay_password=None):
- if quay_username is None:
- quay_username = self.executor_config['QUAY_USERNAME']
+ def generate_cloud_config(
+ self,
+ realm,
+ token,
+ build_uuid,
+ coreos_channel,
+ manager_hostname,
+ quay_username=None,
+ quay_password=None,
+ ):
+ if quay_username is None:
+ quay_username = self.executor_config["QUAY_USERNAME"]
- if quay_password is None:
- quay_password = self.executor_config['QUAY_PASSWORD']
+ if quay_password is None:
+ quay_password = self.executor_config["QUAY_PASSWORD"]
- return TEMPLATE.render(
- realm=realm,
- token=token,
- build_uuid=build_uuid,
- quay_username=quay_username,
- quay_password=quay_password,
- manager_hostname=manager_hostname,
- websocket_scheme=self.websocket_scheme,
- coreos_channel=coreos_channel,
- worker_image=self.executor_config.get('WORKER_IMAGE', 'quay.io/coreos/registry-build-worker'),
- worker_tag=self.executor_config['WORKER_TAG'],
- logentries_token=self.executor_config.get('LOGENTRIES_TOKEN', None),
- volume_size=self.executor_config.get('VOLUME_SIZE', '42G'),
- max_lifetime_s=self.executor_config.get('MAX_LIFETIME_S', 10800),
- ssh_authorized_keys=self.executor_config.get('SSH_AUTHORIZED_KEYS', []),
- )
+ return TEMPLATE.render(
+ realm=realm,
+ token=token,
+ build_uuid=build_uuid,
+ quay_username=quay_username,
+ quay_password=quay_password,
+ manager_hostname=manager_hostname,
+ websocket_scheme=self.websocket_scheme,
+ coreos_channel=coreos_channel,
+ worker_image=self.executor_config.get(
+ "WORKER_IMAGE", "quay.io/coreos/registry-build-worker"
+ ),
+ worker_tag=self.executor_config["WORKER_TAG"],
+ logentries_token=self.executor_config.get("LOGENTRIES_TOKEN", None),
+ volume_size=self.executor_config.get("VOLUME_SIZE", "42G"),
+ max_lifetime_s=self.executor_config.get("MAX_LIFETIME_S", 10800),
+ ssh_authorized_keys=self.executor_config.get("SSH_AUTHORIZED_KEYS", []),
+ )
class EC2Executor(BuilderExecutor):
- """ Implementation of BuilderExecutor which uses libcloud to start machines on a variety of cloud
+ """ Implementation of BuilderExecutor which uses libcloud to start machines on a variety of cloud
providers.
"""
- COREOS_STACK_URL = 'http://%s.release.core-os.net/amd64-usr/current/coreos_production_ami_hvm.txt'
- def __init__(self, *args, **kwargs):
- self._loop = get_event_loop()
- super(EC2Executor, self).__init__(*args, **kwargs)
-
- def _get_conn(self):
- """ Creates an ec2 connection which can be used to manage instances.
- """
- return AsyncWrapper(boto.ec2.connect_to_region(
- self.executor_config['EC2_REGION'],
- aws_access_key_id=self.executor_config['AWS_ACCESS_KEY'],
- aws_secret_access_key=self.executor_config['AWS_SECRET_KEY'],
- ))
-
- @classmethod
- @cachetools.func.ttl_cache(ttl=ONE_HOUR)
- def _get_coreos_ami(cls, ec2_region, coreos_channel):
- """ Retrieve the CoreOS AMI id from the canonical listing.
- """
- stack_list_string = requests.get(EC2Executor.COREOS_STACK_URL % coreos_channel).text
- stack_amis = dict([stack.split('=') for stack in stack_list_string.split('|')])
- return stack_amis[ec2_region]
-
- @coroutine
- @duration_collector_async(metric_queue.builder_time_to_start, ['ec2'])
- def start_builder(self, realm, token, build_uuid):
- region = self.executor_config['EC2_REGION']
- channel = self.executor_config.get('COREOS_CHANNEL', 'stable')
-
- coreos_ami = self.executor_config.get('COREOS_AMI', None)
- if coreos_ami is None:
- get_ami_callable = partial(self._get_coreos_ami, region, channel)
- coreos_ami = yield From(self._loop.run_in_executor(None, get_ami_callable))
-
- user_data = self.generate_cloud_config(realm, token, build_uuid, channel, self.manager_hostname)
- logger.debug('Generated cloud config for build %s: %s', build_uuid, user_data)
-
- ec2_conn = self._get_conn()
-
- ssd_root_ebs = boto.ec2.blockdevicemapping.BlockDeviceType(
- size=int(self.executor_config.get('BLOCK_DEVICE_SIZE', 48)),
- volume_type='gp2',
- delete_on_termination=True,
+ COREOS_STACK_URL = (
+ "http://%s.release.core-os.net/amd64-usr/current/coreos_production_ami_hvm.txt"
)
- block_devices = boto.ec2.blockdevicemapping.BlockDeviceMapping()
- block_devices['/dev/xvda'] = ssd_root_ebs
- interfaces = None
- if self.executor_config.get('EC2_VPC_SUBNET_ID', None) is not None:
- interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
- subnet_id=self.executor_config['EC2_VPC_SUBNET_ID'],
- groups=self.executor_config['EC2_SECURITY_GROUP_IDS'],
- associate_public_ip_address=True,
- )
- interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)
+ def __init__(self, *args, **kwargs):
+ self._loop = get_event_loop()
+ super(EC2Executor, self).__init__(*args, **kwargs)
- try:
- reservation = yield From(ec2_conn.run_instances(
- coreos_ami,
- instance_type=self.executor_config['EC2_INSTANCE_TYPE'],
- key_name=self.executor_config.get('EC2_KEY_NAME', None),
- user_data=user_data,
- instance_initiated_shutdown_behavior='terminate',
- block_device_map=block_devices,
- network_interfaces=interfaces,
- ))
- except boto.exception.EC2ResponseError as ec2e:
- logger.exception('Unable to spawn builder instance')
- metric_queue.ephemeral_build_worker_failure.Inc()
- raise ec2e
+ def _get_conn(self):
+ """ Creates an ec2 connection which can be used to manage instances.
+ """
+ return AsyncWrapper(
+ boto.ec2.connect_to_region(
+ self.executor_config["EC2_REGION"],
+ aws_access_key_id=self.executor_config["AWS_ACCESS_KEY"],
+ aws_secret_access_key=self.executor_config["AWS_SECRET_KEY"],
+ )
+ )
- if not reservation.instances:
- raise ExecutorException('Unable to spawn builder instance.')
- elif len(reservation.instances) != 1:
- raise ExecutorException('EC2 started wrong number of instances!')
+ @classmethod
+ @cachetools.func.ttl_cache(ttl=ONE_HOUR)
+ def _get_coreos_ami(cls, ec2_region, coreos_channel):
+ """ Retrieve the CoreOS AMI id from the canonical listing.
+ """
+ stack_list_string = requests.get(
+ EC2Executor.COREOS_STACK_URL % coreos_channel
+ ).text
+ stack_amis = dict([stack.split("=") for stack in stack_list_string.split("|")])
+ return stack_amis[ec2_region]
- launched = AsyncWrapper(reservation.instances[0])
+ @coroutine
+ @duration_collector_async(metric_queue.builder_time_to_start, ["ec2"])
+ def start_builder(self, realm, token, build_uuid):
+ region = self.executor_config["EC2_REGION"]
+ channel = self.executor_config.get("COREOS_CHANNEL", "stable")
- # Sleep a few seconds to wait for AWS to spawn the instance.
- yield From(trollius.sleep(_TAG_RETRY_SLEEP))
+ coreos_ami = self.executor_config.get("COREOS_AMI", None)
+ if coreos_ami is None:
+ get_ami_callable = partial(self._get_coreos_ami, region, channel)
+ coreos_ami = yield From(self._loop.run_in_executor(None, get_ami_callable))
- # Tag the instance with its metadata.
- for i in range(0, _TAG_RETRY_COUNT):
- try:
- yield From(launched.add_tags({
- 'Name': 'Quay Ephemeral Builder',
- 'Realm': realm,
- 'Token': token,
- 'BuildUUID': build_uuid,
- }))
- except boto.exception.EC2ResponseError as ec2e:
- if ec2e.error_code == 'InvalidInstanceID.NotFound':
- if i < _TAG_RETRY_COUNT - 1:
- logger.warning('Failed to write EC2 tags for instance %s for build %s (attempt #%s)',
- launched.id, build_uuid, i)
- yield From(trollius.sleep(_TAG_RETRY_SLEEP))
- continue
+ user_data = self.generate_cloud_config(
+ realm, token, build_uuid, channel, self.manager_hostname
+ )
+ logger.debug("Generated cloud config for build %s: %s", build_uuid, user_data)
- raise ExecutorException('Unable to find builder instance.')
+ ec2_conn = self._get_conn()
- logger.exception('Failed to write EC2 tags (attempt #%s)', i)
+ ssd_root_ebs = boto.ec2.blockdevicemapping.BlockDeviceType(
+ size=int(self.executor_config.get("BLOCK_DEVICE_SIZE", 48)),
+ volume_type="gp2",
+ delete_on_termination=True,
+ )
+ block_devices = boto.ec2.blockdevicemapping.BlockDeviceMapping()
+ block_devices["/dev/xvda"] = ssd_root_ebs
- logger.debug('Machine with ID %s started for build %s', launched.id, build_uuid)
- raise Return(launched.id)
+ interfaces = None
+ if self.executor_config.get("EC2_VPC_SUBNET_ID", None) is not None:
+ interface = boto.ec2.networkinterface.NetworkInterfaceSpecification(
+ subnet_id=self.executor_config["EC2_VPC_SUBNET_ID"],
+ groups=self.executor_config["EC2_SECURITY_GROUP_IDS"],
+ associate_public_ip_address=True,
+ )
+ interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface)
- @coroutine
- def stop_builder(self, builder_id):
- try:
- ec2_conn = self._get_conn()
- terminated_instances = yield From(ec2_conn.terminate_instances([builder_id]))
- except boto.exception.EC2ResponseError as ec2e:
- if ec2e.error_code == 'InvalidInstanceID.NotFound':
- logger.debug('Instance %s already terminated', builder_id)
- return
+ try:
+ reservation = yield From(
+ ec2_conn.run_instances(
+ coreos_ami,
+ instance_type=self.executor_config["EC2_INSTANCE_TYPE"],
+ key_name=self.executor_config.get("EC2_KEY_NAME", None),
+ user_data=user_data,
+ instance_initiated_shutdown_behavior="terminate",
+ block_device_map=block_devices,
+ network_interfaces=interfaces,
+ )
+ )
+ except boto.exception.EC2ResponseError as ec2e:
+ logger.exception("Unable to spawn builder instance")
+ metric_queue.ephemeral_build_worker_failure.Inc()
+ raise ec2e
- logger.exception('Exception when trying to terminate instance %s', builder_id)
- raise
+ if not reservation.instances:
+ raise ExecutorException("Unable to spawn builder instance.")
+ elif len(reservation.instances) != 1:
+ raise ExecutorException("EC2 started wrong number of instances!")
- if builder_id not in [si.id for si in terminated_instances]:
- raise ExecutorException('Unable to terminate instance: %s' % builder_id)
+ launched = AsyncWrapper(reservation.instances[0])
+
+ # Sleep a few seconds to wait for AWS to spawn the instance.
+ yield From(trollius.sleep(_TAG_RETRY_SLEEP))
+
+ # Tag the instance with its metadata.
+ for i in range(0, _TAG_RETRY_COUNT):
+ try:
+ yield From(
+ launched.add_tags(
+ {
+ "Name": "Quay Ephemeral Builder",
+ "Realm": realm,
+ "Token": token,
+ "BuildUUID": build_uuid,
+ }
+ )
+ )
+ except boto.exception.EC2ResponseError as ec2e:
+ if ec2e.error_code == "InvalidInstanceID.NotFound":
+ if i < _TAG_RETRY_COUNT - 1:
+ logger.warning(
+ "Failed to write EC2 tags for instance %s for build %s (attempt #%s)",
+ launched.id,
+ build_uuid,
+ i,
+ )
+ yield From(trollius.sleep(_TAG_RETRY_SLEEP))
+ continue
+
+ raise ExecutorException("Unable to find builder instance.")
+
+ logger.exception("Failed to write EC2 tags (attempt #%s)", i)
+
+ logger.debug("Machine with ID %s started for build %s", launched.id, build_uuid)
+ raise Return(launched.id)
+
+ @coroutine
+ def stop_builder(self, builder_id):
+ try:
+ ec2_conn = self._get_conn()
+ terminated_instances = yield From(
+ ec2_conn.terminate_instances([builder_id])
+ )
+ except boto.exception.EC2ResponseError as ec2e:
+ if ec2e.error_code == "InvalidInstanceID.NotFound":
+ logger.debug("Instance %s already terminated", builder_id)
+ return
+
+ logger.exception(
+ "Exception when trying to terminate instance %s", builder_id
+ )
+ raise
+
+ if builder_id not in [si.id for si in terminated_instances]:
+ raise ExecutorException("Unable to terminate instance: %s" % builder_id)
class PopenExecutor(BuilderExecutor):
- """ Implementation of BuilderExecutor which uses Popen to fork a quay-builder process.
+ """ Implementation of BuilderExecutor which uses Popen to fork a quay-builder process.
"""
- def __init__(self, executor_config, manager_hostname):
- self._jobs = {}
- super(PopenExecutor, self).__init__(executor_config, manager_hostname)
+ def __init__(self, executor_config, manager_hostname):
+ self._jobs = {}
- """ Executor which uses Popen to fork a quay-builder process.
+ super(PopenExecutor, self).__init__(executor_config, manager_hostname)
+
+ """ Executor which uses Popen to fork a quay-builder process.
"""
- @coroutine
- @duration_collector_async(metric_queue.builder_time_to_start, ['fork'])
- def start_builder(self, realm, token, build_uuid):
- # Now start a machine for this job, adding the machine id to the etcd information
- logger.debug('Forking process for build')
- ws_host = os.environ.get("BUILDMAN_WS_HOST", "localhost")
- ws_port = os.environ.get("BUILDMAN_WS_PORT", "8787")
- builder_env = {
- 'TOKEN': token,
- 'REALM': realm,
- 'ENDPOINT': 'ws://%s:%s' % (ws_host, ws_port),
- 'DOCKER_TLS_VERIFY': os.environ.get('DOCKER_TLS_VERIFY', ''),
- 'DOCKER_CERT_PATH': os.environ.get('DOCKER_CERT_PATH', ''),
- 'DOCKER_HOST': os.environ.get('DOCKER_HOST', ''),
- 'PATH': "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
- }
+ @coroutine
+ @duration_collector_async(metric_queue.builder_time_to_start, ["fork"])
+ def start_builder(self, realm, token, build_uuid):
+ # Now start a machine for this job, adding the machine id to the etcd information
+ logger.debug("Forking process for build")
- logpipe = LogPipe(logging.INFO)
- spawned = subprocess.Popen(os.environ.get('BUILDER_BINARY_LOCATION',
- '/usr/local/bin/quay-builder'),
- stdout=logpipe,
- stderr=logpipe,
- env=builder_env)
+ ws_host = os.environ.get("BUILDMAN_WS_HOST", "localhost")
+ ws_port = os.environ.get("BUILDMAN_WS_PORT", "8787")
+ builder_env = {
+ "TOKEN": token,
+ "REALM": realm,
+ "ENDPOINT": "ws://%s:%s" % (ws_host, ws_port),
+ "DOCKER_TLS_VERIFY": os.environ.get("DOCKER_TLS_VERIFY", ""),
+ "DOCKER_CERT_PATH": os.environ.get("DOCKER_CERT_PATH", ""),
+ "DOCKER_HOST": os.environ.get("DOCKER_HOST", ""),
+ "PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
+ }
- builder_id = str(uuid.uuid4())
- self._jobs[builder_id] = (spawned, logpipe)
- logger.debug('Builder spawned with id: %s', builder_id)
- raise Return(builder_id)
+ logpipe = LogPipe(logging.INFO)
+ spawned = subprocess.Popen(
+ os.environ.get("BUILDER_BINARY_LOCATION", "/usr/local/bin/quay-builder"),
+ stdout=logpipe,
+ stderr=logpipe,
+ env=builder_env,
+ )
- @coroutine
- def stop_builder(self, builder_id):
- if builder_id not in self._jobs:
- raise ExecutorException('Builder id not being tracked by executor.')
+ builder_id = str(uuid.uuid4())
+ self._jobs[builder_id] = (spawned, logpipe)
+ logger.debug("Builder spawned with id: %s", builder_id)
+ raise Return(builder_id)
- logger.debug('Killing builder with id: %s', builder_id)
- spawned, logpipe = self._jobs[builder_id]
+ @coroutine
+ def stop_builder(self, builder_id):
+ if builder_id not in self._jobs:
+ raise ExecutorException("Builder id not being tracked by executor.")
- if spawned.poll() is None:
- spawned.kill()
- logpipe.close()
+ logger.debug("Killing builder with id: %s", builder_id)
+ spawned, logpipe = self._jobs[builder_id]
+
+ if spawned.poll() is None:
+ spawned.kill()
+ logpipe.close()
class KubernetesExecutor(BuilderExecutor):
- """ Executes build jobs by creating Kubernetes jobs which run a qemu-kvm virtual
+ """ Executes build jobs by creating Kubernetes jobs which run a qemu-kvm virtual
machine in a pod """
- def __init__(self, *args, **kwargs):
- super(KubernetesExecutor, self).__init__(*args, **kwargs)
- self._loop = get_event_loop()
- self.namespace = self.executor_config.get('BUILDER_NAMESPACE', 'builder')
- self.image = self.executor_config.get('BUILDER_VM_CONTAINER_IMAGE',
- 'quay.io/quay/quay-builder-qemu-coreos:stable')
- @coroutine
- def _request(self, method, path, **kwargs):
- request_options = dict(kwargs)
+ def __init__(self, *args, **kwargs):
+ super(KubernetesExecutor, self).__init__(*args, **kwargs)
+ self._loop = get_event_loop()
+ self.namespace = self.executor_config.get("BUILDER_NAMESPACE", "builder")
+ self.image = self.executor_config.get(
+ "BUILDER_VM_CONTAINER_IMAGE", "quay.io/quay/quay-builder-qemu-coreos:stable"
+ )
- tls_cert = self.executor_config.get('K8S_API_TLS_CERT')
- tls_key = self.executor_config.get('K8S_API_TLS_KEY')
- tls_ca = self.executor_config.get('K8S_API_TLS_CA')
- service_account_token = self.executor_config.get('SERVICE_ACCOUNT_TOKEN')
+ @coroutine
+ def _request(self, method, path, **kwargs):
+ request_options = dict(kwargs)
- if 'timeout' not in request_options:
- request_options['timeout'] = self.executor_config.get("K8S_API_TIMEOUT", 20)
+ tls_cert = self.executor_config.get("K8S_API_TLS_CERT")
+ tls_key = self.executor_config.get("K8S_API_TLS_KEY")
+ tls_ca = self.executor_config.get("K8S_API_TLS_CA")
+ service_account_token = self.executor_config.get("SERVICE_ACCOUNT_TOKEN")
- if service_account_token:
- scheme = 'https'
- request_options['headers'] = {'Authorization': 'Bearer ' + service_account_token}
- logger.debug('Using service account token for Kubernetes authentication')
- elif tls_cert and tls_key:
- scheme = 'https'
- request_options['cert'] = (tls_cert, tls_key)
- logger.debug('Using tls certificate and key for Kubernetes authentication')
- if tls_ca:
- request_options['verify'] = tls_ca
- else:
- scheme = 'http'
+ if "timeout" not in request_options:
+ request_options["timeout"] = self.executor_config.get("K8S_API_TIMEOUT", 20)
- server = self.executor_config.get('K8S_API_SERVER', 'localhost:8080')
- url = '%s://%s%s' % (scheme, server, path)
+ if service_account_token:
+ scheme = "https"
+ request_options["headers"] = {
+ "Authorization": "Bearer " + service_account_token
+ }
+ logger.debug("Using service account token for Kubernetes authentication")
+ elif tls_cert and tls_key:
+ scheme = "https"
+ request_options["cert"] = (tls_cert, tls_key)
+ logger.debug("Using tls certificate and key for Kubernetes authentication")
+ if tls_ca:
+ request_options["verify"] = tls_ca
+ else:
+ scheme = "http"
- logger.debug('Executor config: %s', self.executor_config)
- logger.debug('Kubernetes request: %s %s: %s', method, url, request_options)
- res = requests.request(method, url, **request_options)
- logger.debug('Kubernetes response: %s: %s', res.status_code, res.text)
- raise Return(res)
+ server = self.executor_config.get("K8S_API_SERVER", "localhost:8080")
+ url = "%s://%s%s" % (scheme, server, path)
- def _jobs_path(self):
- return '/apis/batch/v1/namespaces/%s/jobs' % self.namespace
+ logger.debug("Executor config: %s", self.executor_config)
+ logger.debug("Kubernetes request: %s %s: %s", method, url, request_options)
+ res = requests.request(method, url, **request_options)
+ logger.debug("Kubernetes response: %s: %s", res.status_code, res.text)
+ raise Return(res)
- def _job_path(self, build_uuid):
- return '%s/%s' % (self._jobs_path(), build_uuid)
+ def _jobs_path(self):
+ return "/apis/batch/v1/namespaces/%s/jobs" % self.namespace
- def _kubernetes_distribution(self):
- return self.executor_config.get('KUBERNETES_DISTRIBUTION', 'basic').lower()
+ def _job_path(self, build_uuid):
+ return "%s/%s" % (self._jobs_path(), build_uuid)
- def _is_basic_kubernetes_distribution(self):
- return self._kubernetes_distribution() == 'basic'
+ def _kubernetes_distribution(self):
+ return self.executor_config.get("KUBERNETES_DISTRIBUTION", "basic").lower()
- def _is_openshift_kubernetes_distribution(self):
- return self._kubernetes_distribution() == 'openshift'
+ def _is_basic_kubernetes_distribution(self):
+ return self._kubernetes_distribution() == "basic"
- def _build_job_container_resources(self):
- # Minimum acceptable free resources for this container to "fit" in a quota
- # These may be lower than the absolute limits if the cluster is knowingly
- # oversubscribed by some amount.
- container_requests = {
- 'memory' : self.executor_config.get('CONTAINER_MEMORY_REQUEST', '3968Mi'),
- }
+ def _is_openshift_kubernetes_distribution(self):
+ return self._kubernetes_distribution() == "openshift"
- container_limits = {
- 'memory' : self.executor_config.get('CONTAINER_MEMORY_LIMITS', '5120Mi'),
- 'cpu' : self.executor_config.get('CONTAINER_CPU_LIMITS', '1000m'),
- }
+ def _build_job_container_resources(self):
+ # Minimum acceptable free resources for this container to "fit" in a quota
+ # These may be lower than the absolute limits if the cluster is knowingly
+ # oversubscribed by some amount.
+ container_requests = {
+ "memory": self.executor_config.get("CONTAINER_MEMORY_REQUEST", "3968Mi")
+ }
- resources = {
- 'requests': container_requests,
- }
+ container_limits = {
+ "memory": self.executor_config.get("CONTAINER_MEMORY_LIMITS", "5120Mi"),
+ "cpu": self.executor_config.get("CONTAINER_CPU_LIMITS", "1000m"),
+ }
- if self._is_openshift_kubernetes_distribution():
- resources['requests']['cpu'] = self.executor_config.get('CONTAINER_CPU_REQUEST', '500m')
- resources['limits'] = container_limits
+ resources = {"requests": container_requests}
- return resources
+ if self._is_openshift_kubernetes_distribution():
+ resources["requests"]["cpu"] = self.executor_config.get(
+ "CONTAINER_CPU_REQUEST", "500m"
+ )
+ resources["limits"] = container_limits
- def _build_job_containers(self, user_data):
- vm_memory_limit = self.executor_config.get('VM_MEMORY_LIMIT', '4G')
- vm_volume_size = self.executor_config.get('VOLUME_SIZE', '32G')
+ return resources
- container = {
- 'name': 'builder',
- 'imagePullPolicy': 'IfNotPresent',
- 'image': self.image,
- 'securityContext': {'privileged': True},
- 'env': [
- {'name': 'USERDATA', 'value': user_data},
- {'name': 'VM_MEMORY', 'value': vm_memory_limit},
- {'name': 'VM_VOLUME_SIZE', 'value': vm_volume_size},
- ],
- 'resources': self._build_job_container_resources(),
- }
+ def _build_job_containers(self, user_data):
+ vm_memory_limit = self.executor_config.get("VM_MEMORY_LIMIT", "4G")
+ vm_volume_size = self.executor_config.get("VOLUME_SIZE", "32G")
- if self._is_basic_kubernetes_distribution():
- container['volumeMounts'] = [{'name': 'secrets-mask','mountPath': '/var/run/secrets/kubernetes.io/serviceaccount'}]
+ container = {
+ "name": "builder",
+ "imagePullPolicy": "IfNotPresent",
+ "image": self.image,
+ "securityContext": {"privileged": True},
+ "env": [
+ {"name": "USERDATA", "value": user_data},
+ {"name": "VM_MEMORY", "value": vm_memory_limit},
+ {"name": "VM_VOLUME_SIZE", "value": vm_volume_size},
+ ],
+ "resources": self._build_job_container_resources(),
+ }
- return container
+ if self._is_basic_kubernetes_distribution():
+ container["volumeMounts"] = [
+ {
+ "name": "secrets-mask",
+ "mountPath": "/var/run/secrets/kubernetes.io/serviceaccount",
+ }
+ ]
- def _job_resource(self, build_uuid, user_data, coreos_channel='stable'):
- image_pull_secret_name = self.executor_config.get('IMAGE_PULL_SECRET_NAME', 'builder')
- service_account = self.executor_config.get('SERVICE_ACCOUNT_NAME', 'quay-builder-sa')
- node_selector_label_key = self.executor_config.get('NODE_SELECTOR_LABEL_KEY', 'beta.kubernetes.io/instance-type')
- node_selector_label_value = self.executor_config.get('NODE_SELECTOR_LABEL_VALUE', '')
+ return container
- node_selector = {
- node_selector_label_key : node_selector_label_value
- }
+ def _job_resource(self, build_uuid, user_data, coreos_channel="stable"):
+ image_pull_secret_name = self.executor_config.get(
+ "IMAGE_PULL_SECRET_NAME", "builder"
+ )
+ service_account = self.executor_config.get(
+ "SERVICE_ACCOUNT_NAME", "quay-builder-sa"
+ )
+ node_selector_label_key = self.executor_config.get(
+ "NODE_SELECTOR_LABEL_KEY", "beta.kubernetes.io/instance-type"
+ )
+ node_selector_label_value = self.executor_config.get(
+ "NODE_SELECTOR_LABEL_VALUE", ""
+ )
- release_sha = release.GIT_HEAD or 'none'
- if ' ' in release_sha:
- release_sha = 'HEAD'
+ node_selector = {node_selector_label_key: node_selector_label_value}
- job_resource = {
- 'apiVersion': 'batch/v1',
- 'kind': 'Job',
- 'metadata': {
- 'namespace': self.namespace,
- 'generateName': build_uuid + '-',
- 'labels': {
- 'build': build_uuid,
- 'time': datetime.datetime.now().strftime('%Y-%m-%d-%H'),
- 'manager': socket.gethostname(),
- 'quay-sha': release_sha,
- },
- },
- 'spec' : {
- 'activeDeadlineSeconds': self.executor_config.get('MAXIMUM_JOB_TIME', 7200),
- 'template': {
- 'metadata': {
- 'labels': {
- 'build': build_uuid,
- 'time': datetime.datetime.now().strftime('%Y-%m-%d-%H'),
- 'manager': socket.gethostname(),
- 'quay-sha': release_sha,
+ release_sha = release.GIT_HEAD or "none"
+ if " " in release_sha:
+ release_sha = "HEAD"
+
+ job_resource = {
+ "apiVersion": "batch/v1",
+ "kind": "Job",
+ "metadata": {
+ "namespace": self.namespace,
+ "generateName": build_uuid + "-",
+ "labels": {
+ "build": build_uuid,
+ "time": datetime.datetime.now().strftime("%Y-%m-%d-%H"),
+ "manager": socket.gethostname(),
+ "quay-sha": release_sha,
+ },
},
- },
- 'spec': {
- 'imagePullSecrets': [{ 'name': image_pull_secret_name }],
- 'restartPolicy': 'Never',
- 'dnsPolicy': 'Default',
- 'containers': [self._build_job_containers(user_data)],
- },
- },
- },
- }
+ "spec": {
+ "activeDeadlineSeconds": self.executor_config.get(
+ "MAXIMUM_JOB_TIME", 7200
+ ),
+ "template": {
+ "metadata": {
+ "labels": {
+ "build": build_uuid,
+ "time": datetime.datetime.now().strftime("%Y-%m-%d-%H"),
+ "manager": socket.gethostname(),
+ "quay-sha": release_sha,
+ }
+ },
+ "spec": {
+ "imagePullSecrets": [{"name": image_pull_secret_name}],
+ "restartPolicy": "Never",
+ "dnsPolicy": "Default",
+ "containers": [self._build_job_containers(user_data)],
+ },
+ },
+ },
+ }
- if self._is_openshift_kubernetes_distribution():
- # Setting `automountServiceAccountToken` to false will prevent automounting API credentials for a service account.
- job_resource['spec']['template']['spec']['automountServiceAccountToken'] = False
+ if self._is_openshift_kubernetes_distribution():
+ # Setting `automountServiceAccountToken` to false will prevent automounting API credentials for a service account.
+ job_resource["spec"]["template"]["spec"][
+ "automountServiceAccountToken"
+ ] = False
- # Use dedicated service account that has no authorization to any resources.
- job_resource['spec']['template']['spec']['serviceAccount'] = service_account
+ # Use dedicated service account that has no authorization to any resources.
+ job_resource["spec"]["template"]["spec"]["serviceAccount"] = service_account
- # Setting `enableServiceLinks` to false prevents information about other services from being injected into pod's
- # environment variables. Pod has no visibility into other services on the cluster.
- job_resource['spec']['template']['spec']['enableServiceLinks'] = False
+ # Setting `enableServiceLinks` to false prevents information about other services from being injected into pod's
+ # environment variables. Pod has no visibility into other services on the cluster.
+ job_resource["spec"]["template"]["spec"]["enableServiceLinks"] = False
- if node_selector_label_value.strip() != '':
- job_resource['spec']['template']['spec']['nodeSelector'] = node_selector
+ if node_selector_label_value.strip() != "":
+ job_resource["spec"]["template"]["spec"]["nodeSelector"] = node_selector
- if self._is_basic_kubernetes_distribution():
- # This volume is a hack to mask the token for the namespace's
- # default service account, which is placed in a file mounted under
- # `/var/run/secrets/kubernetes.io/serviceaccount` in all pods.
- # There's currently no other way to just disable the service
- # account at either the pod or namespace level.
- #
- # https://github.com/kubernetes/kubernetes/issues/16779
- #
- job_resource['spec']['template']['spec']['volumes'] = [{'name': 'secrets-mask','emptyDir': {'medium': 'Memory'}}]
+ if self._is_basic_kubernetes_distribution():
+ # This volume is a hack to mask the token for the namespace's
+ # default service account, which is placed in a file mounted under
+ # `/var/run/secrets/kubernetes.io/serviceaccount` in all pods.
+ # There's currently no other way to just disable the service
+ # account at either the pod or namespace level.
+ #
+ # https://github.com/kubernetes/kubernetes/issues/16779
+ #
+ job_resource["spec"]["template"]["spec"]["volumes"] = [
+ {"name": "secrets-mask", "emptyDir": {"medium": "Memory"}}
+ ]
- return job_resource
+ return job_resource
- @coroutine
- @duration_collector_async(metric_queue.builder_time_to_start, ['k8s'])
- def start_builder(self, realm, token, build_uuid):
- # generate resource
- channel = self.executor_config.get('COREOS_CHANNEL', 'stable')
- user_data = self.generate_cloud_config(realm, token, build_uuid, channel, self.manager_hostname)
- resource = self._job_resource(build_uuid, user_data, channel)
- logger.debug('Using Kubernetes Distribution: %s', self._kubernetes_distribution())
- logger.debug('Generated kubernetes resource:\n%s', resource)
+ @coroutine
+ @duration_collector_async(metric_queue.builder_time_to_start, ["k8s"])
+ def start_builder(self, realm, token, build_uuid):
+ # generate resource
+ channel = self.executor_config.get("COREOS_CHANNEL", "stable")
+ user_data = self.generate_cloud_config(
+ realm, token, build_uuid, channel, self.manager_hostname
+ )
+ resource = self._job_resource(build_uuid, user_data, channel)
+ logger.debug(
+ "Using Kubernetes Distribution: %s", self._kubernetes_distribution()
+ )
+ logger.debug("Generated kubernetes resource:\n%s", resource)
- # schedule
- create_job = yield From(self._request('POST', self._jobs_path(), json=resource))
- if int(create_job.status_code / 100) != 2:
- raise ExecutorException('Failed to create job: %s: %s: %s' %
- (build_uuid, create_job.status_code, create_job.text))
+ # schedule
+ create_job = yield From(self._request("POST", self._jobs_path(), json=resource))
+ if int(create_job.status_code / 100) != 2:
+ raise ExecutorException(
+ "Failed to create job: %s: %s: %s"
+ % (build_uuid, create_job.status_code, create_job.text)
+ )
- job = create_job.json()
- raise Return(job['metadata']['name'])
+ job = create_job.json()
+ raise Return(job["metadata"]["name"])
- @coroutine
- def stop_builder(self, builder_id):
- pods_path = '/api/v1/namespaces/%s/pods' % self.namespace
+ @coroutine
+ def stop_builder(self, builder_id):
+ pods_path = "/api/v1/namespaces/%s/pods" % self.namespace
- # Delete the job itself.
- try:
- yield From(self._request('DELETE', self._job_path(builder_id)))
- except:
- logger.exception('Failed to send delete job call for job %s', builder_id)
+ # Delete the job itself.
+ try:
+ yield From(self._request("DELETE", self._job_path(builder_id)))
+ except:
+ logger.exception("Failed to send delete job call for job %s", builder_id)
- # Delete the pod(s) for the job.
- selectorString = "job-name=%s" % builder_id
- try:
- yield From(self._request('DELETE', pods_path, params=dict(labelSelector=selectorString)))
- except:
- logger.exception("Failed to send delete pod call for job %s", builder_id)
+ # Delete the pod(s) for the job.
+ selectorString = "job-name=%s" % builder_id
+ try:
+ yield From(
+ self._request(
+ "DELETE", pods_path, params=dict(labelSelector=selectorString)
+ )
+ )
+ except:
+ logger.exception("Failed to send delete pod call for job %s", builder_id)
class LogPipe(threading.Thread):
- """ Adapted from http://codereview.stackexchange.com/a/17959
+ """ Adapted from http://codereview.stackexchange.com/a/17959
"""
- def __init__(self, level):
- """Setup the object with a logger and a loglevel
+
+ def __init__(self, level):
+ """Setup the object with a logger and a loglevel
and start the thread
"""
- threading.Thread.__init__(self)
- self.daemon = False
- self.level = level
- self.fd_read, self.fd_write = os.pipe()
- self.pipe_reader = os.fdopen(self.fd_read)
- self.start()
+ threading.Thread.__init__(self)
+ self.daemon = False
+ self.level = level
+ self.fd_read, self.fd_write = os.pipe()
+ self.pipe_reader = os.fdopen(self.fd_read)
+ self.start()
- def fileno(self):
- """Return the write file descriptor of the pipe
+ def fileno(self):
+ """Return the write file descriptor of the pipe
"""
- return self.fd_write
+ return self.fd_write
- def run(self):
- """Run the thread, logging everything.
+ def run(self):
+ """Run the thread, logging everything.
"""
- for line in iter(self.pipe_reader.readline, ''):
- logging.log(self.level, line.strip('\n'))
+ for line in iter(self.pipe_reader.readline, ""):
+ logging.log(self.level, line.strip("\n"))
- self.pipe_reader.close()
+ self.pipe_reader.close()
- def close(self):
- """Close the write end of the pipe.
+ def close(self):
+ """Close the write end of the pipe.
"""
- os.close(self.fd_write)
+ os.close(self.fd_write)
diff --git a/buildman/manager/noop_canceller.py b/buildman/manager/noop_canceller.py
index 2adf17ad7..51c023fcc 100644
--- a/buildman/manager/noop_canceller.py
+++ b/buildman/manager/noop_canceller.py
@@ -1,8 +1,9 @@
class NoopCanceller(object):
- """ A class that can not cancel a build """
- def __init__(self, config=None):
- pass
+ """ A class that can not cancel a build """
- def try_cancel_build(self, uuid):
- """ Does nothing and fails to cancel build. """
- return False
+ def __init__(self, config=None):
+ pass
+
+ def try_cancel_build(self, uuid):
+ """ Does nothing and fails to cancel build. """
+ return False
diff --git a/buildman/manager/orchestrator_canceller.py b/buildman/manager/orchestrator_canceller.py
index f3f821d5e..64ae4f8d7 100644
--- a/buildman/manager/orchestrator_canceller.py
+++ b/buildman/manager/orchestrator_canceller.py
@@ -7,20 +7,23 @@ from util import slash_join
logger = logging.getLogger(__name__)
-CANCEL_PREFIX = 'cancel/'
+CANCEL_PREFIX = "cancel/"
class OrchestratorCanceller(object):
- """ An asynchronous way to cancel a build with any Orchestrator. """
- def __init__(self, config):
- self._orchestrator = orchestrator_from_config(config, canceller_only=True)
+ """ An asynchronous way to cancel a build with any Orchestrator. """
- def try_cancel_build(self, build_uuid):
- logger.info('Cancelling build %s', build_uuid)
- cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
- try:
- self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
- return True
- except OrchestratorError:
- logger.exception('Failed to write cancel action to redis with uuid %s', build_uuid)
- return False
+ def __init__(self, config):
+ self._orchestrator = orchestrator_from_config(config, canceller_only=True)
+
+ def try_cancel_build(self, build_uuid):
+ logger.info("Cancelling build %s", build_uuid)
+ cancel_key = slash_join(CANCEL_PREFIX, build_uuid)
+ try:
+ self._orchestrator.set_key_sync(cancel_key, build_uuid, expiration=60)
+ return True
+ except OrchestratorError:
+ logger.exception(
+ "Failed to write cancel action to redis with uuid %s", build_uuid
+ )
+ return False
diff --git a/buildman/test/test_buildman.py b/buildman/test/test_buildman.py
index 49b9a20fc..ec6192ae2 100644
--- a/buildman/test/test_buildman.py
+++ b/buildman/test/test_buildman.py
@@ -9,8 +9,7 @@ from trollius import coroutine, get_event_loop, From, Future, Return
from app import metric_queue
from buildman.asyncutil import AsyncWrapper
from buildman.component.buildcomponent import BuildComponent
-from buildman.manager.ephemeral import (EphemeralBuilderManager, REALM_PREFIX,
- JOB_PREFIX)
+from buildman.manager.ephemeral import EphemeralBuilderManager, REALM_PREFIX, JOB_PREFIX
from buildman.manager.executor import BuilderExecutor, ExecutorException
from buildman.orchestrator import KeyEvent, KeyChange
from buildman.server import BuildJobResult
@@ -18,662 +17,767 @@ from util import slash_join
from util.metrics.metricqueue import duration_collector_async
-BUILD_UUID = 'deadbeef-dead-beef-dead-deadbeefdead'
-REALM_ID = '1234-realm'
+BUILD_UUID = "deadbeef-dead-beef-dead-deadbeefdead"
+REALM_ID = "1234-realm"
def async_test(f):
- def wrapper(*args, **kwargs):
- coro = coroutine(f)
- future = coro(*args, **kwargs)
- loop = get_event_loop()
- loop.run_until_complete(future)
- return wrapper
+ def wrapper(*args, **kwargs):
+ coro = coroutine(f)
+ future = coro(*args, **kwargs)
+ loop = get_event_loop()
+ loop.run_until_complete(future)
+
+ return wrapper
class TestExecutor(BuilderExecutor):
- job_started = None
- job_stopped = None
+ job_started = None
+ job_stopped = None
- @coroutine
- @duration_collector_async(metric_queue.builder_time_to_start, labelvalues=["testlabel"])
- def start_builder(self, realm, token, build_uuid):
- self.job_started = str(uuid.uuid4())
- raise Return(self.job_started)
+ @coroutine
+ @duration_collector_async(
+ metric_queue.builder_time_to_start, labelvalues=["testlabel"]
+ )
+ def start_builder(self, realm, token, build_uuid):
+ self.job_started = str(uuid.uuid4())
+ raise Return(self.job_started)
- @coroutine
- def stop_builder(self, execution_id):
- self.job_stopped = execution_id
+ @coroutine
+ def stop_builder(self, execution_id):
+ self.job_stopped = execution_id
class BadExecutor(BuilderExecutor):
- @coroutine
- @duration_collector_async(metric_queue.builder_time_to_start, labelvalues=["testlabel"])
- def start_builder(self, realm, token, build_uuid):
- raise ExecutorException('raised on purpose!')
+ @coroutine
+ @duration_collector_async(
+ metric_queue.builder_time_to_start, labelvalues=["testlabel"]
+ )
+ def start_builder(self, realm, token, build_uuid):
+ raise ExecutorException("raised on purpose!")
class EphemeralBuilderTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- self.etcd_client_mock = None
- super(EphemeralBuilderTestCase, self).__init__(*args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ self.etcd_client_mock = None
+ super(EphemeralBuilderTestCase, self).__init__(*args, **kwargs)
- @staticmethod
- def _create_completed_future(result=None):
- def inner(*args, **kwargs):
- new_future = Future()
- new_future.set_result(result)
- return new_future
- return inner
+ @staticmethod
+ def _create_completed_future(result=None):
+ def inner(*args, **kwargs):
+ new_future = Future()
+ new_future.set_result(result)
+ return new_future
- def setUp(self):
- self._existing_executors = dict(EphemeralBuilderManager.EXECUTORS)
+ return inner
- def tearDown(self):
- EphemeralBuilderManager.EXECUTORS = self._existing_executors
+ def setUp(self):
+ self._existing_executors = dict(EphemeralBuilderManager.EXECUTORS)
- @coroutine
- def _register_component(self, realm_spec, build_component, token):
- raise Return('hello')
+ def tearDown(self):
+ EphemeralBuilderManager.EXECUTORS = self._existing_executors
- def _create_build_job(self, namespace='namespace', retries=3):
- mock_job = Mock()
- mock_job.job_details = {'build_uuid': BUILD_UUID}
- mock_job.job_item = {
- 'body': json.dumps(mock_job.job_details),
- 'id': 1,
- }
+ @coroutine
+ def _register_component(self, realm_spec, build_component, token):
+ raise Return("hello")
- mock_job.namespace = namespace
- mock_job.retries_remaining = retries
- mock_job.build_uuid = BUILD_UUID
- return mock_job
+ def _create_build_job(self, namespace="namespace", retries=3):
+ mock_job = Mock()
+ mock_job.job_details = {"build_uuid": BUILD_UUID}
+ mock_job.job_item = {"body": json.dumps(mock_job.job_details), "id": 1}
+
+ mock_job.namespace = namespace
+ mock_job.retries_remaining = retries
+ mock_job.build_uuid = BUILD_UUID
+ return mock_job
class TestEphemeralLifecycle(EphemeralBuilderTestCase):
- """ Tests the various lifecycles of the ephemeral builder and its interaction with etcd. """
+ """ Tests the various lifecycles of the ephemeral builder and its interaction with etcd. """
- def __init__(self, *args, **kwargs):
- super(TestEphemeralLifecycle, self).__init__(*args, **kwargs)
- self.etcd_client_mock = None
- self.test_executor = None
+ def __init__(self, *args, **kwargs):
+ super(TestEphemeralLifecycle, self).__init__(*args, **kwargs)
+ self.etcd_client_mock = None
+ self.test_executor = None
- def _create_completed_future(self, result=None):
- def inner(*args, **kwargs):
- new_future = Future()
- new_future.set_result(result)
- return new_future
- return inner
+ def _create_completed_future(self, result=None):
+ def inner(*args, **kwargs):
+ new_future = Future()
+ new_future.set_result(result)
+ return new_future
- def _create_mock_executor(self, *args, **kwargs):
- self.test_executor = Mock(spec=BuilderExecutor)
- self.test_executor.start_builder = Mock(side_effect=self._create_completed_future('123'))
- self.test_executor.stop_builder = Mock(side_effect=self._create_completed_future())
- self.test_executor.setup_time = 60
- self.test_executor.name = 'MockExecutor'
- self.test_executor.minimum_retry_threshold = 0
- return self.test_executor
+ return inner
- def setUp(self):
- super(TestEphemeralLifecycle, self).setUp()
+ def _create_mock_executor(self, *args, **kwargs):
+ self.test_executor = Mock(spec=BuilderExecutor)
+ self.test_executor.start_builder = Mock(
+ side_effect=self._create_completed_future("123")
+ )
+ self.test_executor.stop_builder = Mock(
+ side_effect=self._create_completed_future()
+ )
+ self.test_executor.setup_time = 60
+ self.test_executor.name = "MockExecutor"
+ self.test_executor.minimum_retry_threshold = 0
+ return self.test_executor
- EphemeralBuilderManager.EXECUTORS['test'] = self._create_mock_executor
+ def setUp(self):
+ super(TestEphemeralLifecycle, self).setUp()
- self.register_component_callback = Mock()
- self.unregister_component_callback = Mock()
- self.job_heartbeat_callback = Mock()
- self.job_complete_callback = AsyncWrapper(Mock())
+ EphemeralBuilderManager.EXECUTORS["test"] = self._create_mock_executor
- self.manager = EphemeralBuilderManager(
- self.register_component_callback,
- self.unregister_component_callback,
- self.job_heartbeat_callback,
- self.job_complete_callback,
- '127.0.0.1',
- 30,
- )
+ self.register_component_callback = Mock()
+ self.unregister_component_callback = Mock()
+ self.job_heartbeat_callback = Mock()
+ self.job_complete_callback = AsyncWrapper(Mock())
- self.manager.initialize({
- 'EXECUTOR': 'test',
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
+ self.manager = EphemeralBuilderManager(
+ self.register_component_callback,
+ self.unregister_component_callback,
+ self.job_heartbeat_callback,
+ self.job_complete_callback,
+ "127.0.0.1",
+ 30,
+ )
- # Ensure that that the realm and building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(REALM_PREFIX, callback_keys)
- self.assertIn(JOB_PREFIX, callback_keys)
+ self.manager.initialize(
+ {"EXECUTOR": "test", "ORCHESTRATOR": {"MEM_CONFIG": None}}
+ )
- self.mock_job = self._create_build_job()
- self.mock_job_key = slash_join('building', BUILD_UUID)
+ # Ensure that that the realm and building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(REALM_PREFIX, callback_keys)
+ self.assertIn(JOB_PREFIX, callback_keys)
- def tearDown(self):
- super(TestEphemeralLifecycle, self).tearDown()
- self.manager.shutdown()
+ self.mock_job = self._create_build_job()
+ self.mock_job_key = slash_join("building", BUILD_UUID)
+ def tearDown(self):
+ super(TestEphemeralLifecycle, self).tearDown()
+ self.manager.shutdown()
- @coroutine
- def _setup_job_for_managers(self):
- test_component = Mock(spec=BuildComponent)
- test_component.builder_realm = REALM_ID
- test_component.start_build = Mock(side_effect=self._create_completed_future())
- self.register_component_callback.return_value = test_component
+ @coroutine
+ def _setup_job_for_managers(self):
+ test_component = Mock(spec=BuildComponent)
+ test_component.builder_realm = REALM_ID
+ test_component.start_build = Mock(side_effect=self._create_completed_future())
+ self.register_component_callback.return_value = test_component
- is_scheduled = yield From(self.manager.schedule(self.mock_job))
- self.assertTrue(is_scheduled)
- self.assertEqual(self.test_executor.start_builder.call_count, 1)
+ is_scheduled = yield From(self.manager.schedule(self.mock_job))
+ self.assertTrue(is_scheduled)
+ self.assertEqual(self.test_executor.start_builder.call_count, 1)
- # Ensure that that the job, realm, and metric callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(self.mock_job_key, self.manager._orchestrator.state)
- self.assertIn(REALM_PREFIX, callback_keys)
- # TODO: assert metric key has been set
+ # Ensure that that the job, realm, and metric callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(self.mock_job_key, self.manager._orchestrator.state)
+ self.assertIn(REALM_PREFIX, callback_keys)
+ # TODO: assert metric key has been set
- realm_for_build = self._find_realm_key(self.manager._orchestrator, BUILD_UUID)
+ realm_for_build = self._find_realm_key(self.manager._orchestrator, BUILD_UUID)
- raw_realm_data = yield From(self.manager._orchestrator.get_key(slash_join('realm',
- realm_for_build)))
- realm_data = json.loads(raw_realm_data)
- realm_data['realm'] = REALM_ID
+ raw_realm_data = yield From(
+ self.manager._orchestrator.get_key(slash_join("realm", realm_for_build))
+ )
+ realm_data = json.loads(raw_realm_data)
+ realm_data["realm"] = REALM_ID
- # Right now the job is not registered with any managers because etcd has not accepted the job
- self.assertEqual(self.register_component_callback.call_count, 0)
+ # Right now the job is not registered with any managers because etcd has not accepted the job
+ self.assertEqual(self.register_component_callback.call_count, 0)
- # Fire off a realm changed with the same data.
- yield From(self.manager._realm_callback(
- KeyChange(KeyEvent.CREATE,
- slash_join(REALM_PREFIX, REALM_ID),
- json.dumps(realm_data))))
+ # Fire off a realm changed with the same data.
+ yield From(
+ self.manager._realm_callback(
+ KeyChange(
+ KeyEvent.CREATE,
+ slash_join(REALM_PREFIX, REALM_ID),
+ json.dumps(realm_data),
+ )
+ )
+ )
- # Ensure that we have at least one component node.
- self.assertEqual(self.register_component_callback.call_count, 1)
- self.assertEqual(1, self.manager.num_workers())
+ # Ensure that we have at least one component node.
+ self.assertEqual(self.register_component_callback.call_count, 1)
+ self.assertEqual(1, self.manager.num_workers())
- # Ensure that the build info exists.
- self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ # Ensure that the build info exists.
+ self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- raise Return(test_component)
+ raise Return(test_component)
- @staticmethod
- def _find_realm_key(orchestrator, build_uuid):
- for key, value in iteritems(orchestrator.state):
- if key.startswith(REALM_PREFIX):
- parsed_value = json.loads(value)
- body = json.loads(parsed_value['job_queue_item']['body'])
- if body['build_uuid'] == build_uuid:
- return parsed_value['realm']
- continue
- raise KeyError
+ @staticmethod
+ def _find_realm_key(orchestrator, build_uuid):
+ for key, value in iteritems(orchestrator.state):
+ if key.startswith(REALM_PREFIX):
+ parsed_value = json.loads(value)
+ body = json.loads(parsed_value["job_queue_item"]["body"])
+ if body["build_uuid"] == build_uuid:
+ return parsed_value["realm"]
+ continue
+ raise KeyError
+ @async_test
+ def test_schedule_and_complete(self):
+ # Test that a job is properly registered with all of the managers
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_schedule_and_complete(self):
- # Test that a job is properly registered with all of the managers
- test_component = yield From(self._setup_job_for_managers())
+ # Take the job ourselves
+ yield From(self.manager.build_component_ready(test_component))
- # Take the job ourselves
- yield From(self.manager.build_component_ready(test_component))
+ self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ # Finish the job
+ yield From(
+ self.manager.job_completed(
+ self.mock_job, BuildJobResult.COMPLETE, test_component
+ )
+ )
- # Finish the job
- yield From(self.manager.job_completed(self.mock_job, BuildJobResult.COMPLETE, test_component))
+ # Ensure that the executor kills the job.
+ self.assertEqual(self.test_executor.stop_builder.call_count, 1)
- # Ensure that the executor kills the job.
- self.assertEqual(self.test_executor.stop_builder.call_count, 1)
+ # Ensure the build information is cleaned up.
+ self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ self.assertEqual(0, self.manager.num_workers())
- # Ensure the build information is cleaned up.
- self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- self.assertEqual(0, self.manager.num_workers())
+ @async_test
+ def test_another_manager_takes_job(self):
+ # Prepare a job to be taken by another manager
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_another_manager_takes_job(self):
- # Prepare a job to be taken by another manager
- test_component = yield From(self._setup_job_for_managers())
+ yield From(
+ self.manager._realm_callback(
+ KeyChange(
+ KeyEvent.DELETE,
+ slash_join(REALM_PREFIX, REALM_ID),
+ json.dumps(
+ {
+ "realm": REALM_ID,
+ "token": "beef",
+ "execution_id": "123",
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- yield From(self.manager._realm_callback(
- KeyChange(KeyEvent.DELETE,
- slash_join(REALM_PREFIX, REALM_ID),
- json.dumps({'realm': REALM_ID,
- 'token': 'beef',
- 'execution_id': '123',
- 'job_queue_item': self.mock_job.job_item}))))
+ self.unregister_component_callback.assert_called_once_with(test_component)
- self.unregister_component_callback.assert_called_once_with(test_component)
+ # Ensure that the executor does not kill the job.
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- # Ensure that the executor does not kill the job.
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ # Ensure that we still have the build info, but not the component.
+ self.assertEqual(0, self.manager.num_workers())
+ self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- # Ensure that we still have the build info, but not the component.
- self.assertEqual(0, self.manager.num_workers())
- self.assertIsNotNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ # Delete the job once it has "completed".
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.DELETE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": False,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- # Delete the job once it has "completed".
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.DELETE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': False,
- 'job_queue_item': self.mock_job.job_item}))))
+ # Ensure the job was removed from the info, but stop was not called.
+ self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- # Ensure the job was removed from the info, but stop was not called.
- self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ @async_test
+ def test_job_started_by_other_manager(self):
+ # Ensure that that the building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(JOB_PREFIX, callback_keys)
- @async_test
- def test_job_started_by_other_manager(self):
- # Ensure that that the building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(JOB_PREFIX, callback_keys)
+ # Send a signal to the callback that the job has been created.
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.CREATE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": False,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- # Send a signal to the callback that the job has been created.
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.CREATE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': False,
- 'job_queue_item': self.mock_job.job_item}))))
+ # Ensure the create does nothing.
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- # Ensure the create does nothing.
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ @async_test
+ def test_expiring_worker_not_started(self):
+ # Ensure that that the building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(JOB_PREFIX, callback_keys)
- @async_test
- def test_expiring_worker_not_started(self):
- # Ensure that that the building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(JOB_PREFIX, callback_keys)
+ # Send a signal to the callback that a worker has expired
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.EXPIRE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": True,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- # Send a signal to the callback that a worker has expired
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.EXPIRE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': True,
- 'job_queue_item': self.mock_job.job_item}))))
+ # Since the realm was never registered, expiration should do nothing.
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- # Since the realm was never registered, expiration should do nothing.
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ @async_test
+ def test_expiring_worker_started(self):
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_expiring_worker_started(self):
- test_component = yield From(self._setup_job_for_managers())
+ # Ensure that that the building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(JOB_PREFIX, callback_keys)
- # Ensure that that the building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(JOB_PREFIX, callback_keys)
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.EXPIRE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": True,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.EXPIRE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': True,
- 'job_queue_item': self.mock_job.job_item}))))
+ self.test_executor.stop_builder.assert_called_once_with("123")
+ self.assertEqual(self.test_executor.stop_builder.call_count, 1)
- self.test_executor.stop_builder.assert_called_once_with('123')
- self.assertEqual(self.test_executor.stop_builder.call_count, 1)
+ @async_test
+ def test_buildjob_deleted(self):
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_buildjob_deleted(self):
- test_component = yield From(self._setup_job_for_managers())
+ # Ensure that that the building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(JOB_PREFIX, callback_keys)
- # Ensure that that the building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(JOB_PREFIX, callback_keys)
+ # Send a signal to the callback that a worker has expired
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.DELETE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": False,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- # Send a signal to the callback that a worker has expired
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.DELETE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': False,
- 'job_queue_item': self.mock_job.job_item}))))
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ self.assertEqual(self.job_complete_callback.call_count, 0)
+ self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- self.assertEqual(self.job_complete_callback.call_count, 0)
- self.assertIsNone(self.manager._build_uuid_to_info.get(BUILD_UUID))
+ @async_test
+ def test_builder_never_starts(self):
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_builder_never_starts(self):
- test_component = yield From(self._setup_job_for_managers())
+ # Ensure that that the building callbacks have been registered
+ callback_keys = [key for key in self.manager._orchestrator.callbacks]
+ self.assertIn(JOB_PREFIX, callback_keys)
- # Ensure that that the building callbacks have been registered
- callback_keys = [key for key in self.manager._orchestrator.callbacks]
- self.assertIn(JOB_PREFIX, callback_keys)
+ # Send a signal to the callback that a worker has expired
+ yield From(
+ self.manager._job_callback(
+ KeyChange(
+ KeyEvent.EXPIRE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "had_heartbeat": False,
+ "job_queue_item": self.mock_job.job_item,
+ }
+ ),
+ )
+ )
+ )
- # Send a signal to the callback that a worker has expired
- yield From(self.manager._job_callback(
- KeyChange(KeyEvent.EXPIRE,
- self.mock_job_key,
- json.dumps({'had_heartbeat': False,
- 'job_queue_item': self.mock_job.job_item}))))
+ self.test_executor.stop_builder.assert_called_once_with("123")
+ self.assertEqual(self.test_executor.stop_builder.call_count, 1)
- self.test_executor.stop_builder.assert_called_once_with('123')
- self.assertEqual(self.test_executor.stop_builder.call_count, 1)
+ # Ensure the job was marked as incomplete, with an update_phase to True (so the DB record and
+ # logs are updated as well)
+ yield From(
+ self.job_complete_callback.assert_called_once_with(
+ ANY, BuildJobResult.INCOMPLETE, "MockExecutor", update_phase=True
+ )
+ )
- # Ensure the job was marked as incomplete, with an update_phase to True (so the DB record and
- # logs are updated as well)
- yield From(self.job_complete_callback.assert_called_once_with(ANY, BuildJobResult.INCOMPLETE,
- 'MockExecutor',
- update_phase=True))
+ @async_test
+ def test_change_worker(self):
+ # Send a signal to the callback that a worker key has been changed
+ self.manager._job_callback(KeyChange(KeyEvent.SET, self.mock_job_key, "value"))
+ self.assertEqual(self.test_executor.stop_builder.call_count, 0)
- @async_test
- def test_change_worker(self):
- # Send a signal to the callback that a worker key has been changed
- self.manager._job_callback(KeyChange(KeyEvent.SET, self.mock_job_key, 'value'))
- self.assertEqual(self.test_executor.stop_builder.call_count, 0)
+ @async_test
+ def test_realm_expired(self):
+ test_component = yield From(self._setup_job_for_managers())
- @async_test
- def test_realm_expired(self):
- test_component = yield From(self._setup_job_for_managers())
+ # Send a signal to the callback that a realm has expired
+ yield From(
+ self.manager._realm_callback(
+ KeyChange(
+ KeyEvent.EXPIRE,
+ self.mock_job_key,
+ json.dumps(
+ {
+ "realm": REALM_ID,
+ "execution_id": "foobar",
+ "executor_name": "MockExecutor",
+ "job_queue_item": {"body": '{"build_uuid": "fakeid"}'},
+ }
+ ),
+ )
+ )
+ )
- # Send a signal to the callback that a realm has expired
- yield From(self.manager._realm_callback(KeyChange(
- KeyEvent.EXPIRE,
- self.mock_job_key,
- json.dumps({
- 'realm': REALM_ID,
- 'execution_id': 'foobar',
- 'executor_name': 'MockExecutor',
- 'job_queue_item': {'body': '{"build_uuid": "fakeid"}'},
- }))))
-
- # Ensure that the cleanup code for the executor was called.
- self.test_executor.stop_builder.assert_called_once_with('foobar')
- self.assertEqual(self.test_executor.stop_builder.call_count, 1)
+ # Ensure that the cleanup code for the executor was called.
+ self.test_executor.stop_builder.assert_called_once_with("foobar")
+ self.assertEqual(self.test_executor.stop_builder.call_count, 1)
class TestEphemeral(EphemeralBuilderTestCase):
- """ Simple unit tests for the ephemeral builder around config management, starting and stopping
+ """ Simple unit tests for the ephemeral builder around config management, starting and stopping
jobs.
"""
- def setUp(self):
- super(TestEphemeral, self).setUp()
+ def setUp(self):
+ super(TestEphemeral, self).setUp()
- unregister_component_callback = Mock()
- job_heartbeat_callback = Mock()
+ unregister_component_callback = Mock()
+ job_heartbeat_callback = Mock()
- @coroutine
- def job_complete_callback(*args, **kwargs):
- raise Return()
+ @coroutine
+ def job_complete_callback(*args, **kwargs):
+ raise Return()
- self.manager = EphemeralBuilderManager(
- self._register_component,
- unregister_component_callback,
- job_heartbeat_callback,
- job_complete_callback,
- '127.0.0.1',
- 30,
- )
+ self.manager = EphemeralBuilderManager(
+ self._register_component,
+ unregister_component_callback,
+ job_heartbeat_callback,
+ job_complete_callback,
+ "127.0.0.1",
+ 30,
+ )
- def tearDown(self):
- super(TestEphemeral, self).tearDown()
- self.manager.shutdown()
+ def tearDown(self):
+ super(TestEphemeral, self).tearDown()
+ self.manager.shutdown()
- def test_verify_executor_oldconfig(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- self.manager.initialize({
- 'EXECUTOR': 'test',
- 'EXECUTOR_CONFIG': dict(MINIMUM_RETRY_THRESHOLD=42),
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
+ def test_verify_executor_oldconfig(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ self.manager.initialize(
+ {
+ "EXECUTOR": "test",
+ "EXECUTOR_CONFIG": dict(MINIMUM_RETRY_THRESHOLD=42),
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
- # Ensure that we have a single test executor.
- self.assertEqual(1, len(self.manager.registered_executors))
- self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold)
- self.assertEqual('TestExecutor', self.manager.registered_executors[0].name)
+ # Ensure that we have a single test executor.
+ self.assertEqual(1, len(self.manager.registered_executors))
+ self.assertEqual(
+ 42, self.manager.registered_executors[0].minimum_retry_threshold
+ )
+ self.assertEqual("TestExecutor", self.manager.registered_executors[0].name)
- def test_verify_executor_newconfig(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- self.manager.initialize({
- 'EXECUTORS': [{
- 'EXECUTOR': 'test',
- 'MINIMUM_RETRY_THRESHOLD': 42
- }],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
+ def test_verify_executor_newconfig(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ self.manager.initialize(
+ {
+ "EXECUTORS": [{"EXECUTOR": "test", "MINIMUM_RETRY_THRESHOLD": 42}],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
- # Ensure that we have a single test executor.
- self.assertEqual(1, len(self.manager.registered_executors))
- self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold)
+ # Ensure that we have a single test executor.
+ self.assertEqual(1, len(self.manager.registered_executors))
+ self.assertEqual(
+ 42, self.manager.registered_executors[0].minimum_retry_threshold
+ )
+
+ def test_multiple_executors_samename(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ EphemeralBuilderManager.EXECUTORS["anotherexecutor"] = TestExecutor
+
+ with self.assertRaises(Exception):
+ self.manager.initialize(
+ {
+ "EXECUTORS": [
+ {
+ "NAME": "primary",
+ "EXECUTOR": "test",
+ "MINIMUM_RETRY_THRESHOLD": 42,
+ },
+ {
+ "NAME": "primary",
+ "EXECUTOR": "anotherexecutor",
+ "MINIMUM_RETRY_THRESHOLD": 24,
+ },
+ ],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ def test_verify_multiple_executors(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ EphemeralBuilderManager.EXECUTORS["anotherexecutor"] = TestExecutor
+
+ self.manager.initialize(
+ {
+ "EXECUTORS": [
+ {
+ "NAME": "primary",
+ "EXECUTOR": "test",
+ "MINIMUM_RETRY_THRESHOLD": 42,
+ },
+ {
+ "NAME": "secondary",
+ "EXECUTOR": "anotherexecutor",
+ "MINIMUM_RETRY_THRESHOLD": 24,
+ },
+ ],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ # Ensure that we have a two test executors.
+ self.assertEqual(2, len(self.manager.registered_executors))
+ self.assertEqual(
+ 42, self.manager.registered_executors[0].minimum_retry_threshold
+ )
+ self.assertEqual(
+ 24, self.manager.registered_executors[1].minimum_retry_threshold
+ )
+
+ def test_skip_invalid_executor(self):
+ self.manager.initialize(
+ {
+ "EXECUTORS": [{"EXECUTOR": "unknown", "MINIMUM_RETRY_THRESHOLD": 42}],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ self.assertEqual(0, len(self.manager.registered_executors))
+
+ @async_test
+ def test_schedule_job_namespace_filter(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ self.manager.initialize(
+ {
+ "EXECUTORS": [
+ {"EXECUTOR": "test", "NAMESPACE_WHITELIST": ["something"]}
+ ],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ # Try with a build job in an invalid namespace.
+ build_job = self._create_build_job(namespace="somethingelse")
+ result = yield From(self.manager.schedule(build_job))
+ self.assertFalse(result[0])
+
+ # Try with a valid namespace.
+ build_job = self._create_build_job(namespace="something")
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ @async_test
+ def test_schedule_job_retries_filter(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+ self.manager.initialize(
+ {
+ "EXECUTORS": [{"EXECUTOR": "test", "MINIMUM_RETRY_THRESHOLD": 2}],
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ # Try with a build job that has too few retries.
+ build_job = self._create_build_job(retries=1)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertFalse(result[0])
+
+ # Try with a valid job.
+ build_job = self._create_build_job(retries=2)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ @async_test
+ def test_schedule_job_executor_fallback(self):
+ EphemeralBuilderManager.EXECUTORS["primary"] = TestExecutor
+ EphemeralBuilderManager.EXECUTORS["secondary"] = TestExecutor
+
+ self.manager.initialize(
+ {
+ "EXECUTORS": [
+ {
+ "NAME": "primary",
+ "EXECUTOR": "primary",
+ "NAMESPACE_WHITELIST": ["something"],
+ "MINIMUM_RETRY_THRESHOLD": 3,
+ },
+ {
+ "NAME": "secondary",
+ "EXECUTOR": "secondary",
+ "MINIMUM_RETRY_THRESHOLD": 2,
+ },
+ ],
+ "ALLOWED_WORKER_COUNT": 5,
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ # Try a job not matching the primary's namespace filter. Should schedule on secondary.
+ build_job = self._create_build_job(namespace="somethingelse")
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ self.assertIsNone(self.manager.registered_executors[0].job_started)
+ self.assertIsNotNone(self.manager.registered_executors[1].job_started)
+
+ self.manager.registered_executors[0].job_started = None
+ self.manager.registered_executors[1].job_started = None
+
+ # Try a job not matching the primary's retry minimum. Should schedule on secondary.
+ build_job = self._create_build_job(namespace="something", retries=2)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ self.assertIsNone(self.manager.registered_executors[0].job_started)
+ self.assertIsNotNone(self.manager.registered_executors[1].job_started)
+
+ self.manager.registered_executors[0].job_started = None
+ self.manager.registered_executors[1].job_started = None
+
+ # Try a job matching the primary. Should schedule on the primary.
+ build_job = self._create_build_job(namespace="something", retries=3)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ self.assertIsNotNone(self.manager.registered_executors[0].job_started)
+ self.assertIsNone(self.manager.registered_executors[1].job_started)
+
+ self.manager.registered_executors[0].job_started = None
+ self.manager.registered_executors[1].job_started = None
+
+ # Try a job not matching either's restrictions.
+ build_job = self._create_build_job(namespace="somethingelse", retries=1)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertFalse(result[0])
+
+ self.assertIsNone(self.manager.registered_executors[0].job_started)
+ self.assertIsNone(self.manager.registered_executors[1].job_started)
+
+ self.manager.registered_executors[0].job_started = None
+ self.manager.registered_executors[1].job_started = None
+
+ @async_test
+ def test_schedule_job_single_executor(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+
+ self.manager.initialize(
+ {
+ "EXECUTOR": "test",
+ "EXECUTOR_CONFIG": {},
+ "ALLOWED_WORKER_COUNT": 5,
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ build_job = self._create_build_job(namespace="something", retries=3)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ self.assertIsNotNone(self.manager.registered_executors[0].job_started)
+ self.manager.registered_executors[0].job_started = None
+
+ build_job = self._create_build_job(namespace="something", retries=0)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ self.assertIsNotNone(self.manager.registered_executors[0].job_started)
+ self.manager.registered_executors[0].job_started = None
+
+ @async_test
+ def test_executor_exception(self):
+ EphemeralBuilderManager.EXECUTORS["bad"] = BadExecutor
+
+ self.manager.initialize(
+ {
+ "EXECUTOR": "bad",
+ "EXECUTOR_CONFIG": {},
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ build_job = self._create_build_job(namespace="something", retries=3)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertFalse(result[0])
+
+ @async_test
+ def test_schedule_and_stop(self):
+ EphemeralBuilderManager.EXECUTORS["test"] = TestExecutor
+
+ self.manager.initialize(
+ {
+ "EXECUTOR": "test",
+ "EXECUTOR_CONFIG": {},
+ "ORCHESTRATOR": {"MEM_CONFIG": None},
+ }
+ )
+
+ # Start the build job.
+ build_job = self._create_build_job(namespace="something", retries=3)
+ result = yield From(self.manager.schedule(build_job))
+ self.assertTrue(result[0])
+
+ executor = self.manager.registered_executors[0]
+ self.assertIsNotNone(executor.job_started)
+
+ # Register the realm so the build information is added.
+ yield From(
+ self.manager._register_realm(
+ {
+ "realm": str(uuid.uuid4()),
+ "token": str(uuid.uuid4()),
+ "execution_id": executor.job_started,
+ "executor_name": "TestExecutor",
+ "build_uuid": build_job.build_uuid,
+ "job_queue_item": build_job.job_item,
+ }
+ )
+ )
+
+ # Stop the build job.
+ yield From(self.manager.kill_builder_executor(build_job.build_uuid))
+ self.assertEqual(executor.job_stopped, executor.job_started)
- def test_multiple_executors_samename(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- EphemeralBuilderManager.EXECUTORS['anotherexecutor'] = TestExecutor
-
- with self.assertRaises(Exception):
- self.manager.initialize({
- 'EXECUTORS': [
- {
- 'NAME': 'primary',
- 'EXECUTOR': 'test',
- 'MINIMUM_RETRY_THRESHOLD': 42
- },
- {
- 'NAME': 'primary',
- 'EXECUTOR': 'anotherexecutor',
- 'MINIMUM_RETRY_THRESHOLD': 24
- },
- ],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
-
- def test_verify_multiple_executors(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- EphemeralBuilderManager.EXECUTORS['anotherexecutor'] = TestExecutor
-
- self.manager.initialize({
- 'EXECUTORS': [
- {
- 'NAME': 'primary',
- 'EXECUTOR': 'test',
- 'MINIMUM_RETRY_THRESHOLD': 42
- },
- {
- 'NAME': 'secondary',
- 'EXECUTOR': 'anotherexecutor',
- 'MINIMUM_RETRY_THRESHOLD': 24
- },
- ],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- # Ensure that we have a two test executors.
- self.assertEqual(2, len(self.manager.registered_executors))
- self.assertEqual(42, self.manager.registered_executors[0].minimum_retry_threshold)
- self.assertEqual(24, self.manager.registered_executors[1].minimum_retry_threshold)
-
- def test_skip_invalid_executor(self):
- self.manager.initialize({
- 'EXECUTORS': [
- {
- 'EXECUTOR': 'unknown',
- 'MINIMUM_RETRY_THRESHOLD': 42
- },
- ],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- self.assertEqual(0, len(self.manager.registered_executors))
-
- @async_test
- def test_schedule_job_namespace_filter(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- self.manager.initialize({
- 'EXECUTORS': [{
- 'EXECUTOR': 'test',
- 'NAMESPACE_WHITELIST': ['something'],
- }],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- # Try with a build job in an invalid namespace.
- build_job = self._create_build_job(namespace='somethingelse')
- result = yield From(self.manager.schedule(build_job))
- self.assertFalse(result[0])
-
- # Try with a valid namespace.
- build_job = self._create_build_job(namespace='something')
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- @async_test
- def test_schedule_job_retries_filter(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
- self.manager.initialize({
- 'EXECUTORS': [{
- 'EXECUTOR': 'test',
- 'MINIMUM_RETRY_THRESHOLD': 2,
- }],
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- # Try with a build job that has too few retries.
- build_job = self._create_build_job(retries=1)
- result = yield From(self.manager.schedule(build_job))
- self.assertFalse(result[0])
-
- # Try with a valid job.
- build_job = self._create_build_job(retries=2)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- @async_test
- def test_schedule_job_executor_fallback(self):
- EphemeralBuilderManager.EXECUTORS['primary'] = TestExecutor
- EphemeralBuilderManager.EXECUTORS['secondary'] = TestExecutor
-
- self.manager.initialize({
- 'EXECUTORS': [
- {
- 'NAME': 'primary',
- 'EXECUTOR': 'primary',
- 'NAMESPACE_WHITELIST': ['something'],
- 'MINIMUM_RETRY_THRESHOLD': 3,
- },
- {
- 'NAME': 'secondary',
- 'EXECUTOR': 'secondary',
- 'MINIMUM_RETRY_THRESHOLD': 2,
- },
- ],
- 'ALLOWED_WORKER_COUNT': 5,
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- # Try a job not matching the primary's namespace filter. Should schedule on secondary.
- build_job = self._create_build_job(namespace='somethingelse')
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- self.assertIsNone(self.manager.registered_executors[0].job_started)
- self.assertIsNotNone(self.manager.registered_executors[1].job_started)
-
- self.manager.registered_executors[0].job_started = None
- self.manager.registered_executors[1].job_started = None
-
- # Try a job not matching the primary's retry minimum. Should schedule on secondary.
- build_job = self._create_build_job(namespace='something', retries=2)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- self.assertIsNone(self.manager.registered_executors[0].job_started)
- self.assertIsNotNone(self.manager.registered_executors[1].job_started)
-
- self.manager.registered_executors[0].job_started = None
- self.manager.registered_executors[1].job_started = None
-
- # Try a job matching the primary. Should schedule on the primary.
- build_job = self._create_build_job(namespace='something', retries=3)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- self.assertIsNotNone(self.manager.registered_executors[0].job_started)
- self.assertIsNone(self.manager.registered_executors[1].job_started)
-
- self.manager.registered_executors[0].job_started = None
- self.manager.registered_executors[1].job_started = None
-
- # Try a job not matching either's restrictions.
- build_job = self._create_build_job(namespace='somethingelse', retries=1)
- result = yield From(self.manager.schedule(build_job))
- self.assertFalse(result[0])
-
- self.assertIsNone(self.manager.registered_executors[0].job_started)
- self.assertIsNone(self.manager.registered_executors[1].job_started)
-
- self.manager.registered_executors[0].job_started = None
- self.manager.registered_executors[1].job_started = None
-
-
- @async_test
- def test_schedule_job_single_executor(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
-
- self.manager.initialize({
- 'EXECUTOR': 'test',
- 'EXECUTOR_CONFIG': {},
- 'ALLOWED_WORKER_COUNT': 5,
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- build_job = self._create_build_job(namespace='something', retries=3)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- self.assertIsNotNone(self.manager.registered_executors[0].job_started)
- self.manager.registered_executors[0].job_started = None
-
-
- build_job = self._create_build_job(namespace='something', retries=0)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- self.assertIsNotNone(self.manager.registered_executors[0].job_started)
- self.manager.registered_executors[0].job_started = None
-
- @async_test
- def test_executor_exception(self):
- EphemeralBuilderManager.EXECUTORS['bad'] = BadExecutor
-
- self.manager.initialize({
- 'EXECUTOR': 'bad',
- 'EXECUTOR_CONFIG': {},
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- build_job = self._create_build_job(namespace='something', retries=3)
- result = yield From(self.manager.schedule(build_job))
- self.assertFalse(result[0])
-
- @async_test
- def test_schedule_and_stop(self):
- EphemeralBuilderManager.EXECUTORS['test'] = TestExecutor
-
- self.manager.initialize({
- 'EXECUTOR': 'test',
- 'EXECUTOR_CONFIG': {},
- 'ORCHESTRATOR': {'MEM_CONFIG': None},
- })
-
- # Start the build job.
- build_job = self._create_build_job(namespace='something', retries=3)
- result = yield From(self.manager.schedule(build_job))
- self.assertTrue(result[0])
-
- executor = self.manager.registered_executors[0]
- self.assertIsNotNone(executor.job_started)
-
- # Register the realm so the build information is added.
- yield From(self.manager._register_realm({
- 'realm': str(uuid.uuid4()),
- 'token': str(uuid.uuid4()),
- 'execution_id': executor.job_started,
- 'executor_name': 'TestExecutor',
- 'build_uuid': build_job.build_uuid,
- 'job_queue_item': build_job.job_item,
- }))
-
- # Stop the build job.
- yield From(self.manager.kill_builder_executor(build_job.build_uuid))
- self.assertEqual(executor.job_stopped, executor.job_started)
-
-
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ unittest.main()
diff --git a/buildtrigger/__init__.py b/buildtrigger/__init__.py
index 8a794cf96..9c21d9025 100644
--- a/buildtrigger/__init__.py
+++ b/buildtrigger/__init__.py
@@ -2,4 +2,3 @@ import buildtrigger.bitbuckethandler
import buildtrigger.customhandler
import buildtrigger.githubhandler
import buildtrigger.gitlabhandler
-
diff --git a/buildtrigger/basehandler.py b/buildtrigger/basehandler.py
index 8d9b0f753..08bdb68ea 100644
--- a/buildtrigger/basehandler.py
+++ b/buildtrigger/basehandler.py
@@ -9,359 +9,360 @@ from data import model
from buildtrigger.triggerutil import get_trigger_config, InvalidServiceException
NAMESPACES_SCHEMA = {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'personal': {
- 'type': 'boolean',
- 'description': 'True if the namespace is the user\'s personal namespace',
- },
- 'score': {
- 'type': 'number',
- 'description': 'Score of the relevance of the namespace',
- },
- 'avatar_url': {
- 'type': ['string', 'null'],
- 'description': 'URL of the avatar for this namespace',
- },
- 'url': {
- 'type': 'string',
- 'description': 'URL of the website to view the namespace',
- },
- 'id': {
- 'type': 'string',
- 'description': 'Trigger-internal ID of the namespace',
- },
- 'title': {
- 'type': 'string',
- 'description': 'Human-readable title of the namespace',
- },
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "personal": {
+ "type": "boolean",
+ "description": "True if the namespace is the user's personal namespace",
+ },
+ "score": {
+ "type": "number",
+ "description": "Score of the relevance of the namespace",
+ },
+ "avatar_url": {
+ "type": ["string", "null"],
+ "description": "URL of the avatar for this namespace",
+ },
+ "url": {
+ "type": "string",
+ "description": "URL of the website to view the namespace",
+ },
+ "id": {
+ "type": "string",
+ "description": "Trigger-internal ID of the namespace",
+ },
+ "title": {
+ "type": "string",
+ "description": "Human-readable title of the namespace",
+ },
+ },
+ "required": ["personal", "score", "avatar_url", "id", "title"],
},
- 'required': ['personal', 'score', 'avatar_url', 'id', 'title'],
- },
}
BUILD_SOURCES_SCHEMA = {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The name of the repository, without its namespace',
- },
- 'full_name': {
- 'type': 'string',
- 'description': 'The name of the repository, with its namespace',
- },
- 'description': {
- 'type': 'string',
- 'description': 'The description of the repository. May be an empty string',
- },
- 'last_updated': {
- 'type': 'number',
- 'description': 'The date/time when the repository was last updated, since epoch in UTC',
- },
- 'url': {
- 'type': 'string',
- 'description': 'The URL at which to view the repository in the browser',
- },
- 'has_admin_permissions': {
- 'type': 'boolean',
- 'description': 'True if the current user has admin permissions on the repository',
- },
- 'private': {
- 'type': 'boolean',
- 'description': 'True if the repository is private',
- },
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the repository, without its namespace",
+ },
+ "full_name": {
+ "type": "string",
+ "description": "The name of the repository, with its namespace",
+ },
+ "description": {
+ "type": "string",
+ "description": "The description of the repository. May be an empty string",
+ },
+ "last_updated": {
+ "type": "number",
+ "description": "The date/time when the repository was last updated, since epoch in UTC",
+ },
+ "url": {
+ "type": "string",
+ "description": "The URL at which to view the repository in the browser",
+ },
+ "has_admin_permissions": {
+ "type": "boolean",
+ "description": "True if the current user has admin permissions on the repository",
+ },
+ "private": {
+ "type": "boolean",
+ "description": "True if the repository is private",
+ },
+ },
+ "required": [
+ "name",
+ "full_name",
+ "description",
+ "last_updated",
+ "has_admin_permissions",
+ "private",
+ ],
},
- 'required': ['name', 'full_name', 'description', 'last_updated',
- 'has_admin_permissions', 'private'],
- },
}
METADATA_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'commit': {
- 'type': 'string',
- 'description': 'first 7 characters of the SHA-1 identifier for a git commit',
- 'pattern': '^([A-Fa-f0-9]{7,})$',
- },
- 'git_url': {
- 'type': 'string',
- 'description': 'The GIT url to use for the checkout',
- },
- 'ref': {
- 'type': 'string',
- 'description': 'git reference for a git commit',
- 'pattern': r'^refs\/(heads|tags|remotes)\/(.+)$',
- },
- 'default_branch': {
- 'type': 'string',
- 'description': 'default branch of the git repository',
- },
- 'commit_info': {
- 'type': 'object',
- 'description': 'metadata about a git commit',
- 'properties': {
- 'url': {
- 'type': 'string',
- 'description': 'URL to view a git commit',
+ "type": "object",
+ "properties": {
+ "commit": {
+ "type": "string",
+ "description": "first 7 characters of the SHA-1 identifier for a git commit",
+ "pattern": "^([A-Fa-f0-9]{7,})$",
},
- 'message': {
- 'type': 'string',
- 'description': 'git commit message',
+ "git_url": {
+ "type": "string",
+ "description": "The GIT url to use for the checkout",
},
- 'date': {
- 'type': 'string',
- 'description': 'timestamp for a git commit'
+ "ref": {
+ "type": "string",
+ "description": "git reference for a git commit",
+ "pattern": r"^refs\/(heads|tags|remotes)\/(.+)$",
},
- 'author': {
- 'type': 'object',
- 'description': 'metadata about the author of a git commit',
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'username of the author',
- },
- 'url': {
- 'type': 'string',
- 'description': 'URL to view the profile of the author',
- },
- 'avatar_url': {
- 'type': 'string',
- 'description': 'URL to view the avatar of the author',
- },
- },
- 'required': ['username'],
+ "default_branch": {
+ "type": "string",
+ "description": "default branch of the git repository",
},
- 'committer': {
- 'type': 'object',
- 'description': 'metadata about the committer of a git commit',
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'username of the committer',
+ "commit_info": {
+ "type": "object",
+ "description": "metadata about a git commit",
+ "properties": {
+ "url": {"type": "string", "description": "URL to view a git commit"},
+ "message": {"type": "string", "description": "git commit message"},
+ "date": {"type": "string", "description": "timestamp for a git commit"},
+ "author": {
+ "type": "object",
+ "description": "metadata about the author of a git commit",
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "username of the author",
+ },
+ "url": {
+ "type": "string",
+ "description": "URL to view the profile of the author",
+ },
+ "avatar_url": {
+ "type": "string",
+ "description": "URL to view the avatar of the author",
+ },
+ },
+ "required": ["username"],
+ },
+ "committer": {
+ "type": "object",
+ "description": "metadata about the committer of a git commit",
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "username of the committer",
+ },
+ "url": {
+ "type": "string",
+ "description": "URL to view the profile of the committer",
+ },
+ "avatar_url": {
+ "type": "string",
+ "description": "URL to view the avatar of the committer",
+ },
+ },
+ "required": ["username"],
+ },
},
- 'url': {
- 'type': 'string',
- 'description': 'URL to view the profile of the committer',
- },
- 'avatar_url': {
- 'type': 'string',
- 'description': 'URL to view the avatar of the committer',
- },
- },
- 'required': ['username'],
+ "required": ["message"],
},
- },
- 'required': ['message'],
},
- },
- 'required': ['commit', 'git_url'],
+ "required": ["commit", "git_url"],
}
@add_metaclass(ABCMeta)
class BuildTriggerHandler(object):
- def __init__(self, trigger, override_config=None):
- self.trigger = trigger
- self.config = override_config or get_trigger_config(trigger)
+ def __init__(self, trigger, override_config=None):
+ self.trigger = trigger
+ self.config = override_config or get_trigger_config(trigger)
- @property
- def auth_token(self):
- """ Returns the auth token for the trigger. """
- # NOTE: This check is for testing.
- if isinstance(self.trigger.auth_token, str):
- return self.trigger.auth_token
+ @property
+ def auth_token(self):
+ """ Returns the auth token for the trigger. """
+ # NOTE: This check is for testing.
+ if isinstance(self.trigger.auth_token, str):
+ return self.trigger.auth_token
- # TODO(remove-unenc): Remove legacy field.
- if self.trigger.secure_auth_token is not None:
- return self.trigger.secure_auth_token.decrypt()
+ # TODO(remove-unenc): Remove legacy field.
+ if self.trigger.secure_auth_token is not None:
+ return self.trigger.secure_auth_token.decrypt()
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- return self.trigger.auth_token
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ return self.trigger.auth_token
- return None
+ return None
- @abstractmethod
- def load_dockerfile_contents(self):
- """
+ @abstractmethod
+ def load_dockerfile_contents(self):
+ """
Loads the Dockerfile found for the trigger's config and returns them or None if none could
be found/loaded.
"""
- pass
+ pass
- @abstractmethod
- def list_build_source_namespaces(self):
- """
+ @abstractmethod
+ def list_build_source_namespaces(self):
+ """
Take the auth information for the specific trigger type and load the
list of namespaces that can contain build sources.
"""
- pass
+ pass
- @abstractmethod
- def list_build_sources_for_namespace(self, namespace):
- """
+ @abstractmethod
+ def list_build_sources_for_namespace(self, namespace):
+ """
Take the auth information for the specific trigger type and load the
list of repositories under the given namespace.
"""
- pass
+ pass
- @abstractmethod
- def list_build_subdirs(self):
- """
+ @abstractmethod
+ def list_build_subdirs(self):
+ """
Take the auth information and the specified config so far and list all of
the possible subdirs containing dockerfiles.
"""
- pass
+ pass
- @abstractmethod
- def handle_trigger_request(self, request):
- """
+ @abstractmethod
+ def handle_trigger_request(self, request):
+ """
Transform the incoming request data into a set of actions. Returns a PreparedBuild.
"""
- pass
+ pass
- @abstractmethod
- def is_active(self):
- """
+ @abstractmethod
+ def is_active(self):
+ """
Returns True if the current build trigger is active. Inactive means further
setup is needed.
"""
- pass
+ pass
- @abstractmethod
- def activate(self, standard_webhook_url):
- """
+ @abstractmethod
+ def activate(self, standard_webhook_url):
+ """
Activates the trigger for the service, with the given new configuration.
Returns new public and private config that should be stored if successful.
"""
- pass
+ pass
- @abstractmethod
- def deactivate(self):
- """
+ @abstractmethod
+ def deactivate(self):
+ """
Deactivates the trigger for the service, removing any hooks installed in
the remote service. Returns the new config that should be stored if this
trigger is going to be re-activated.
"""
- pass
+ pass
- @abstractmethod
- def manual_start(self, run_parameters=None):
- """
+ @abstractmethod
+ def manual_start(self, run_parameters=None):
+ """
Manually creates a repository build for this trigger. Returns a PreparedBuild.
"""
- pass
+ pass
- @abstractmethod
- def list_field_values(self, field_name, limit=None):
- """
+ @abstractmethod
+ def list_field_values(self, field_name, limit=None):
+ """
Lists all values for the given custom trigger field. For example, a trigger might have a
field named "branches", and this method would return all branches.
"""
- pass
+ pass
- @abstractmethod
- def get_repository_url(self):
- """ Returns the URL of the current trigger's repository. Note that this operation
+ @abstractmethod
+ def get_repository_url(self):
+ """ Returns the URL of the current trigger's repository. Note that this operation
can be called in a loop, so it should be as fast as possible. """
- pass
+ pass
- @classmethod
- def filename_is_dockerfile(cls, file_name):
- """ Returns whether the file is named Dockerfile or follows the convention .Dockerfile"""
- return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name
+ @classmethod
+ def filename_is_dockerfile(cls, file_name):
+ """ Returns whether the file is named Dockerfile or follows the convention .Dockerfile"""
+ return file_name.endswith(".Dockerfile") or u"Dockerfile" == file_name
- @classmethod
- def service_name(cls):
- """
+ @classmethod
+ def service_name(cls):
+ """
Particular service implemented by subclasses.
"""
- raise NotImplementedError
+ raise NotImplementedError
- @classmethod
- def get_handler(cls, trigger, override_config=None):
- for subc in cls.__subclasses__():
- if subc.service_name() == trigger.service.name:
- return subc(trigger, override_config)
+ @classmethod
+ def get_handler(cls, trigger, override_config=None):
+ for subc in cls.__subclasses__():
+ if subc.service_name() == trigger.service.name:
+ return subc(trigger, override_config)
- raise InvalidServiceException('Unable to find service: %s' % trigger.service.name)
+ raise InvalidServiceException(
+ "Unable to find service: %s" % trigger.service.name
+ )
- def put_config_key(self, key, value):
- """ Updates a config key in the trigger, saving it to the DB. """
- self.config[key] = value
- model.build.update_build_trigger(self.trigger, self.config)
+ def put_config_key(self, key, value):
+ """ Updates a config key in the trigger, saving it to the DB. """
+ self.config[key] = value
+ model.build.update_build_trigger(self.trigger, self.config)
- def set_auth_token(self, auth_token):
- """ Sets the auth token for the trigger, saving it to the DB. """
- model.build.update_build_trigger(self.trigger, self.config, auth_token=auth_token)
+ def set_auth_token(self, auth_token):
+ """ Sets the auth token for the trigger, saving it to the DB. """
+ model.build.update_build_trigger(
+ self.trigger, self.config, auth_token=auth_token
+ )
- def get_dockerfile_path(self):
- """ Returns the normalized path to the Dockerfile found in the subdirectory
+ def get_dockerfile_path(self):
+ """ Returns the normalized path to the Dockerfile found in the subdirectory
in the config. """
- dockerfile_path = self.config.get('dockerfile_path') or 'Dockerfile'
- if dockerfile_path[0] == '/':
- dockerfile_path = dockerfile_path[1:]
- return dockerfile_path
+ dockerfile_path = self.config.get("dockerfile_path") or "Dockerfile"
+ if dockerfile_path[0] == "/":
+ dockerfile_path = dockerfile_path[1:]
+ return dockerfile_path
- def prepare_build(self, metadata, is_manual=False):
- # Ensure that the metadata meets the scheme.
- validate(metadata, METADATA_SCHEMA)
+ def prepare_build(self, metadata, is_manual=False):
+ # Ensure that the metadata meets the scheme.
+ validate(metadata, METADATA_SCHEMA)
- config = self.config
- ref = metadata.get('ref', None)
- commit_sha = metadata['commit']
- default_branch = metadata.get('default_branch', None)
- prepared = PreparedBuild(self.trigger)
- prepared.name_from_sha(commit_sha)
- prepared.subdirectory = config.get('dockerfile_path', None)
- prepared.context = config.get('context', None)
- prepared.is_manual = is_manual
- prepared.metadata = metadata
+ config = self.config
+ ref = metadata.get("ref", None)
+ commit_sha = metadata["commit"]
+ default_branch = metadata.get("default_branch", None)
+ prepared = PreparedBuild(self.trigger)
+ prepared.name_from_sha(commit_sha)
+ prepared.subdirectory = config.get("dockerfile_path", None)
+ prepared.context = config.get("context", None)
+ prepared.is_manual = is_manual
+ prepared.metadata = metadata
- if ref is not None:
- prepared.tags_from_ref(ref, default_branch)
- else:
- prepared.tags = [commit_sha[:7]]
+ if ref is not None:
+ prepared.tags_from_ref(ref, default_branch)
+ else:
+ prepared.tags = [commit_sha[:7]]
- return prepared
+ return prepared
- @classmethod
- def build_sources_response(cls, sources):
- validate(sources, BUILD_SOURCES_SCHEMA)
- return sources
+ @classmethod
+ def build_sources_response(cls, sources):
+ validate(sources, BUILD_SOURCES_SCHEMA)
+ return sources
- @classmethod
- def build_namespaces_response(cls, namespaces_dict):
- namespaces = list(namespaces_dict.values())
- validate(namespaces, NAMESPACES_SCHEMA)
- return namespaces
+ @classmethod
+ def build_namespaces_response(cls, namespaces_dict):
+ namespaces = list(namespaces_dict.values())
+ validate(namespaces, NAMESPACES_SCHEMA)
+ return namespaces
- @classmethod
- def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None):
- """ Returns a map of dockerfile_paths to it's possible contexts. """
- if dockerfile_path == "":
- return {}
+ @classmethod
+ def get_parent_directory_mappings(cls, dockerfile_path, current_paths=None):
+ """ Returns a map of dockerfile_paths to it's possible contexts. """
+ if dockerfile_path == "":
+ return {}
- if dockerfile_path[0] != os.path.sep:
- dockerfile_path = os.path.sep + dockerfile_path
+ if dockerfile_path[0] != os.path.sep:
+ dockerfile_path = os.path.sep + dockerfile_path
- dockerfile_path = os.path.normpath(dockerfile_path)
- all_paths = set()
- path, _ = os.path.split(dockerfile_path)
- if path == "":
- path = os.path.sep
+ dockerfile_path = os.path.normpath(dockerfile_path)
+ all_paths = set()
+ path, _ = os.path.split(dockerfile_path)
+ if path == "":
+ path = os.path.sep
- all_paths.add(path)
- for i in range(1, len(path.split(os.path.sep))):
- path, _ = os.path.split(path)
- all_paths.add(path)
+ all_paths.add(path)
+ for i in range(1, len(path.split(os.path.sep))):
+ path, _ = os.path.split(path)
+ all_paths.add(path)
- if current_paths:
- return dict({dockerfile_path: list(all_paths)}, **current_paths)
+ if current_paths:
+ return dict({dockerfile_path: list(all_paths)}, **current_paths)
- return {dockerfile_path: list(all_paths)}
+ return {dockerfile_path: list(all_paths)}
diff --git a/buildtrigger/bitbuckethandler.py b/buildtrigger/bitbuckethandler.py
index 9573f5c60..c74a06c0d 100644
--- a/buildtrigger/bitbuckethandler.py
+++ b/buildtrigger/bitbuckethandler.py
@@ -9,537 +9,551 @@ from jsonschema import validate
from app import app, get_app_url
from buildtrigger.basehandler import BuildTriggerHandler
-from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
- TriggerDeactivationException, TriggerStartException,
- InvalidPayloadException, TriggerProviderException,
- SkipRequestException,
- determine_build_ref, raise_if_skipped_build,
- find_matching_branches)
+from buildtrigger.triggerutil import (
+ RepositoryReadException,
+ TriggerActivationException,
+ TriggerDeactivationException,
+ TriggerStartException,
+ InvalidPayloadException,
+ TriggerProviderException,
+ SkipRequestException,
+ determine_build_ref,
+ raise_if_skipped_build,
+ find_matching_branches,
+)
from util.dict_wrappers import JSONPathDict, SafeDictSetter
from util.security.ssh import generate_ssh_keypair
logger = logging.getLogger(__name__)
-_BITBUCKET_COMMIT_URL = 'https://bitbucket.org/%s/commits/%s'
-_RAW_AUTHOR_REGEX = re.compile(r'.*<(.+)>')
+_BITBUCKET_COMMIT_URL = "https://bitbucket.org/%s/commits/%s"
+_RAW_AUTHOR_REGEX = re.compile(r".*<(.+)>")
BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'repository': {
- 'type': 'object',
- 'properties': {
- 'full_name': {
- 'type': 'string',
- },
- },
- 'required': ['full_name'],
- }, # /Repository
- 'push': {
- 'type': 'object',
- 'properties': {
- 'changes': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'new': {
- 'type': 'object',
- 'properties': {
- 'target': {
- 'type': 'object',
- 'properties': {
- 'hash': {
- 'type': 'string'
- },
- 'message': {
- 'type': 'string'
- },
- 'date': {
- 'type': 'string'
- },
- 'author': {
- 'type': 'object',
- 'properties': {
- 'user': {
- 'type': 'object',
- 'properties': {
- 'display_name': {
- 'type': 'string',
- },
- 'account_id': {
- 'type': 'string',
- },
- 'links': {
- 'type': 'object',
- 'properties': {
- 'avatar': {
- 'type': 'object',
- 'properties': {
- 'href': {
- 'type': 'string',
- },
- },
- 'required': ['href'],
- },
+ "type": "object",
+ "properties": {
+ "repository": {
+ "type": "object",
+ "properties": {"full_name": {"type": "string"}},
+ "required": ["full_name"],
+ }, # /Repository
+ "push": {
+ "type": "object",
+ "properties": {
+ "changes": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "new": {
+ "type": "object",
+ "properties": {
+ "target": {
+ "type": "object",
+ "properties": {
+ "hash": {"type": "string"},
+ "message": {"type": "string"},
+ "date": {"type": "string"},
+ "author": {
+ "type": "object",
+ "properties": {
+ "user": {
+ "type": "object",
+ "properties": {
+ "display_name": {
+ "type": "string"
+ },
+ "account_id": {
+ "type": "string"
+ },
+ "links": {
+ "type": "object",
+ "properties": {
+ "avatar": {
+ "type": "object",
+ "properties": {
+ "href": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "href"
+ ],
+ }
+ },
+ "required": ["avatar"],
+ }, # /User
+ },
+ } # /Author
+ },
+ },
+ },
+ "required": ["hash", "message", "date"],
+ } # /Target
},
- 'required': ['avatar'],
- }, # /User
- },
- }, # /Author
+ "required": ["name", "target"],
+ } # /New
},
- },
- },
- 'required': ['hash', 'message', 'date'],
- }, # /Target
- },
- 'required': ['name', 'target'],
- }, # /New
+ }, # /Changes item
+ } # /Changes
},
- }, # /Changes item
- }, # /Changes
- },
- 'required': ['changes'],
- }, # / Push
- },
- 'actor': {
- 'type': 'object',
- 'properties': {
- 'account_id': {
- 'type': 'string',
- },
- 'display_name': {
- 'type': 'string',
- },
- 'links': {
- 'type': 'object',
- 'properties': {
- 'avatar': {
- 'type': 'object',
- 'properties': {
- 'href': {
- 'type': 'string',
- },
- },
- 'required': ['href'],
- },
- },
- 'required': ['avatar'],
- },
+ "required": ["changes"],
+ }, # / Push
},
- }, # /Actor
- 'required': ['push', 'repository'],
-} # /Root
+ "actor": {
+ "type": "object",
+ "properties": {
+ "account_id": {"type": "string"},
+ "display_name": {"type": "string"},
+ "links": {
+ "type": "object",
+ "properties": {
+ "avatar": {
+ "type": "object",
+ "properties": {"href": {"type": "string"}},
+ "required": ["href"],
+ }
+ },
+ "required": ["avatar"],
+ },
+ },
+ }, # /Actor
+ "required": ["push", "repository"],
+} # /Root
BITBUCKET_COMMIT_INFO_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'node': {
- 'type': 'string',
+ "type": "object",
+ "properties": {
+ "node": {"type": "string"},
+ "message": {"type": "string"},
+ "timestamp": {"type": "string"},
+ "raw_author": {"type": "string"},
},
- 'message': {
- 'type': 'string',
- },
- 'timestamp': {
- 'type': 'string',
- },
- 'raw_author': {
- 'type': 'string',
- },
- },
- 'required': ['node', 'message', 'timestamp']
+ "required": ["node", "message", "timestamp"],
}
-def get_transformed_commit_info(bb_commit, ref, default_branch, repository_name, lookup_author):
- """ Returns the BitBucket commit information transformed into our own
+
+def get_transformed_commit_info(
+ bb_commit, ref, default_branch, repository_name, lookup_author
+):
+ """ Returns the BitBucket commit information transformed into our own
payload format.
"""
- try:
- validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
- except Exception as exc:
- logger.exception('Exception when validating Bitbucket commit information: %s from %s', exc.message, bb_commit)
- raise InvalidPayloadException(exc.message)
+ try:
+ validate(bb_commit, BITBUCKET_COMMIT_INFO_SCHEMA)
+ except Exception as exc:
+ logger.exception(
+ "Exception when validating Bitbucket commit information: %s from %s",
+ exc.message,
+ bb_commit,
+ )
+ raise InvalidPayloadException(exc.message)
- commit = JSONPathDict(bb_commit)
+ commit = JSONPathDict(bb_commit)
- config = SafeDictSetter()
- config['commit'] = commit['node']
- config['ref'] = ref
- config['default_branch'] = default_branch
- config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name
+ config = SafeDictSetter()
+ config["commit"] = commit["node"]
+ config["ref"] = ref
+ config["default_branch"] = default_branch
+ config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
- config['commit_info.url'] = _BITBUCKET_COMMIT_URL % (repository_name, commit['node'])
- config['commit_info.message'] = commit['message']
- config['commit_info.date'] = commit['timestamp']
+ config["commit_info.url"] = _BITBUCKET_COMMIT_URL % (
+ repository_name,
+ commit["node"],
+ )
+ config["commit_info.message"] = commit["message"]
+ config["commit_info.date"] = commit["timestamp"]
- match = _RAW_AUTHOR_REGEX.match(commit['raw_author'])
- if match:
- author = lookup_author(match.group(1))
- author_info = JSONPathDict(author) if author is not None else None
- if author_info:
- config['commit_info.author.username'] = author_info['user.display_name']
- config['commit_info.author.avatar_url'] = author_info['user.avatar']
+ match = _RAW_AUTHOR_REGEX.match(commit["raw_author"])
+ if match:
+ author = lookup_author(match.group(1))
+ author_info = JSONPathDict(author) if author is not None else None
+ if author_info:
+ config["commit_info.author.username"] = author_info["user.display_name"]
+ config["commit_info.author.avatar_url"] = author_info["user.avatar"]
- return config.dict_value()
+ return config.dict_value()
def get_transformed_webhook_payload(bb_payload, default_branch=None):
- """ Returns the BitBucket webhook JSON payload transformed into our own payload
+ """ Returns the BitBucket webhook JSON payload transformed into our own payload
format. If the bb_payload is not valid, returns None.
"""
- try:
- validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
- except Exception as exc:
- logger.exception('Exception when validating Bitbucket webhook payload: %s from %s', exc.message,
- bb_payload)
- raise InvalidPayloadException(exc.message)
+ try:
+ validate(bb_payload, BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA)
+ except Exception as exc:
+ logger.exception(
+ "Exception when validating Bitbucket webhook payload: %s from %s",
+ exc.message,
+ bb_payload,
+ )
+ raise InvalidPayloadException(exc.message)
- payload = JSONPathDict(bb_payload)
- change = payload['push.changes[-1].new']
- if not change:
- raise SkipRequestException
+ payload = JSONPathDict(bb_payload)
+ change = payload["push.changes[-1].new"]
+ if not change:
+ raise SkipRequestException
- is_branch = change['type'] == 'branch'
- ref = 'refs/heads/' + change['name'] if is_branch else 'refs/tags/' + change['name']
+ is_branch = change["type"] == "branch"
+ ref = "refs/heads/" + change["name"] if is_branch else "refs/tags/" + change["name"]
- repository_name = payload['repository.full_name']
- target = change['target']
+ repository_name = payload["repository.full_name"]
+ target = change["target"]
- config = SafeDictSetter()
- config['commit'] = target['hash']
- config['ref'] = ref
- config['default_branch'] = default_branch
- config['git_url'] = 'git@bitbucket.org:%s.git' % repository_name
+ config = SafeDictSetter()
+ config["commit"] = target["hash"]
+ config["ref"] = ref
+ config["default_branch"] = default_branch
+ config["git_url"] = "git@bitbucket.org:%s.git" % repository_name
- config['commit_info.url'] = target['links.html.href'] or ''
- config['commit_info.message'] = target['message']
- config['commit_info.date'] = target['date']
+ config["commit_info.url"] = target["links.html.href"] or ""
+ config["commit_info.message"] = target["message"]
+ config["commit_info.date"] = target["date"]
- config['commit_info.author.username'] = target['author.user.display_name']
- config['commit_info.author.avatar_url'] = target['author.user.links.avatar.href']
+ config["commit_info.author.username"] = target["author.user.display_name"]
+ config["commit_info.author.avatar_url"] = target["author.user.links.avatar.href"]
- config['commit_info.committer.username'] = payload['actor.display_name']
- config['commit_info.committer.avatar_url'] = payload['actor.links.avatar.href']
- return config.dict_value()
+ config["commit_info.committer.username"] = payload["actor.display_name"]
+ config["commit_info.committer.avatar_url"] = payload["actor.links.avatar.href"]
+ return config.dict_value()
class BitbucketBuildTrigger(BuildTriggerHandler):
- """
+ """
BuildTrigger for Bitbucket.
"""
- @classmethod
- def service_name(cls):
- return 'bitbucket'
- def _get_client(self):
- """ Returns a BitBucket API client for this trigger's config. """
- key = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_KEY', '')
- secret = app.config.get('BITBUCKET_TRIGGER_CONFIG', {}).get('CONSUMER_SECRET', '')
+ @classmethod
+ def service_name(cls):
+ return "bitbucket"
- trigger_uuid = self.trigger.uuid
- callback_url = '%s/oauth1/bitbucket/callback/trigger/%s' % (get_app_url(), trigger_uuid)
+ def _get_client(self):
+ """ Returns a BitBucket API client for this trigger's config. """
+ key = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get("CONSUMER_KEY", "")
+ secret = app.config.get("BITBUCKET_TRIGGER_CONFIG", {}).get(
+ "CONSUMER_SECRET", ""
+ )
- return BitBucket(key, secret, callback_url, timeout=15)
+ trigger_uuid = self.trigger.uuid
+ callback_url = "%s/oauth1/bitbucket/callback/trigger/%s" % (
+ get_app_url(),
+ trigger_uuid,
+ )
- def _get_authorized_client(self):
- """ Returns an authorized API client. """
- base_client = self._get_client()
- auth_token = self.auth_token or 'invalid:invalid'
- token_parts = auth_token.split(':')
- if len(token_parts) != 2:
- token_parts = ['invalid', 'invalid']
+ return BitBucket(key, secret, callback_url, timeout=15)
- (access_token, access_token_secret) = token_parts
- return base_client.get_authorized_client(access_token, access_token_secret)
+ def _get_authorized_client(self):
+ """ Returns an authorized API client. """
+ base_client = self._get_client()
+ auth_token = self.auth_token or "invalid:invalid"
+ token_parts = auth_token.split(":")
+ if len(token_parts) != 2:
+ token_parts = ["invalid", "invalid"]
- def _get_repository_client(self):
- """ Returns an API client for working with this config's BB repository. """
- source = self.config['build_source']
- (namespace, name) = source.split('/')
- bitbucket_client = self._get_authorized_client()
- return bitbucket_client.for_namespace(namespace).repositories().get(name)
+ (access_token, access_token_secret) = token_parts
+ return base_client.get_authorized_client(access_token, access_token_secret)
- def _get_default_branch(self, repository, default_value='master'):
- """ Returns the default branch for the repository or the value given. """
- (result, data, _) = repository.get_main_branch()
- if result:
- return data['name']
+ def _get_repository_client(self):
+ """ Returns an API client for working with this config's BB repository. """
+ source = self.config["build_source"]
+ (namespace, name) = source.split("/")
+ bitbucket_client = self._get_authorized_client()
+ return bitbucket_client.for_namespace(namespace).repositories().get(name)
- return default_value
+ def _get_default_branch(self, repository, default_value="master"):
+ """ Returns the default branch for the repository or the value given. """
+ (result, data, _) = repository.get_main_branch()
+ if result:
+ return data["name"]
- def get_oauth_url(self):
- """ Returns the OAuth URL to authorize Bitbucket. """
- bitbucket_client = self._get_client()
- (result, data, err_msg) = bitbucket_client.get_authorization_url()
- if not result:
- raise TriggerProviderException(err_msg)
+ return default_value
- return data
+ def get_oauth_url(self):
+ """ Returns the OAuth URL to authorize Bitbucket. """
+ bitbucket_client = self._get_client()
+ (result, data, err_msg) = bitbucket_client.get_authorization_url()
+ if not result:
+ raise TriggerProviderException(err_msg)
- def exchange_verifier(self, verifier):
- """ Exchanges the given verifier token to setup this trigger. """
- bitbucket_client = self._get_client()
- access_token = self.config.get('access_token', '')
- access_token_secret = self.auth_token
+ return data
- # Exchange the verifier for a new access token.
- (result, data, _) = bitbucket_client.verify_token(access_token, access_token_secret, verifier)
- if not result:
- return False
+ def exchange_verifier(self, verifier):
+ """ Exchanges the given verifier token to setup this trigger. """
+ bitbucket_client = self._get_client()
+ access_token = self.config.get("access_token", "")
+ access_token_secret = self.auth_token
- # Save the updated access token and secret.
- self.set_auth_token(data[0] + ':' + data[1])
+ # Exchange the verifier for a new access token.
+ (result, data, _) = bitbucket_client.verify_token(
+ access_token, access_token_secret, verifier
+ )
+ if not result:
+ return False
- # Retrieve the current authorized user's information and store the username in the config.
- authorized_client = self._get_authorized_client()
- (result, data, _) = authorized_client.get_current_user()
- if not result:
- return False
+ # Save the updated access token and secret.
+ self.set_auth_token(data[0] + ":" + data[1])
- self.put_config_key('account_id', data['user']['account_id'])
- self.put_config_key('nickname', data['user']['nickname'])
- return True
+ # Retrieve the current authorized user's information and store the username in the config.
+ authorized_client = self._get_authorized_client()
+ (result, data, _) = authorized_client.get_current_user()
+ if not result:
+ return False
- def is_active(self):
- return 'webhook_id' in self.config
+ self.put_config_key("account_id", data["user"]["account_id"])
+ self.put_config_key("nickname", data["user"]["nickname"])
+ return True
- def activate(self, standard_webhook_url):
- config = self.config
+ def is_active(self):
+ return "webhook_id" in self.config
- # Add a deploy key to the repository.
- public_key, private_key = generate_ssh_keypair()
- config['credentials'] = [
- {
- 'name': 'SSH Public Key',
- 'value': public_key,
- },
- ]
+ def activate(self, standard_webhook_url):
+ config = self.config
- repository = self._get_repository_client()
- (result, created_deploykey, err_msg) = repository.deploykeys().create(
- app.config['REGISTRY_TITLE'] + ' webhook key', public_key)
+ # Add a deploy key to the repository.
+ public_key, private_key = generate_ssh_keypair()
+ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
- if not result:
- msg = 'Unable to add deploy key to repository: %s' % err_msg
- raise TriggerActivationException(msg)
+ repository = self._get_repository_client()
+ (result, created_deploykey, err_msg) = repository.deploykeys().create(
+ app.config["REGISTRY_TITLE"] + " webhook key", public_key
+ )
- config['deploy_key_id'] = created_deploykey['pk']
+ if not result:
+ msg = "Unable to add deploy key to repository: %s" % err_msg
+ raise TriggerActivationException(msg)
- # Add a webhook callback.
- description = 'Webhook for invoking builds on %s' % app.config['REGISTRY_TITLE_SHORT']
- webhook_events = ['repo:push']
- (result, created_webhook, err_msg) = repository.webhooks().create(
- description, standard_webhook_url, webhook_events)
+ config["deploy_key_id"] = created_deploykey["pk"]
- if not result:
- msg = 'Unable to add webhook to repository: %s' % err_msg
- raise TriggerActivationException(msg)
+ # Add a webhook callback.
+ description = (
+ "Webhook for invoking builds on %s" % app.config["REGISTRY_TITLE_SHORT"]
+ )
+ webhook_events = ["repo:push"]
+ (result, created_webhook, err_msg) = repository.webhooks().create(
+ description, standard_webhook_url, webhook_events
+ )
- config['webhook_id'] = created_webhook['uuid']
- self.config = config
- return config, {'private_key': private_key}
+ if not result:
+ msg = "Unable to add webhook to repository: %s" % err_msg
+ raise TriggerActivationException(msg)
- def deactivate(self):
- config = self.config
+ config["webhook_id"] = created_webhook["uuid"]
+ self.config = config
+ return config, {"private_key": private_key}
- webhook_id = config.pop('webhook_id', None)
- deploy_key_id = config.pop('deploy_key_id', None)
- repository = self._get_repository_client()
+ def deactivate(self):
+ config = self.config
- # Remove the webhook.
- if webhook_id is not None:
- (result, _, err_msg) = repository.webhooks().delete(webhook_id)
- if not result:
- msg = 'Unable to remove webhook from repository: %s' % err_msg
- raise TriggerDeactivationException(msg)
+ webhook_id = config.pop("webhook_id", None)
+ deploy_key_id = config.pop("deploy_key_id", None)
+ repository = self._get_repository_client()
- # Remove the public key.
- if deploy_key_id is not None:
- (result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
- if not result:
- msg = 'Unable to remove deploy key from repository: %s' % err_msg
- raise TriggerDeactivationException(msg)
+ # Remove the webhook.
+ if webhook_id is not None:
+ (result, _, err_msg) = repository.webhooks().delete(webhook_id)
+ if not result:
+ msg = "Unable to remove webhook from repository: %s" % err_msg
+ raise TriggerDeactivationException(msg)
- return config
+ # Remove the public key.
+ if deploy_key_id is not None:
+ (result, _, err_msg) = repository.deploykeys().delete(deploy_key_id)
+ if not result:
+ msg = "Unable to remove deploy key from repository: %s" % err_msg
+ raise TriggerDeactivationException(msg)
- def list_build_source_namespaces(self):
- bitbucket_client = self._get_authorized_client()
- (result, data, err_msg) = bitbucket_client.get_visible_repositories()
- if not result:
- raise RepositoryReadException('Could not read repository list: ' + err_msg)
+ return config
- namespaces = {}
- for repo in data:
- owner = repo['owner']
+ def list_build_source_namespaces(self):
+ bitbucket_client = self._get_authorized_client()
+ (result, data, err_msg) = bitbucket_client.get_visible_repositories()
+ if not result:
+ raise RepositoryReadException("Could not read repository list: " + err_msg)
- if owner in namespaces:
- namespaces[owner]['score'] = namespaces[owner]['score'] + 1
- else:
- namespaces[owner] = {
- 'personal': owner == self.config.get('nickname', self.config.get('username')),
- 'id': owner,
- 'title': owner,
- 'avatar_url': repo['logo'],
- 'url': 'https://bitbucket.org/%s' % (owner),
- 'score': 1,
- }
+ namespaces = {}
+ for repo in data:
+ owner = repo["owner"]
- return BuildTriggerHandler.build_namespaces_response(namespaces)
+ if owner in namespaces:
+ namespaces[owner]["score"] = namespaces[owner]["score"] + 1
+ else:
+ namespaces[owner] = {
+ "personal": owner
+ == self.config.get("nickname", self.config.get("username")),
+ "id": owner,
+ "title": owner,
+ "avatar_url": repo["logo"],
+ "url": "https://bitbucket.org/%s" % (owner),
+ "score": 1,
+ }
- def list_build_sources_for_namespace(self, namespace):
- def repo_view(repo):
- last_modified = dateutil.parser.parse(repo['utc_last_updated'])
+ return BuildTriggerHandler.build_namespaces_response(namespaces)
- return {
- 'name': repo['slug'],
- 'full_name': '%s/%s' % (repo['owner'], repo['slug']),
- 'description': repo['description'] or '',
- 'last_updated': timegm(last_modified.utctimetuple()),
- 'url': 'https://bitbucket.org/%s/%s' % (repo['owner'], repo['slug']),
- 'has_admin_permissions': repo['read_only'] is False,
- 'private': repo['is_private'],
- }
+ def list_build_sources_for_namespace(self, namespace):
+ def repo_view(repo):
+ last_modified = dateutil.parser.parse(repo["utc_last_updated"])
- bitbucket_client = self._get_authorized_client()
- (result, data, err_msg) = bitbucket_client.get_visible_repositories()
- if not result:
- raise RepositoryReadException('Could not read repository list: ' + err_msg)
+ return {
+ "name": repo["slug"],
+ "full_name": "%s/%s" % (repo["owner"], repo["slug"]),
+ "description": repo["description"] or "",
+ "last_updated": timegm(last_modified.utctimetuple()),
+ "url": "https://bitbucket.org/%s/%s" % (repo["owner"], repo["slug"]),
+ "has_admin_permissions": repo["read_only"] is False,
+ "private": repo["is_private"],
+ }
- repos = [repo_view(repo) for repo in data if repo['owner'] == namespace]
- return BuildTriggerHandler.build_sources_response(repos)
+ bitbucket_client = self._get_authorized_client()
+ (result, data, err_msg) = bitbucket_client.get_visible_repositories()
+ if not result:
+ raise RepositoryReadException("Could not read repository list: " + err_msg)
- def list_build_subdirs(self):
- config = self.config
- repository = self._get_repository_client()
+ repos = [repo_view(repo) for repo in data if repo["owner"] == namespace]
+ return BuildTriggerHandler.build_sources_response(repos)
- # Find the first matching branch.
- repo_branches = self.list_field_values('branch_name') or []
- branches = find_matching_branches(config, repo_branches)
- if not branches:
- branches = [self._get_default_branch(repository)]
+ def list_build_subdirs(self):
+ config = self.config
+ repository = self._get_repository_client()
- (result, data, err_msg) = repository.get_path_contents('', revision=branches[0])
- if not result:
- raise RepositoryReadException(err_msg)
+ # Find the first matching branch.
+ repo_branches = self.list_field_values("branch_name") or []
+ branches = find_matching_branches(config, repo_branches)
+ if not branches:
+ branches = [self._get_default_branch(repository)]
- files = set([f['path'] for f in data['files']])
- return ["/" + file_path for file_path in files if self.filename_is_dockerfile(os.path.basename(file_path))]
+ (result, data, err_msg) = repository.get_path_contents("", revision=branches[0])
+ if not result:
+ raise RepositoryReadException(err_msg)
- def load_dockerfile_contents(self):
- repository = self._get_repository_client()
- path = self.get_dockerfile_path()
+ files = set([f["path"] for f in data["files"]])
+ return [
+ "/" + file_path
+ for file_path in files
+ if self.filename_is_dockerfile(os.path.basename(file_path))
+ ]
- (result, data, err_msg) = repository.get_raw_path_contents(path, revision='master')
- if not result:
- return None
+ def load_dockerfile_contents(self):
+ repository = self._get_repository_client()
+ path = self.get_dockerfile_path()
- return data
+ (result, data, err_msg) = repository.get_raw_path_contents(
+ path, revision="master"
+ )
+ if not result:
+ return None
- def list_field_values(self, field_name, limit=None):
- if 'build_source' not in self.config:
- return None
+ return data
- source = self.config['build_source']
- (namespace, name) = source.split('/')
+ def list_field_values(self, field_name, limit=None):
+ if "build_source" not in self.config:
+ return None
- bitbucket_client = self._get_authorized_client()
- repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
+ source = self.config["build_source"]
+ (namespace, name) = source.split("/")
+
+ bitbucket_client = self._get_authorized_client()
+ repository = bitbucket_client.for_namespace(namespace).repositories().get(name)
+
+ if field_name == "refs":
+ (result, data, _) = repository.get_branches_and_tags()
+ if not result:
+ return None
+
+ branches = [b["name"] for b in data["branches"]]
+ tags = [t["name"] for t in data["tags"]]
+
+ return [{"kind": "branch", "name": b} for b in branches] + [
+ {"kind": "tag", "name": tag} for tag in tags
+ ]
+
+ if field_name == "tag_name":
+ (result, data, _) = repository.get_tags()
+ if not result:
+ return None
+
+ tags = list(data.keys())
+ if limit:
+ tags = tags[0:limit]
+
+ return tags
+
+ if field_name == "branch_name":
+ (result, data, _) = repository.get_branches()
+ if not result:
+ return None
+
+ branches = list(data.keys())
+ if limit:
+ branches = branches[0:limit]
+
+ return branches
- if field_name == 'refs':
- (result, data, _) = repository.get_branches_and_tags()
- if not result:
return None
- branches = [b['name'] for b in data['branches']]
- tags = [t['name'] for t in data['tags']]
+ def get_repository_url(self):
+ source = self.config["build_source"]
+ (namespace, name) = source.split("/")
+ return "https://bitbucket.org/%s/%s" % (namespace, name)
- return ([{'kind': 'branch', 'name': b} for b in branches] +
- [{'kind': 'tag', 'name': tag} for tag in tags])
+ def handle_trigger_request(self, request):
+ payload = request.get_json()
+ if payload is None:
+ raise InvalidPayloadException("Missing payload")
- if field_name == 'tag_name':
- (result, data, _) = repository.get_tags()
- if not result:
- return None
+ logger.debug("Got BitBucket request: %s", payload)
- tags = list(data.keys())
- if limit:
- tags = tags[0:limit]
+ repository = self._get_repository_client()
+ default_branch = self._get_default_branch(repository)
- return tags
+ metadata = get_transformed_webhook_payload(
+ payload, default_branch=default_branch
+ )
+ prepared = self.prepare_build(metadata)
- if field_name == 'branch_name':
- (result, data, _) = repository.get_branches()
- if not result:
- return None
+ # Check if we should skip this build.
+ raise_if_skipped_build(prepared, self.config)
+ return prepared
- branches = list(data.keys())
- if limit:
- branches = branches[0:limit]
+ def manual_start(self, run_parameters=None):
+ run_parameters = run_parameters or {}
+ repository = self._get_repository_client()
+ bitbucket_client = self._get_authorized_client()
- return branches
+ def get_branch_sha(branch_name):
+ # Lookup the commit SHA for the branch.
+ (result, data, _) = repository.get_branch(branch_name)
+ if not result:
+ raise TriggerStartException("Could not find branch in repository")
- return None
+ return data["target"]["hash"]
- def get_repository_url(self):
- source = self.config['build_source']
- (namespace, name) = source.split('/')
- return 'https://bitbucket.org/%s/%s' % (namespace, name)
+ def get_tag_sha(tag_name):
+ # Lookup the commit SHA for the tag.
+ (result, data, _) = repository.get_tag(tag_name)
+ if not result:
+ raise TriggerStartException("Could not find tag in repository")
- def handle_trigger_request(self, request):
- payload = request.get_json()
- if payload is None:
- raise InvalidPayloadException('Missing payload')
+ return data["target"]["hash"]
- logger.debug('Got BitBucket request: %s', payload)
+ def lookup_author(email_address):
+ (result, data, _) = bitbucket_client.accounts().get_profile(email_address)
+ return data if result else None
- repository = self._get_repository_client()
- default_branch = self._get_default_branch(repository)
+ # Find the branch or tag to build.
+ default_branch = self._get_default_branch(repository)
+ (commit_sha, ref) = determine_build_ref(
+ run_parameters, get_branch_sha, get_tag_sha, default_branch
+ )
- metadata = get_transformed_webhook_payload(payload, default_branch=default_branch)
- prepared = self.prepare_build(metadata)
+ # Lookup the commit SHA in BitBucket.
+ (result, commit_info, _) = repository.changesets().get(commit_sha)
+ if not result:
+ raise TriggerStartException("Could not lookup commit SHA")
- # Check if we should skip this build.
- raise_if_skipped_build(prepared, self.config)
- return prepared
+ # Return a prepared build for the commit.
+ repository_name = "%s/%s" % (repository.namespace, repository.repository_name)
+ metadata = get_transformed_commit_info(
+ commit_info, ref, default_branch, repository_name, lookup_author
+ )
- def manual_start(self, run_parameters=None):
- run_parameters = run_parameters or {}
- repository = self._get_repository_client()
- bitbucket_client = self._get_authorized_client()
-
- def get_branch_sha(branch_name):
- # Lookup the commit SHA for the branch.
- (result, data, _) = repository.get_branch(branch_name)
- if not result:
- raise TriggerStartException('Could not find branch in repository')
-
- return data['target']['hash']
-
- def get_tag_sha(tag_name):
- # Lookup the commit SHA for the tag.
- (result, data, _) = repository.get_tag(tag_name)
- if not result:
- raise TriggerStartException('Could not find tag in repository')
-
- return data['target']['hash']
-
- def lookup_author(email_address):
- (result, data, _) = bitbucket_client.accounts().get_profile(email_address)
- return data if result else None
-
- # Find the branch or tag to build.
- default_branch = self._get_default_branch(repository)
- (commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha,
- default_branch)
-
- # Lookup the commit SHA in BitBucket.
- (result, commit_info, _) = repository.changesets().get(commit_sha)
- if not result:
- raise TriggerStartException('Could not lookup commit SHA')
-
- # Return a prepared build for the commit.
- repository_name = '%s/%s' % (repository.namespace, repository.repository_name)
- metadata = get_transformed_commit_info(commit_info, ref, default_branch,
- repository_name, lookup_author)
-
- return self.prepare_build(metadata, is_manual=True)
+ return self.prepare_build(metadata, is_manual=True)
diff --git a/buildtrigger/customhandler.py b/buildtrigger/customhandler.py
index 193445ee2..6ed6e08c7 100644
--- a/buildtrigger/customhandler.py
+++ b/buildtrigger/customhandler.py
@@ -2,22 +2,33 @@ import logging
import json
from jsonschema import validate, ValidationError
-from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
- TriggerStartException, ValidationRequestException,
- InvalidPayloadException,
- SkipRequestException, raise_if_skipped_build,
- find_matching_branches)
+from buildtrigger.triggerutil import (
+ RepositoryReadException,
+ TriggerActivationException,
+ TriggerStartException,
+ ValidationRequestException,
+ InvalidPayloadException,
+ SkipRequestException,
+ raise_if_skipped_build,
+ find_matching_branches,
+)
from buildtrigger.basehandler import BuildTriggerHandler
-from buildtrigger.bitbuckethandler import (BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
- get_transformed_webhook_payload as bb_payload)
+from buildtrigger.bitbuckethandler import (
+ BITBUCKET_WEBHOOK_PAYLOAD_SCHEMA as bb_schema,
+ get_transformed_webhook_payload as bb_payload,
+)
-from buildtrigger.githubhandler import (GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
- get_transformed_webhook_payload as gh_payload)
+from buildtrigger.githubhandler import (
+ GITHUB_WEBHOOK_PAYLOAD_SCHEMA as gh_schema,
+ get_transformed_webhook_payload as gh_payload,
+)
-from buildtrigger.gitlabhandler import (GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
- get_transformed_webhook_payload as gl_payload)
+from buildtrigger.gitlabhandler import (
+ GITLAB_WEBHOOK_PAYLOAD_SCHEMA as gl_schema,
+ get_transformed_webhook_payload as gl_payload,
+)
from util.security.ssh import generate_ssh_keypair
@@ -27,203 +38,191 @@ logger = logging.getLogger(__name__)
# Defines an ordered set of tuples of the schemas and associated transformation functions
# for incoming webhook payloads.
SCHEMA_AND_HANDLERS = [
- (gh_schema, gh_payload),
- (bb_schema, bb_payload),
- (gl_schema, gl_payload),
+ (gh_schema, gh_payload),
+ (bb_schema, bb_payload),
+ (gl_schema, gl_payload),
]
def custom_trigger_payload(metadata, git_url):
- # First try the customhandler schema. If it matches, nothing more to do.
- custom_handler_validation_error = None
- try:
- validate(metadata, CustomBuildTrigger.payload_schema)
- except ValidationError as vex:
- custom_handler_validation_error = vex
-
- # Otherwise, try the defined schemas, in order, until we find a match.
- for schema, handler in SCHEMA_AND_HANDLERS:
+ # First try the customhandler schema. If it matches, nothing more to do.
+ custom_handler_validation_error = None
try:
- validate(metadata, schema)
- except ValidationError:
- continue
+ validate(metadata, CustomBuildTrigger.payload_schema)
+ except ValidationError as vex:
+ custom_handler_validation_error = vex
- result = handler(metadata)
- result['git_url'] = git_url
- return result
+ # Otherwise, try the defined schemas, in order, until we find a match.
+ for schema, handler in SCHEMA_AND_HANDLERS:
+ try:
+ validate(metadata, schema)
+ except ValidationError:
+ continue
- # If we have reached this point and no other schemas validated, then raise the error for the
- # custom schema.
- if custom_handler_validation_error is not None:
- raise InvalidPayloadException(custom_handler_validation_error.message)
+ result = handler(metadata)
+ result["git_url"] = git_url
+ return result
- metadata['git_url'] = git_url
- return metadata
+ # If we have reached this point and no other schemas validated, then raise the error for the
+ # custom schema.
+ if custom_handler_validation_error is not None:
+ raise InvalidPayloadException(custom_handler_validation_error.message)
+
+ metadata["git_url"] = git_url
+ return metadata
class CustomBuildTrigger(BuildTriggerHandler):
- payload_schema = {
- 'type': 'object',
- 'properties': {
- 'commit': {
- 'type': 'string',
- 'description': 'first 7 characters of the SHA-1 identifier for a git commit',
- 'pattern': '^([A-Fa-f0-9]{7,})$',
- },
- 'ref': {
- 'type': 'string',
- 'description': 'git reference for a git commit',
- 'pattern': '^refs\/(heads|tags|remotes)\/(.+)$',
- },
- 'default_branch': {
- 'type': 'string',
- 'description': 'default branch of the git repository',
- },
- 'commit_info': {
- 'type': 'object',
- 'description': 'metadata about a git commit',
- 'properties': {
- 'url': {
- 'type': 'string',
- 'description': 'URL to view a git commit',
- },
- 'message': {
- 'type': 'string',
- 'description': 'git commit message',
- },
- 'date': {
- 'type': 'string',
- 'description': 'timestamp for a git commit'
- },
- 'author': {
- 'type': 'object',
- 'description': 'metadata about the author of a git commit',
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'username of the author',
- },
- 'url': {
- 'type': 'string',
- 'description': 'URL to view the profile of the author',
- },
- 'avatar_url': {
- 'type': 'string',
- 'description': 'URL to view the avatar of the author',
- },
+ payload_schema = {
+ "type": "object",
+ "properties": {
+ "commit": {
+ "type": "string",
+ "description": "first 7 characters of the SHA-1 identifier for a git commit",
+ "pattern": "^([A-Fa-f0-9]{7,})$",
},
- 'required': ['username', 'url', 'avatar_url'],
- },
- 'committer': {
- 'type': 'object',
- 'description': 'metadata about the committer of a git commit',
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'username of the committer',
- },
- 'url': {
- 'type': 'string',
- 'description': 'URL to view the profile of the committer',
- },
- 'avatar_url': {
- 'type': 'string',
- 'description': 'URL to view the avatar of the committer',
- },
+ "ref": {
+ "type": "string",
+ "description": "git reference for a git commit",
+ "pattern": "^refs\/(heads|tags|remotes)\/(.+)$",
+ },
+ "default_branch": {
+ "type": "string",
+ "description": "default branch of the git repository",
+ },
+ "commit_info": {
+ "type": "object",
+ "description": "metadata about a git commit",
+ "properties": {
+ "url": {
+ "type": "string",
+ "description": "URL to view a git commit",
+ },
+ "message": {"type": "string", "description": "git commit message"},
+ "date": {
+ "type": "string",
+ "description": "timestamp for a git commit",
+ },
+ "author": {
+ "type": "object",
+ "description": "metadata about the author of a git commit",
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "username of the author",
+ },
+ "url": {
+ "type": "string",
+ "description": "URL to view the profile of the author",
+ },
+ "avatar_url": {
+ "type": "string",
+ "description": "URL to view the avatar of the author",
+ },
+ },
+ "required": ["username", "url", "avatar_url"],
+ },
+ "committer": {
+ "type": "object",
+ "description": "metadata about the committer of a git commit",
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "username of the committer",
+ },
+ "url": {
+ "type": "string",
+ "description": "URL to view the profile of the committer",
+ },
+ "avatar_url": {
+ "type": "string",
+ "description": "URL to view the avatar of the committer",
+ },
+ },
+ "required": ["username", "url", "avatar_url"],
+ },
+ },
+ "required": ["url", "message", "date"],
},
- 'required': ['username', 'url', 'avatar_url'],
- },
},
- 'required': ['url', 'message', 'date'],
- },
- },
- 'required': ['commit', 'ref', 'default_branch'],
- }
-
- @classmethod
- def service_name(cls):
- return 'custom-git'
-
- def is_active(self):
- return self.config.has_key('credentials')
-
- def _metadata_from_payload(self, payload, git_url):
- # Parse the JSON payload.
- try:
- metadata = json.loads(payload)
- except ValueError as vex:
- raise InvalidPayloadException(vex.message)
-
- return custom_trigger_payload(metadata, git_url)
-
- def handle_trigger_request(self, request):
- payload = request.data
- if not payload:
- raise InvalidPayloadException('Missing expected payload')
-
- logger.debug('Payload %s', payload)
-
- metadata = self._metadata_from_payload(payload, self.config['build_source'])
- prepared = self.prepare_build(metadata)
-
- # Check if we should skip this build.
- raise_if_skipped_build(prepared, self.config)
-
- return prepared
-
- def manual_start(self, run_parameters=None):
- # commit_sha is the only required parameter
- commit_sha = run_parameters.get('commit_sha')
- if commit_sha is None:
- raise TriggerStartException('missing required parameter')
-
- config = self.config
- metadata = {
- 'commit': commit_sha,
- 'git_url': config['build_source'],
+ "required": ["commit", "ref", "default_branch"],
}
- try:
- return self.prepare_build(metadata, is_manual=True)
- except ValidationError as ve:
- raise TriggerStartException(ve.message)
+ @classmethod
+ def service_name(cls):
+ return "custom-git"
- def activate(self, standard_webhook_url):
- config = self.config
- public_key, private_key = generate_ssh_keypair()
- config['credentials'] = [
- {
- 'name': 'SSH Public Key',
- 'value': public_key,
- },
- {
- 'name': 'Webhook Endpoint URL',
- 'value': standard_webhook_url,
- },
- ]
- self.config = config
- return config, {'private_key': private_key}
+ def is_active(self):
+ return self.config.has_key("credentials")
- def deactivate(self):
- config = self.config
- config.pop('credentials', None)
- self.config = config
- return config
+ def _metadata_from_payload(self, payload, git_url):
+ # Parse the JSON payload.
+ try:
+ metadata = json.loads(payload)
+ except ValueError as vex:
+ raise InvalidPayloadException(vex.message)
- def get_repository_url(self):
- return None
+ return custom_trigger_payload(metadata, git_url)
- def list_build_source_namespaces(self):
- raise NotImplementedError
+ def handle_trigger_request(self, request):
+ payload = request.data
+ if not payload:
+ raise InvalidPayloadException("Missing expected payload")
- def list_build_sources_for_namespace(self, namespace):
- raise NotImplementedError
+ logger.debug("Payload %s", payload)
- def list_build_subdirs(self):
- raise NotImplementedError
+ metadata = self._metadata_from_payload(payload, self.config["build_source"])
+ prepared = self.prepare_build(metadata)
- def list_field_values(self, field_name, limit=None):
- raise NotImplementedError
+ # Check if we should skip this build.
+ raise_if_skipped_build(prepared, self.config)
- def load_dockerfile_contents(self):
- raise NotImplementedError
+ return prepared
+
+ def manual_start(self, run_parameters=None):
+ # commit_sha is the only required parameter
+ commit_sha = run_parameters.get("commit_sha")
+ if commit_sha is None:
+ raise TriggerStartException("missing required parameter")
+
+ config = self.config
+ metadata = {"commit": commit_sha, "git_url": config["build_source"]}
+
+ try:
+ return self.prepare_build(metadata, is_manual=True)
+ except ValidationError as ve:
+ raise TriggerStartException(ve.message)
+
+ def activate(self, standard_webhook_url):
+ config = self.config
+ public_key, private_key = generate_ssh_keypair()
+ config["credentials"] = [
+ {"name": "SSH Public Key", "value": public_key},
+ {"name": "Webhook Endpoint URL", "value": standard_webhook_url},
+ ]
+ self.config = config
+ return config, {"private_key": private_key}
+
+ def deactivate(self):
+ config = self.config
+ config.pop("credentials", None)
+ self.config = config
+ return config
+
+ def get_repository_url(self):
+ return None
+
+ def list_build_source_namespaces(self):
+ raise NotImplementedError
+
+ def list_build_sources_for_namespace(self, namespace):
+ raise NotImplementedError
+
+ def list_build_subdirs(self):
+ raise NotImplementedError
+
+ def list_field_values(self, field_name, limit=None):
+ raise NotImplementedError
+
+ def load_dockerfile_contents(self):
+ raise NotImplementedError
diff --git a/buildtrigger/githubhandler.py b/buildtrigger/githubhandler.py
index bc40f993c..7fc12135b 100644
--- a/buildtrigger/githubhandler.py
+++ b/buildtrigger/githubhandler.py
@@ -7,18 +7,29 @@ from calendar import timegm
from functools import wraps
from ssl import SSLError
-from github import (Github, UnknownObjectException, GithubException,
- BadCredentialsException as GitHubBadCredentialsException)
+from github import (
+ Github,
+ UnknownObjectException,
+ GithubException,
+ BadCredentialsException as GitHubBadCredentialsException,
+)
from jsonschema import validate
from app import app, github_trigger
-from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
- TriggerDeactivationException, TriggerStartException,
- EmptyRepositoryException, ValidationRequestException,
- SkipRequestException, InvalidPayloadException,
- determine_build_ref, raise_if_skipped_build,
- find_matching_branches)
+from buildtrigger.triggerutil import (
+ RepositoryReadException,
+ TriggerActivationException,
+ TriggerDeactivationException,
+ TriggerStartException,
+ EmptyRepositoryException,
+ ValidationRequestException,
+ SkipRequestException,
+ InvalidPayloadException,
+ determine_build_ref,
+ raise_if_skipped_build,
+ find_matching_branches,
+)
from buildtrigger.basehandler import BuildTriggerHandler
from endpoints.exception import ExternalServiceError
from util.security.ssh import generate_ssh_keypair
@@ -27,561 +38,576 @@ from util.dict_wrappers import JSONPathDict, SafeDictSetter
logger = logging.getLogger(__name__)
GITHUB_WEBHOOK_PAYLOAD_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'ref': {
- 'type': 'string',
+ "type": "object",
+ "properties": {
+ "ref": {"type": "string"},
+ "head_commit": {
+ "type": ["object", "null"],
+ "properties": {
+ "id": {"type": "string"},
+ "url": {"type": "string"},
+ "message": {"type": "string"},
+ "timestamp": {"type": "string"},
+ "author": {
+ "type": "object",
+ "properties": {
+ "username": {"type": "string"},
+ "html_url": {"type": "string"},
+ "avatar_url": {"type": "string"},
+ },
+ },
+ "committer": {
+ "type": "object",
+ "properties": {
+ "username": {"type": "string"},
+ "html_url": {"type": "string"},
+ "avatar_url": {"type": "string"},
+ },
+ },
+ },
+ "required": ["id", "url", "message", "timestamp"],
+ },
+ "repository": {
+ "type": "object",
+ "properties": {"ssh_url": {"type": "string"}},
+ "required": ["ssh_url"],
+ },
},
- 'head_commit': {
- 'type': ['object', 'null'],
- 'properties': {
- 'id': {
- 'type': 'string',
- },
- 'url': {
- 'type': 'string',
- },
- 'message': {
- 'type': 'string',
- },
- 'timestamp': {
- 'type': 'string',
- },
- 'author': {
- 'type': 'object',
- 'properties': {
- 'username': {
- 'type': 'string'
- },
- 'html_url': {
- 'type': 'string'
- },
- 'avatar_url': {
- 'type': 'string'
- },
- },
- },
- 'committer': {
- 'type': 'object',
- 'properties': {
- 'username': {
- 'type': 'string'
- },
- 'html_url': {
- 'type': 'string'
- },
- 'avatar_url': {
- 'type': 'string'
- },
- },
- },
- },
- 'required': ['id', 'url', 'message', 'timestamp'],
- },
- 'repository': {
- 'type': 'object',
- 'properties': {
- 'ssh_url': {
- 'type': 'string',
- },
- },
- 'required': ['ssh_url'],
- },
- },
- 'required': ['ref', 'head_commit', 'repository'],
+ "required": ["ref", "head_commit", "repository"],
}
+
def get_transformed_webhook_payload(gh_payload, default_branch=None, lookup_user=None):
- """ Returns the GitHub webhook JSON payload transformed into our own payload
+ """ Returns the GitHub webhook JSON payload transformed into our own payload
format. If the gh_payload is not valid, returns None.
"""
- try:
- validate(gh_payload, GITHUB_WEBHOOK_PAYLOAD_SCHEMA)
- except Exception as exc:
- raise InvalidPayloadException(exc.message)
+ try:
+ validate(gh_payload, GITHUB_WEBHOOK_PAYLOAD_SCHEMA)
+ except Exception as exc:
+ raise InvalidPayloadException(exc.message)
- payload = JSONPathDict(gh_payload)
+ payload = JSONPathDict(gh_payload)
- if payload['head_commit'] is None:
- raise SkipRequestException
+ if payload["head_commit"] is None:
+ raise SkipRequestException
- config = SafeDictSetter()
- config['commit'] = payload['head_commit.id']
- config['ref'] = payload['ref']
- config['default_branch'] = payload['repository.default_branch'] or default_branch
- config['git_url'] = payload['repository.ssh_url']
+ config = SafeDictSetter()
+ config["commit"] = payload["head_commit.id"]
+ config["ref"] = payload["ref"]
+ config["default_branch"] = payload["repository.default_branch"] or default_branch
+ config["git_url"] = payload["repository.ssh_url"]
- config['commit_info.url'] = payload['head_commit.url']
- config['commit_info.message'] = payload['head_commit.message']
- config['commit_info.date'] = payload['head_commit.timestamp']
+ config["commit_info.url"] = payload["head_commit.url"]
+ config["commit_info.message"] = payload["head_commit.message"]
+ config["commit_info.date"] = payload["head_commit.timestamp"]
- config['commit_info.author.username'] = payload['head_commit.author.username']
- config['commit_info.author.url'] = payload.get('head_commit.author.html_url')
- config['commit_info.author.avatar_url'] = payload.get('head_commit.author.avatar_url')
+ config["commit_info.author.username"] = payload["head_commit.author.username"]
+ config["commit_info.author.url"] = payload.get("head_commit.author.html_url")
+ config["commit_info.author.avatar_url"] = payload.get(
+ "head_commit.author.avatar_url"
+ )
- config['commit_info.committer.username'] = payload.get('head_commit.committer.username')
- config['commit_info.committer.url'] = payload.get('head_commit.committer.html_url')
- config['commit_info.committer.avatar_url'] = payload.get('head_commit.committer.avatar_url')
+ config["commit_info.committer.username"] = payload.get(
+ "head_commit.committer.username"
+ )
+ config["commit_info.committer.url"] = payload.get("head_commit.committer.html_url")
+ config["commit_info.committer.avatar_url"] = payload.get(
+ "head_commit.committer.avatar_url"
+ )
- # Note: GitHub doesn't always return the extra information for users, so we do the lookup
- # manually if possible.
- if (lookup_user and not payload.get('head_commit.author.html_url') and
- payload.get('head_commit.author.username')):
- author_info = lookup_user(payload['head_commit.author.username'])
- if author_info:
- config['commit_info.author.url'] = author_info['html_url']
- config['commit_info.author.avatar_url'] = author_info['avatar_url']
+ # Note: GitHub doesn't always return the extra information for users, so we do the lookup
+ # manually if possible.
+ if (
+ lookup_user
+ and not payload.get("head_commit.author.html_url")
+ and payload.get("head_commit.author.username")
+ ):
+ author_info = lookup_user(payload["head_commit.author.username"])
+ if author_info:
+ config["commit_info.author.url"] = author_info["html_url"]
+ config["commit_info.author.avatar_url"] = author_info["avatar_url"]
- if (lookup_user and
- payload.get('head_commit.committer.username') and
- not payload.get('head_commit.committer.html_url')):
- committer_info = lookup_user(payload['head_commit.committer.username'])
- if committer_info:
- config['commit_info.committer.url'] = committer_info['html_url']
- config['commit_info.committer.avatar_url'] = committer_info['avatar_url']
+ if (
+ lookup_user
+ and payload.get("head_commit.committer.username")
+ and not payload.get("head_commit.committer.html_url")
+ ):
+ committer_info = lookup_user(payload["head_commit.committer.username"])
+ if committer_info:
+ config["commit_info.committer.url"] = committer_info["html_url"]
+ config["commit_info.committer.avatar_url"] = committer_info["avatar_url"]
- return config.dict_value()
+ return config.dict_value()
def _catch_ssl_errors(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- try:
- return func(*args, **kwargs)
- except SSLError as se:
- msg = 'Request to the GitHub API failed: %s' % se.message
- logger.exception(msg)
- raise ExternalServiceError(msg)
- return wrapper
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except SSLError as se:
+ msg = "Request to the GitHub API failed: %s" % se.message
+ logger.exception(msg)
+ raise ExternalServiceError(msg)
+
+ return wrapper
class GithubBuildTrigger(BuildTriggerHandler):
- """
+ """
BuildTrigger for GitHub that uses the archive API and buildpacks.
"""
- def _get_client(self):
- """ Returns an authenticated client for talking to the GitHub API. """
- return Github(self.auth_token,
- base_url=github_trigger.api_endpoint(),
- client_id=github_trigger.client_id(),
- client_secret=github_trigger.client_secret(),
- timeout=5)
- @classmethod
- def service_name(cls):
- return 'github'
+ def _get_client(self):
+ """ Returns an authenticated client for talking to the GitHub API. """
+ return Github(
+ self.auth_token,
+ base_url=github_trigger.api_endpoint(),
+ client_id=github_trigger.client_id(),
+ client_secret=github_trigger.client_secret(),
+ timeout=5,
+ )
- def is_active(self):
- return 'hook_id' in self.config
+ @classmethod
+ def service_name(cls):
+ return "github"
- def get_repository_url(self):
- source = self.config['build_source']
- return github_trigger.get_public_url(source)
+ def is_active(self):
+ return "hook_id" in self.config
- @staticmethod
- def _get_error_message(ghe, default_msg):
- if ghe.data.get('errors') and ghe.data['errors'][0].get('message'):
- return ghe.data['errors'][0]['message']
+ def get_repository_url(self):
+ source = self.config["build_source"]
+ return github_trigger.get_public_url(source)
- return default_msg
+ @staticmethod
+ def _get_error_message(ghe, default_msg):
+ if ghe.data.get("errors") and ghe.data["errors"][0].get("message"):
+ return ghe.data["errors"][0]["message"]
- @_catch_ssl_errors
- def activate(self, standard_webhook_url):
- config = self.config
- new_build_source = config['build_source']
- gh_client = self._get_client()
+ return default_msg
- # Find the GitHub repository.
- try:
- gh_repo = gh_client.get_repo(new_build_source)
- except UnknownObjectException:
- msg = 'Unable to find GitHub repository for source: %s' % new_build_source
- raise TriggerActivationException(msg)
-
- # Add a deploy key to the GitHub repository.
- public_key, private_key = generate_ssh_keypair()
- config['credentials'] = [
- {
- 'name': 'SSH Public Key',
- 'value': public_key,
- },
- ]
-
- try:
- deploy_key = gh_repo.create_key('%s Builder' % app.config['REGISTRY_TITLE'],
- public_key)
- config['deploy_key_id'] = deploy_key.id
- except GithubException as ghe:
- default_msg = 'Unable to add deploy key to repository: %s' % new_build_source
- msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
- raise TriggerActivationException(msg)
-
- # Add the webhook to the GitHub repository.
- webhook_config = {
- 'url': standard_webhook_url,
- 'content_type': 'json',
- }
-
- try:
- hook = gh_repo.create_hook('web', webhook_config)
- config['hook_id'] = hook.id
- config['master_branch'] = gh_repo.default_branch
- except GithubException as ghe:
- default_msg = 'Unable to create webhook on repository: %s' % new_build_source
- msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
- raise TriggerActivationException(msg)
-
- return config, {'private_key': private_key}
-
- @_catch_ssl_errors
- def deactivate(self):
- config = self.config
- gh_client = self._get_client()
-
- # Find the GitHub repository.
- try:
- repo = gh_client.get_repo(config['build_source'])
- except UnknownObjectException:
- msg = 'Unable to find GitHub repository for source: %s' % config['build_source']
- raise TriggerDeactivationException(msg)
- except GitHubBadCredentialsException:
- msg = 'Unable to access repository to disable trigger'
- raise TriggerDeactivationException(msg)
-
- # If the trigger uses a deploy key, remove it.
- try:
- if config['deploy_key_id']:
- deploy_key = repo.get_key(config['deploy_key_id'])
- deploy_key.delete()
- except KeyError:
- # There was no config['deploy_key_id'], thus this is an old trigger without a deploy key.
- pass
- except GithubException as ghe:
- default_msg = 'Unable to remove deploy key: %s' % config['deploy_key_id']
- msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
- raise TriggerDeactivationException(msg)
-
- # Remove the webhook.
- if 'hook_id' in config:
- try:
- hook = repo.get_hook(config['hook_id'])
- hook.delete()
- except GithubException as ghe:
- default_msg = 'Unable to remove hook: %s' % config['hook_id']
- msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
- raise TriggerDeactivationException(msg)
-
- config.pop('hook_id', None)
- self.config = config
- return config
-
- @_catch_ssl_errors
- def list_build_source_namespaces(self):
- gh_client = self._get_client()
- usr = gh_client.get_user()
-
- # Build the full set of namespaces for the user, starting with their own.
- namespaces = {}
- namespaces[usr.login] = {
- 'personal': True,
- 'id': usr.login,
- 'title': usr.name or usr.login,
- 'avatar_url': usr.avatar_url,
- 'url': usr.html_url,
- 'score': usr.plan.private_repos if usr.plan else 0,
- }
-
- for org in usr.get_orgs():
- organization = org.login if org.login else org.name
-
- # NOTE: We don't load the organization's html_url nor its plan, because doing
- # so requires loading *each organization* via its own API call in this tight
- # loop, which was massively slowing down the load time for users when setting
- # up triggers.
- namespaces[organization] = {
- 'personal': False,
- 'id': organization,
- 'title': organization,
- 'avatar_url': org.avatar_url,
- 'url': '',
- 'score': 0,
- }
-
- return BuildTriggerHandler.build_namespaces_response(namespaces)
-
- @_catch_ssl_errors
- def list_build_sources_for_namespace(self, namespace):
- def repo_view(repo):
- return {
- 'name': repo.name,
- 'full_name': repo.full_name,
- 'description': repo.description or '',
- 'last_updated': timegm(repo.pushed_at.utctimetuple()) if repo.pushed_at else 0,
- 'url': repo.html_url,
- 'has_admin_permissions': repo.permissions.admin,
- 'private': repo.private,
- }
-
- gh_client = self._get_client()
- usr = gh_client.get_user()
- if namespace == usr.login:
- repos = [repo_view(repo) for repo in usr.get_repos(type='owner', sort='updated')]
- return BuildTriggerHandler.build_sources_response(repos)
-
- try:
- org = gh_client.get_organization(namespace)
- if org is None:
- return []
- except GithubException:
- return []
-
- repos = [repo_view(repo) for repo in org.get_repos(type='member')]
- return BuildTriggerHandler.build_sources_response(repos)
-
-
- @_catch_ssl_errors
- def list_build_subdirs(self):
- config = self.config
- gh_client = self._get_client()
- source = config['build_source']
-
- try:
- repo = gh_client.get_repo(source)
-
- # Find the first matching branch.
- repo_branches = self.list_field_values('branch_name') or []
- branches = find_matching_branches(config, repo_branches)
- branches = branches or [repo.default_branch or 'master']
- default_commit = repo.get_branch(branches[0]).commit
- commit_tree = repo.get_git_tree(default_commit.sha, recursive=True)
-
- return [elem.path for elem in commit_tree.tree
- if (elem.type == u'blob' and self.filename_is_dockerfile(os.path.basename(elem.path)))]
- except GithubException as ghe:
- message = ghe.data.get('message', 'Unable to list contents of repository: %s' % source)
- if message == 'Branch not found':
- raise EmptyRepositoryException()
-
- raise RepositoryReadException(message)
-
- @_catch_ssl_errors
- def load_dockerfile_contents(self):
- config = self.config
- gh_client = self._get_client()
- source = config['build_source']
-
- try:
- repo = gh_client.get_repo(source)
- except GithubException as ghe:
- message = ghe.data.get('message', 'Unable to list contents of repository: %s' % source)
- raise RepositoryReadException(message)
-
- path = self.get_dockerfile_path()
- if not path:
- return None
-
- try:
- file_info = repo.get_contents(path)
- # TypeError is needed because directory inputs cause a TypeError
- except (GithubException, TypeError) as ghe:
- logger.error("got error from trying to find github file %s" % ghe)
- return None
-
- if file_info is None:
- return None
-
- if isinstance(file_info, list):
- return None
-
- content = file_info.content
- if file_info.encoding == 'base64':
- content = base64.b64decode(content)
- return content
-
- @_catch_ssl_errors
- def list_field_values(self, field_name, limit=None):
- if field_name == 'refs':
- branches = self.list_field_values('branch_name')
- tags = self.list_field_values('tag_name')
-
- return ([{'kind': 'branch', 'name': b} for b in branches] +
- [{'kind': 'tag', 'name': tag} for tag in tags])
-
- config = self.config
- source = config.get('build_source')
- if source is None:
- return []
-
- if field_name == 'tag_name':
- try:
+ @_catch_ssl_errors
+ def activate(self, standard_webhook_url):
+ config = self.config
+ new_build_source = config["build_source"]
gh_client = self._get_client()
- repo = gh_client.get_repo(source)
- gh_tags = repo.get_tags()
- if limit:
- gh_tags = repo.get_tags()[0:limit]
- return [tag.name for tag in gh_tags]
- except GitHubBadCredentialsException:
- return []
- except GithubException:
- logger.exception("Got GitHub Exception when trying to list tags for trigger %s",
- self.trigger.id)
- return []
+ # Find the GitHub repository.
+ try:
+ gh_repo = gh_client.get_repo(new_build_source)
+ except UnknownObjectException:
+ msg = "Unable to find GitHub repository for source: %s" % new_build_source
+ raise TriggerActivationException(msg)
- if field_name == 'branch_name':
- try:
+ # Add a deploy key to the GitHub repository.
+ public_key, private_key = generate_ssh_keypair()
+ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
+
+ try:
+ deploy_key = gh_repo.create_key(
+ "%s Builder" % app.config["REGISTRY_TITLE"], public_key
+ )
+ config["deploy_key_id"] = deploy_key.id
+ except GithubException as ghe:
+ default_msg = (
+ "Unable to add deploy key to repository: %s" % new_build_source
+ )
+ msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
+ raise TriggerActivationException(msg)
+
+ # Add the webhook to the GitHub repository.
+ webhook_config = {"url": standard_webhook_url, "content_type": "json"}
+
+ try:
+ hook = gh_repo.create_hook("web", webhook_config)
+ config["hook_id"] = hook.id
+ config["master_branch"] = gh_repo.default_branch
+ except GithubException as ghe:
+ default_msg = (
+ "Unable to create webhook on repository: %s" % new_build_source
+ )
+ msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
+ raise TriggerActivationException(msg)
+
+ return config, {"private_key": private_key}
+
+ @_catch_ssl_errors
+ def deactivate(self):
+ config = self.config
gh_client = self._get_client()
- repo = gh_client.get_repo(source)
- gh_branches = repo.get_branches()
- if limit:
- gh_branches = repo.get_branches()[0:limit]
- branches = [branch.name for branch in gh_branches]
+ # Find the GitHub repository.
+ try:
+ repo = gh_client.get_repo(config["build_source"])
+ except UnknownObjectException:
+ msg = (
+ "Unable to find GitHub repository for source: %s"
+ % config["build_source"]
+ )
+ raise TriggerDeactivationException(msg)
+ except GitHubBadCredentialsException:
+ msg = "Unable to access repository to disable trigger"
+ raise TriggerDeactivationException(msg)
- if not repo.default_branch in branches:
- branches.insert(0, repo.default_branch)
+ # If the trigger uses a deploy key, remove it.
+ try:
+ if config["deploy_key_id"]:
+ deploy_key = repo.get_key(config["deploy_key_id"])
+ deploy_key.delete()
+ except KeyError:
+ # There was no config['deploy_key_id'], thus this is an old trigger without a deploy key.
+ pass
+ except GithubException as ghe:
+ default_msg = "Unable to remove deploy key: %s" % config["deploy_key_id"]
+ msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
+ raise TriggerDeactivationException(msg)
- if branches[0] != repo.default_branch:
- branches.remove(repo.default_branch)
- branches.insert(0, repo.default_branch)
+ # Remove the webhook.
+ if "hook_id" in config:
+ try:
+ hook = repo.get_hook(config["hook_id"])
+ hook.delete()
+ except GithubException as ghe:
+ default_msg = "Unable to remove hook: %s" % config["hook_id"]
+ msg = GithubBuildTrigger._get_error_message(ghe, default_msg)
+ raise TriggerDeactivationException(msg)
- return branches
- except GitHubBadCredentialsException:
- return ['master']
- except GithubException:
- logger.exception("Got GitHub Exception when trying to list branches for trigger %s",
- self.trigger.id)
- return ['master']
+ config.pop("hook_id", None)
+ self.config = config
+ return config
- return None
+ @_catch_ssl_errors
+ def list_build_source_namespaces(self):
+ gh_client = self._get_client()
+ usr = gh_client.get_user()
- @classmethod
- def _build_metadata_for_commit(cls, commit_sha, ref, repo):
- try:
- commit = repo.get_commit(commit_sha)
- except GithubException:
- logger.exception('Could not load commit information from GitHub')
- return None
+ # Build the full set of namespaces for the user, starting with their own.
+ namespaces = {}
+ namespaces[usr.login] = {
+ "personal": True,
+ "id": usr.login,
+ "title": usr.name or usr.login,
+ "avatar_url": usr.avatar_url,
+ "url": usr.html_url,
+ "score": usr.plan.private_repos if usr.plan else 0,
+ }
- commit_info = {
- 'url': commit.html_url,
- 'message': commit.commit.message,
- 'date': commit.last_modified
- }
+ for org in usr.get_orgs():
+ organization = org.login if org.login else org.name
- if commit.author:
- commit_info['author'] = {
- 'username': commit.author.login,
- 'avatar_url': commit.author.avatar_url,
- 'url': commit.author.html_url
- }
+ # NOTE: We don't load the organization's html_url nor its plan, because doing
+ # so requires loading *each organization* via its own API call in this tight
+ # loop, which was massively slowing down the load time for users when setting
+ # up triggers.
+ namespaces[organization] = {
+ "personal": False,
+ "id": organization,
+ "title": organization,
+ "avatar_url": org.avatar_url,
+ "url": "",
+ "score": 0,
+ }
- if commit.committer:
- commit_info['committer'] = {
- 'username': commit.committer.login,
- 'avatar_url': commit.committer.avatar_url,
- 'url': commit.committer.html_url
- }
+ return BuildTriggerHandler.build_namespaces_response(namespaces)
- return {
- 'commit': commit_sha,
- 'ref': ref,
- 'default_branch': repo.default_branch,
- 'git_url': repo.ssh_url,
- 'commit_info': commit_info
- }
+ @_catch_ssl_errors
+ def list_build_sources_for_namespace(self, namespace):
+ def repo_view(repo):
+ return {
+ "name": repo.name,
+ "full_name": repo.full_name,
+ "description": repo.description or "",
+ "last_updated": timegm(repo.pushed_at.utctimetuple())
+ if repo.pushed_at
+ else 0,
+ "url": repo.html_url,
+ "has_admin_permissions": repo.permissions.admin,
+ "private": repo.private,
+ }
- @_catch_ssl_errors
- def manual_start(self, run_parameters=None):
- config = self.config
- source = config['build_source']
+ gh_client = self._get_client()
+ usr = gh_client.get_user()
+ if namespace == usr.login:
+ repos = [
+ repo_view(repo) for repo in usr.get_repos(type="owner", sort="updated")
+ ]
+ return BuildTriggerHandler.build_sources_response(repos)
- try:
- gh_client = self._get_client()
- repo = gh_client.get_repo(source)
- default_branch = repo.default_branch
- except GithubException as ghe:
- msg = GithubBuildTrigger._get_error_message(ghe, 'Unable to start build trigger')
- raise TriggerStartException(msg)
+ try:
+ org = gh_client.get_organization(namespace)
+ if org is None:
+ return []
+ except GithubException:
+ return []
- def get_branch_sha(branch_name):
- try:
- branch = repo.get_branch(branch_name)
- return branch.commit.sha
- except GithubException:
- raise TriggerStartException('Could not find branch in repository')
+ repos = [repo_view(repo) for repo in org.get_repos(type="member")]
+ return BuildTriggerHandler.build_sources_response(repos)
- def get_tag_sha(tag_name):
- tags = {tag.name: tag for tag in repo.get_tags()}
- if not tag_name in tags:
- raise TriggerStartException('Could not find tag in repository')
+ @_catch_ssl_errors
+ def list_build_subdirs(self):
+ config = self.config
+ gh_client = self._get_client()
+ source = config["build_source"]
- return tags[tag_name].commit.sha
+ try:
+ repo = gh_client.get_repo(source)
- # Find the branch or tag to build.
- (commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha,
- default_branch)
+ # Find the first matching branch.
+ repo_branches = self.list_field_values("branch_name") or []
+ branches = find_matching_branches(config, repo_branches)
+ branches = branches or [repo.default_branch or "master"]
+ default_commit = repo.get_branch(branches[0]).commit
+ commit_tree = repo.get_git_tree(default_commit.sha, recursive=True)
- metadata = GithubBuildTrigger._build_metadata_for_commit(commit_sha, ref, repo)
- return self.prepare_build(metadata, is_manual=True)
+ return [
+ elem.path
+ for elem in commit_tree.tree
+ if (
+ elem.type == u"blob"
+ and self.filename_is_dockerfile(os.path.basename(elem.path))
+ )
+ ]
+ except GithubException as ghe:
+ message = ghe.data.get(
+ "message", "Unable to list contents of repository: %s" % source
+ )
+ if message == "Branch not found":
+ raise EmptyRepositoryException()
- @_catch_ssl_errors
- def lookup_user(self, username):
- try:
- gh_client = self._get_client()
- user = gh_client.get_user(username)
- return {
- 'html_url': user.html_url,
- 'avatar_url': user.avatar_url
- }
- except GithubException:
- return None
+ raise RepositoryReadException(message)
- @_catch_ssl_errors
- def handle_trigger_request(self, request):
- # Check the payload to see if we should skip it based on the lack of a head_commit.
- payload = request.get_json()
- if payload is None:
- raise InvalidPayloadException('Missing payload')
+ @_catch_ssl_errors
+ def load_dockerfile_contents(self):
+ config = self.config
+ gh_client = self._get_client()
+ source = config["build_source"]
- # This is for GitHub's probing/testing.
- if 'zen' in payload:
- raise SkipRequestException()
+ try:
+ repo = gh_client.get_repo(source)
+ except GithubException as ghe:
+ message = ghe.data.get(
+ "message", "Unable to list contents of repository: %s" % source
+ )
+ raise RepositoryReadException(message)
- # Lookup the default branch for the repository.
- if 'repository' not in payload:
- raise InvalidPayloadException("Missing 'repository' on request")
+ path = self.get_dockerfile_path()
+ if not path:
+ return None
- if 'owner' not in payload['repository']:
- raise InvalidPayloadException("Missing 'owner' on repository")
+ try:
+ file_info = repo.get_contents(path)
+ # TypeError is needed because directory inputs cause a TypeError
+ except (GithubException, TypeError) as ghe:
+ logger.error("got error from trying to find github file %s" % ghe)
+ return None
- if 'name' not in payload['repository']['owner']:
- raise InvalidPayloadException("Missing owner 'name' on repository")
+ if file_info is None:
+ return None
- if 'name' not in payload['repository']:
- raise InvalidPayloadException("Missing 'name' on repository")
+ if isinstance(file_info, list):
+ return None
- default_branch = None
- lookup_user = None
- try:
- repo_full_name = '%s/%s' % (payload['repository']['owner']['name'],
- payload['repository']['name'])
+ content = file_info.content
+ if file_info.encoding == "base64":
+ content = base64.b64decode(content)
+ return content
- gh_client = self._get_client()
- repo = gh_client.get_repo(repo_full_name)
- default_branch = repo.default_branch
- lookup_user = self.lookup_user
- except GitHubBadCredentialsException:
- logger.exception('Got GitHub Credentials Exception; Cannot lookup default branch')
- except GithubException:
- logger.exception("Got GitHub Exception when trying to start trigger %s", self.trigger.id)
- raise SkipRequestException()
+ @_catch_ssl_errors
+ def list_field_values(self, field_name, limit=None):
+ if field_name == "refs":
+ branches = self.list_field_values("branch_name")
+ tags = self.list_field_values("tag_name")
- logger.debug('GitHub trigger payload %s', payload)
- metadata = get_transformed_webhook_payload(payload, default_branch=default_branch,
- lookup_user=lookup_user)
- prepared = self.prepare_build(metadata)
+ return [{"kind": "branch", "name": b} for b in branches] + [
+ {"kind": "tag", "name": tag} for tag in tags
+ ]
- # Check if we should skip this build.
- raise_if_skipped_build(prepared, self.config)
- return prepared
+ config = self.config
+ source = config.get("build_source")
+ if source is None:
+ return []
+
+ if field_name == "tag_name":
+ try:
+ gh_client = self._get_client()
+ repo = gh_client.get_repo(source)
+ gh_tags = repo.get_tags()
+ if limit:
+ gh_tags = repo.get_tags()[0:limit]
+
+ return [tag.name for tag in gh_tags]
+ except GitHubBadCredentialsException:
+ return []
+ except GithubException:
+ logger.exception(
+ "Got GitHub Exception when trying to list tags for trigger %s",
+ self.trigger.id,
+ )
+ return []
+
+ if field_name == "branch_name":
+ try:
+ gh_client = self._get_client()
+ repo = gh_client.get_repo(source)
+ gh_branches = repo.get_branches()
+ if limit:
+ gh_branches = repo.get_branches()[0:limit]
+
+ branches = [branch.name for branch in gh_branches]
+
+ if not repo.default_branch in branches:
+ branches.insert(0, repo.default_branch)
+
+ if branches[0] != repo.default_branch:
+ branches.remove(repo.default_branch)
+ branches.insert(0, repo.default_branch)
+
+ return branches
+ except GitHubBadCredentialsException:
+ return ["master"]
+ except GithubException:
+ logger.exception(
+ "Got GitHub Exception when trying to list branches for trigger %s",
+ self.trigger.id,
+ )
+ return ["master"]
+
+ return None
+
+ @classmethod
+ def _build_metadata_for_commit(cls, commit_sha, ref, repo):
+ try:
+ commit = repo.get_commit(commit_sha)
+ except GithubException:
+ logger.exception("Could not load commit information from GitHub")
+ return None
+
+ commit_info = {
+ "url": commit.html_url,
+ "message": commit.commit.message,
+ "date": commit.last_modified,
+ }
+
+ if commit.author:
+ commit_info["author"] = {
+ "username": commit.author.login,
+ "avatar_url": commit.author.avatar_url,
+ "url": commit.author.html_url,
+ }
+
+ if commit.committer:
+ commit_info["committer"] = {
+ "username": commit.committer.login,
+ "avatar_url": commit.committer.avatar_url,
+ "url": commit.committer.html_url,
+ }
+
+ return {
+ "commit": commit_sha,
+ "ref": ref,
+ "default_branch": repo.default_branch,
+ "git_url": repo.ssh_url,
+ "commit_info": commit_info,
+ }
+
+ @_catch_ssl_errors
+ def manual_start(self, run_parameters=None):
+ config = self.config
+ source = config["build_source"]
+
+ try:
+ gh_client = self._get_client()
+ repo = gh_client.get_repo(source)
+ default_branch = repo.default_branch
+ except GithubException as ghe:
+ msg = GithubBuildTrigger._get_error_message(
+ ghe, "Unable to start build trigger"
+ )
+ raise TriggerStartException(msg)
+
+ def get_branch_sha(branch_name):
+ try:
+ branch = repo.get_branch(branch_name)
+ return branch.commit.sha
+ except GithubException:
+ raise TriggerStartException("Could not find branch in repository")
+
+ def get_tag_sha(tag_name):
+ tags = {tag.name: tag for tag in repo.get_tags()}
+ if not tag_name in tags:
+ raise TriggerStartException("Could not find tag in repository")
+
+ return tags[tag_name].commit.sha
+
+ # Find the branch or tag to build.
+ (commit_sha, ref) = determine_build_ref(
+ run_parameters, get_branch_sha, get_tag_sha, default_branch
+ )
+
+ metadata = GithubBuildTrigger._build_metadata_for_commit(commit_sha, ref, repo)
+ return self.prepare_build(metadata, is_manual=True)
+
+ @_catch_ssl_errors
+ def lookup_user(self, username):
+ try:
+ gh_client = self._get_client()
+ user = gh_client.get_user(username)
+ return {"html_url": user.html_url, "avatar_url": user.avatar_url}
+ except GithubException:
+ return None
+
+ @_catch_ssl_errors
+ def handle_trigger_request(self, request):
+ # Check the payload to see if we should skip it based on the lack of a head_commit.
+ payload = request.get_json()
+ if payload is None:
+ raise InvalidPayloadException("Missing payload")
+
+ # This is for GitHub's probing/testing.
+ if "zen" in payload:
+ raise SkipRequestException()
+
+ # Lookup the default branch for the repository.
+ if "repository" not in payload:
+ raise InvalidPayloadException("Missing 'repository' on request")
+
+ if "owner" not in payload["repository"]:
+ raise InvalidPayloadException("Missing 'owner' on repository")
+
+ if "name" not in payload["repository"]["owner"]:
+ raise InvalidPayloadException("Missing owner 'name' on repository")
+
+ if "name" not in payload["repository"]:
+ raise InvalidPayloadException("Missing 'name' on repository")
+
+ default_branch = None
+ lookup_user = None
+ try:
+ repo_full_name = "%s/%s" % (
+ payload["repository"]["owner"]["name"],
+ payload["repository"]["name"],
+ )
+
+ gh_client = self._get_client()
+ repo = gh_client.get_repo(repo_full_name)
+ default_branch = repo.default_branch
+ lookup_user = self.lookup_user
+ except GitHubBadCredentialsException:
+ logger.exception(
+ "Got GitHub Credentials Exception; Cannot lookup default branch"
+ )
+ except GithubException:
+ logger.exception(
+ "Got GitHub Exception when trying to start trigger %s", self.trigger.id
+ )
+ raise SkipRequestException()
+
+ logger.debug("GitHub trigger payload %s", payload)
+ metadata = get_transformed_webhook_payload(
+ payload, default_branch=default_branch, lookup_user=lookup_user
+ )
+ prepared = self.prepare_build(metadata)
+
+ # Check if we should skip this build.
+ raise_if_skipped_build(prepared, self.config)
+ return prepared
diff --git a/buildtrigger/gitlabhandler.py b/buildtrigger/gitlabhandler.py
index 9ed3e91d0..fbb0afa63 100644
--- a/buildtrigger/gitlabhandler.py
+++ b/buildtrigger/gitlabhandler.py
@@ -11,12 +11,18 @@ import requests
from jsonschema import validate
from app import app, gitlab_trigger
-from buildtrigger.triggerutil import (RepositoryReadException, TriggerActivationException,
- TriggerDeactivationException, TriggerStartException,
- SkipRequestException, InvalidPayloadException,
- TriggerAuthException,
- determine_build_ref, raise_if_skipped_build,
- find_matching_branches)
+from buildtrigger.triggerutil import (
+ RepositoryReadException,
+ TriggerActivationException,
+ TriggerDeactivationException,
+ TriggerStartException,
+ SkipRequestException,
+ InvalidPayloadException,
+ TriggerAuthException,
+ determine_build_ref,
+ raise_if_skipped_build,
+ find_matching_branches,
+)
from buildtrigger.basehandler import BuildTriggerHandler
from endpoints.exception import ExternalServiceError
from util.security.ssh import generate_ssh_keypair
@@ -25,597 +31,616 @@ from util.dict_wrappers import JSONPathDict, SafeDictSetter
logger = logging.getLogger(__name__)
GITLAB_WEBHOOK_PAYLOAD_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'ref': {
- 'type': 'string',
- },
- 'checkout_sha': {
- 'type': ['string', 'null'],
- },
- 'repository': {
- 'type': 'object',
- 'properties': {
- 'git_ssh_url': {
- 'type': 'string',
+ "type": "object",
+ "properties": {
+ "ref": {"type": "string"},
+ "checkout_sha": {"type": ["string", "null"]},
+ "repository": {
+ "type": "object",
+ "properties": {"git_ssh_url": {"type": "string"}},
+ "required": ["git_ssh_url"],
},
- },
- 'required': ['git_ssh_url'],
- },
- 'commits': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- },
- 'url': {
- 'type': ['string', 'null'],
- },
- 'message': {
- 'type': 'string',
- },
- 'timestamp': {
- 'type': 'string',
- },
- 'author': {
- 'type': 'object',
- 'properties': {
- 'email': {
- 'type': 'string',
- },
+ "commits": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "url": {"type": ["string", "null"]},
+ "message": {"type": "string"},
+ "timestamp": {"type": "string"},
+ "author": {
+ "type": "object",
+ "properties": {"email": {"type": "string"}},
+ "required": ["email"],
+ },
+ },
+ "required": ["id", "message", "timestamp"],
},
- 'required': ['email'],
- },
},
- 'required': ['id', 'message', 'timestamp'],
- },
},
- },
- 'required': ['ref', 'checkout_sha', 'repository'],
+ "required": ["ref", "checkout_sha", "repository"],
}
_ACCESS_LEVEL_MAP = {
- 50: ("owner", True),
- 40: ("master", True),
- 30: ("developer", False),
- 20: ("reporter", False),
- 10: ("guest", False),
+ 50: ("owner", True),
+ 40: ("master", True),
+ 30: ("developer", False),
+ 20: ("reporter", False),
+ 10: ("guest", False),
}
_PER_PAGE_COUNT = 20
def _catch_timeouts_and_errors(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- try:
- return func(*args, **kwargs)
- except requests.exceptions.Timeout:
- msg = 'Request to the GitLab API timed out'
- logger.exception(msg)
- raise ExternalServiceError(msg)
- except gitlab.GitlabError:
- msg = 'GitLab API error. Please contact support.'
- logger.exception(msg)
- raise ExternalServiceError(msg)
- return wrapper
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except requests.exceptions.Timeout:
+ msg = "Request to the GitLab API timed out"
+ logger.exception(msg)
+ raise ExternalServiceError(msg)
+ except gitlab.GitlabError:
+ msg = "GitLab API error. Please contact support."
+ logger.exception(msg)
+ raise ExternalServiceError(msg)
+
+ return wrapper
def _paginated_iterator(func, exc, **kwargs):
- """ Returns an iterator over invocations of the given function, automatically handling
+ """ Returns an iterator over invocations of the given function, automatically handling
pagination.
"""
- page = 1
- while True:
- result = func(page=page, per_page=_PER_PAGE_COUNT, **kwargs)
- if result is None or result is False:
- raise exc
+ page = 1
+ while True:
+ result = func(page=page, per_page=_PER_PAGE_COUNT, **kwargs)
+ if result is None or result is False:
+ raise exc
- counter = 0
- for item in result:
- yield item
- counter = counter + 1
+ counter = 0
+ for item in result:
+ yield item
+ counter = counter + 1
- if counter < _PER_PAGE_COUNT:
- break
+ if counter < _PER_PAGE_COUNT:
+ break
- page = page + 1
+ page = page + 1
-def get_transformed_webhook_payload(gl_payload, default_branch=None, lookup_user=None,
- lookup_commit=None):
- """ Returns the Gitlab webhook JSON payload transformed into our own payload
+def get_transformed_webhook_payload(
+ gl_payload, default_branch=None, lookup_user=None, lookup_commit=None
+):
+ """ Returns the Gitlab webhook JSON payload transformed into our own payload
format. If the gl_payload is not valid, returns None.
"""
- try:
- validate(gl_payload, GITLAB_WEBHOOK_PAYLOAD_SCHEMA)
- except Exception as exc:
- raise InvalidPayloadException(exc.message)
+ try:
+ validate(gl_payload, GITLAB_WEBHOOK_PAYLOAD_SCHEMA)
+ except Exception as exc:
+ raise InvalidPayloadException(exc.message)
- payload = JSONPathDict(gl_payload)
+ payload = JSONPathDict(gl_payload)
- if payload['object_kind'] != 'push' and payload['object_kind'] != 'tag_push':
- # Unknown kind of webhook.
- raise SkipRequestException
+ if payload["object_kind"] != "push" and payload["object_kind"] != "tag_push":
+ # Unknown kind of webhook.
+ raise SkipRequestException
- # Check for empty commits. The commits list will be empty if the branch is deleted.
- commits = payload['commits']
- if payload['object_kind'] == 'push' and not commits:
- raise SkipRequestException
+ # Check for empty commits. The commits list will be empty if the branch is deleted.
+ commits = payload["commits"]
+ if payload["object_kind"] == "push" and not commits:
+ raise SkipRequestException
- # Check for missing commit information.
- commit_sha = payload['checkout_sha'] or payload['after']
- if commit_sha is None or commit_sha == '0000000000000000000000000000000000000000':
- raise SkipRequestException
+ # Check for missing commit information.
+ commit_sha = payload["checkout_sha"] or payload["after"]
+ if commit_sha is None or commit_sha == "0000000000000000000000000000000000000000":
+ raise SkipRequestException
- config = SafeDictSetter()
- config['commit'] = commit_sha
- config['ref'] = payload['ref']
- config['default_branch'] = default_branch
- config['git_url'] = payload['repository.git_ssh_url']
+ config = SafeDictSetter()
+ config["commit"] = commit_sha
+ config["ref"] = payload["ref"]
+ config["default_branch"] = default_branch
+ config["git_url"] = payload["repository.git_ssh_url"]
- found_commit = JSONPathDict({})
- if payload['object_kind'] == 'push' or payload['object_kind'] == 'tag_push':
- # Find the commit associated with the checkout_sha. Gitlab doesn't (necessary) send this in
- # any order, so we cannot simply index into the commits list.
- found_commit = None
- if commits is not None:
- for commit in commits:
- if commit['id'] == payload['checkout_sha']:
- found_commit = JSONPathDict(commit)
- break
+ found_commit = JSONPathDict({})
+ if payload["object_kind"] == "push" or payload["object_kind"] == "tag_push":
+ # Find the commit associated with the checkout_sha. Gitlab doesn't (necessary) send this in
+ # any order, so we cannot simply index into the commits list.
+ found_commit = None
+ if commits is not None:
+ for commit in commits:
+ if commit["id"] == payload["checkout_sha"]:
+ found_commit = JSONPathDict(commit)
+ break
- if found_commit is None and lookup_commit:
- checkout_sha = payload['checkout_sha'] or payload['after']
- found_commit_info = lookup_commit(payload['project_id'], checkout_sha)
- found_commit = JSONPathDict(dict(found_commit_info) if found_commit_info else {})
+ if found_commit is None and lookup_commit:
+ checkout_sha = payload["checkout_sha"] or payload["after"]
+ found_commit_info = lookup_commit(payload["project_id"], checkout_sha)
+ found_commit = JSONPathDict(
+ dict(found_commit_info) if found_commit_info else {}
+ )
- if found_commit is None:
- raise SkipRequestException
+ if found_commit is None:
+ raise SkipRequestException
- config['commit_info.url'] = found_commit['url']
- config['commit_info.message'] = found_commit['message']
- config['commit_info.date'] = found_commit['timestamp']
+ config["commit_info.url"] = found_commit["url"]
+ config["commit_info.message"] = found_commit["message"]
+ config["commit_info.date"] = found_commit["timestamp"]
- # Note: Gitlab does not send full user information with the payload, so we have to
- # (optionally) look it up.
- author_email = found_commit['author.email'] or found_commit['author_email']
- if lookup_user and author_email:
- author_info = lookup_user(author_email)
- if author_info:
- config['commit_info.author.username'] = author_info['username']
- config['commit_info.author.url'] = author_info['html_url']
- config['commit_info.author.avatar_url'] = author_info['avatar_url']
+ # Note: Gitlab does not send full user information with the payload, so we have to
+ # (optionally) look it up.
+ author_email = found_commit["author.email"] or found_commit["author_email"]
+ if lookup_user and author_email:
+ author_info = lookup_user(author_email)
+ if author_info:
+ config["commit_info.author.username"] = author_info["username"]
+ config["commit_info.author.url"] = author_info["html_url"]
+ config["commit_info.author.avatar_url"] = author_info["avatar_url"]
- return config.dict_value()
+ return config.dict_value()
class GitLabBuildTrigger(BuildTriggerHandler):
- """
+ """
BuildTrigger for GitLab.
"""
- @classmethod
- def service_name(cls):
- return 'gitlab'
- def _get_authorized_client(self):
- auth_token = self.auth_token or 'invalid'
- api_version = self.config.get('API_VERSION', '4')
- client = gitlab.Gitlab(gitlab_trigger.api_endpoint(), oauth_token=auth_token, timeout=20,
- api_version=api_version)
- try:
- client.auth()
- except gitlab.GitlabGetError as ex:
- raise TriggerAuthException(ex.message)
+ @classmethod
+ def service_name(cls):
+ return "gitlab"
- return client
+ def _get_authorized_client(self):
+ auth_token = self.auth_token or "invalid"
+ api_version = self.config.get("API_VERSION", "4")
+ client = gitlab.Gitlab(
+ gitlab_trigger.api_endpoint(),
+ oauth_token=auth_token,
+ timeout=20,
+ api_version=api_version,
+ )
+ try:
+ client.auth()
+ except gitlab.GitlabGetError as ex:
+ raise TriggerAuthException(ex.message)
- def is_active(self):
- return 'hook_id' in self.config
+ return client
- @_catch_timeouts_and_errors
- def activate(self, standard_webhook_url):
- config = self.config
- new_build_source = config['build_source']
- gl_client = self._get_authorized_client()
+ def is_active(self):
+ return "hook_id" in self.config
- # Find the GitLab repository.
- gl_project = gl_client.projects.get(new_build_source)
- if not gl_project:
- msg = 'Unable to find GitLab repository for source: %s' % new_build_source
- raise TriggerActivationException(msg)
+ @_catch_timeouts_and_errors
+ def activate(self, standard_webhook_url):
+ config = self.config
+ new_build_source = config["build_source"]
+ gl_client = self._get_authorized_client()
- # Add a deploy key to the repository.
- public_key, private_key = generate_ssh_keypair()
- config['credentials'] = [
- {
- 'name': 'SSH Public Key',
- 'value': public_key,
- },
- ]
+ # Find the GitLab repository.
+ gl_project = gl_client.projects.get(new_build_source)
+ if not gl_project:
+ msg = "Unable to find GitLab repository for source: %s" % new_build_source
+ raise TriggerActivationException(msg)
- key = gl_project.keys.create({
- 'title': '%s Builder' % app.config['REGISTRY_TITLE'],
- 'key': public_key,
- })
+ # Add a deploy key to the repository.
+ public_key, private_key = generate_ssh_keypair()
+ config["credentials"] = [{"name": "SSH Public Key", "value": public_key}]
- if not key:
- msg = 'Unable to add deploy key to repository: %s' % new_build_source
- raise TriggerActivationException(msg)
+ key = gl_project.keys.create(
+ {"title": "%s Builder" % app.config["REGISTRY_TITLE"], "key": public_key}
+ )
- config['key_id'] = key.get_id()
+ if not key:
+ msg = "Unable to add deploy key to repository: %s" % new_build_source
+ raise TriggerActivationException(msg)
- # Add the webhook to the GitLab repository.
- hook = gl_project.hooks.create({
- 'url': standard_webhook_url,
- 'push': True,
- 'tag_push': True,
- 'push_events': True,
- 'tag_push_events': True,
- })
- if not hook:
- msg = 'Unable to create webhook on repository: %s' % new_build_source
- raise TriggerActivationException(msg)
+ config["key_id"] = key.get_id()
- config['hook_id'] = hook.get_id()
- self.config = config
- return config, {'private_key': private_key}
+ # Add the webhook to the GitLab repository.
+ hook = gl_project.hooks.create(
+ {
+ "url": standard_webhook_url,
+ "push": True,
+ "tag_push": True,
+ "push_events": True,
+ "tag_push_events": True,
+ }
+ )
+ if not hook:
+ msg = "Unable to create webhook on repository: %s" % new_build_source
+ raise TriggerActivationException(msg)
- def deactivate(self):
- config = self.config
- gl_client = self._get_authorized_client()
+ config["hook_id"] = hook.get_id()
+ self.config = config
+ return config, {"private_key": private_key}
+
+ def deactivate(self):
+ config = self.config
+ gl_client = self._get_authorized_client()
+
+ # Find the GitLab repository.
+ try:
+ gl_project = gl_client.projects.get(config["build_source"])
+ if not gl_project:
+ config.pop("key_id", None)
+ config.pop("hook_id", None)
+ self.config = config
+ return config
+ except gitlab.GitlabGetError as ex:
+ if ex.response_code != 404:
+ raise
+
+ # Remove the webhook.
+ try:
+ gl_project.hooks.delete(config["hook_id"])
+ except gitlab.GitlabDeleteError as ex:
+ if ex.response_code != 404:
+ raise
+
+ config.pop("hook_id", None)
+
+ # Remove the key
+ try:
+ gl_project.keys.delete(config["key_id"])
+ except gitlab.GitlabDeleteError as ex:
+ if ex.response_code != 404:
+ raise
+
+ config.pop("key_id", None)
- # Find the GitLab repository.
- try:
- gl_project = gl_client.projects.get(config['build_source'])
- if not gl_project:
- config.pop('key_id', None)
- config.pop('hook_id', None)
self.config = config
return config
- except gitlab.GitlabGetError as ex:
- if ex.response_code != 404:
- raise
- # Remove the webhook.
- try:
- gl_project.hooks.delete(config['hook_id'])
- except gitlab.GitlabDeleteError as ex:
- if ex.response_code != 404:
- raise
+ @_catch_timeouts_and_errors
+ def list_build_source_namespaces(self):
+ gl_client = self._get_authorized_client()
+ current_user = gl_client.user
+ if not current_user:
+ raise RepositoryReadException("Unable to get current user")
- config.pop('hook_id', None)
+ namespaces = {}
+ for namespace in _paginated_iterator(
+ gl_client.namespaces.list, RepositoryReadException
+ ):
+ namespace_id = namespace.get_id()
+ if namespace_id in namespaces:
+ namespaces[namespace_id]["score"] = (
+ namespaces[namespace_id]["score"] + 1
+ )
+ else:
+ owner = namespace.attributes["name"]
+ namespaces[namespace_id] = {
+ "personal": namespace.attributes["kind"] == "user",
+ "id": str(namespace_id),
+ "title": namespace.attributes["name"],
+ "avatar_url": namespace.attributes.get("avatar_url"),
+ "score": 1,
+ "url": namespace.attributes.get("web_url") or "",
+ }
- # Remove the key
- try:
- gl_project.keys.delete(config['key_id'])
- except gitlab.GitlabDeleteError as ex:
- if ex.response_code != 404:
- raise
+ return BuildTriggerHandler.build_namespaces_response(namespaces)
- config.pop('key_id', None)
+ def _get_namespace(self, gl_client, gl_namespace, lazy=False):
+ try:
+ if gl_namespace.attributes["kind"] == "group":
+ return gl_client.groups.get(gl_namespace.attributes["id"], lazy=lazy)
- self.config = config
- return config
+ if gl_namespace.attributes["kind"] == "user":
+ return gl_client.users.get(gl_client.user.attributes["id"], lazy=lazy)
- @_catch_timeouts_and_errors
- def list_build_source_namespaces(self):
- gl_client = self._get_authorized_client()
- current_user = gl_client.user
- if not current_user:
- raise RepositoryReadException('Unable to get current user')
+ # Note: This doesn't seem to work for IDs retrieved via the namespaces API; the IDs are
+ # different.
+ return gl_client.users.get(gl_namespace.attributes["id"], lazy=lazy)
+ except gitlab.GitlabGetError:
+ return None
- namespaces = {}
- for namespace in _paginated_iterator(gl_client.namespaces.list, RepositoryReadException):
- namespace_id = namespace.get_id()
- if namespace_id in namespaces:
- namespaces[namespace_id]['score'] = namespaces[namespace_id]['score'] + 1
- else:
- owner = namespace.attributes['name']
- namespaces[namespace_id] = {
- 'personal': namespace.attributes['kind'] == 'user',
- 'id': str(namespace_id),
- 'title': namespace.attributes['name'],
- 'avatar_url': namespace.attributes.get('avatar_url'),
- 'score': 1,
- 'url': namespace.attributes.get('web_url') or '',
+ @_catch_timeouts_and_errors
+ def list_build_sources_for_namespace(self, namespace_id):
+ if not namespace_id:
+ return []
+
+ def repo_view(repo):
+ # Because *anything* can be None in GitLab API!
+ permissions = repo.attributes.get("permissions") or {}
+ group_access = permissions.get("group_access") or {}
+ project_access = permissions.get("project_access") or {}
+
+ missing_group_access = permissions.get("group_access") is None
+ missing_project_access = permissions.get("project_access") is None
+
+ access_level = max(
+ group_access.get("access_level") or 0,
+ project_access.get("access_level") or 0,
+ )
+
+ has_admin_permission = _ACCESS_LEVEL_MAP.get(access_level, ("", False))[1]
+ if missing_group_access or missing_project_access:
+ # Default to has permission if we cannot check the permissions. This will allow our users
+ # to select the repository and then GitLab's own checks will ensure that the webhook is
+ # added only if allowed.
+ # TODO: Do we want to display this differently in the UI?
+ has_admin_permission = True
+
+ view = {
+ "name": repo.attributes["path"],
+ "full_name": repo.attributes["path_with_namespace"],
+ "description": repo.attributes.get("description") or "",
+ "url": repo.attributes.get("web_url"),
+ "has_admin_permissions": has_admin_permission,
+ "private": repo.attributes.get("visibility") == "private",
+ }
+
+ if repo.attributes.get("last_activity_at"):
+ try:
+ last_modified = dateutil.parser.parse(
+ repo.attributes["last_activity_at"]
+ )
+ view["last_updated"] = timegm(last_modified.utctimetuple())
+ except ValueError:
+ logger.exception(
+ "Gitlab gave us an invalid last_activity_at: %s", last_modified
+ )
+
+ return view
+
+ gl_client = self._get_authorized_client()
+
+ try:
+ gl_namespace = gl_client.namespaces.get(namespace_id)
+ except gitlab.GitlabGetError:
+ return []
+
+ namespace_obj = self._get_namespace(gl_client, gl_namespace, lazy=True)
+ repositories = _paginated_iterator(
+ namespace_obj.projects.list, RepositoryReadException
+ )
+
+ try:
+ return BuildTriggerHandler.build_sources_response(
+ [repo_view(repo) for repo in repositories]
+ )
+ except gitlab.GitlabGetError:
+ return []
+
+ @_catch_timeouts_and_errors
+ def list_build_subdirs(self):
+ config = self.config
+ gl_client = self._get_authorized_client()
+ new_build_source = config["build_source"]
+
+ gl_project = gl_client.projects.get(new_build_source)
+ if not gl_project:
+ msg = "Unable to find GitLab repository for source: %s" % new_build_source
+ raise RepositoryReadException(msg)
+
+ repo_branches = gl_project.branches.list()
+ if not repo_branches:
+ msg = "Unable to find GitLab branches for source: %s" % new_build_source
+ raise RepositoryReadException(msg)
+
+ branches = [branch.attributes["name"] for branch in repo_branches]
+ branches = find_matching_branches(config, branches)
+ branches = branches or [gl_project.attributes["default_branch"] or "master"]
+
+ repo_tree = gl_project.repository_tree(ref=branches[0])
+ if not repo_tree:
+ msg = (
+ "Unable to find GitLab repository tree for source: %s"
+ % new_build_source
+ )
+ raise RepositoryReadException(msg)
+
+ return [
+ node["name"]
+ for node in repo_tree
+ if self.filename_is_dockerfile(node["name"])
+ ]
+
+ @_catch_timeouts_and_errors
+ def load_dockerfile_contents(self):
+ gl_client = self._get_authorized_client()
+ path = self.get_dockerfile_path()
+
+ gl_project = gl_client.projects.get(self.config["build_source"])
+ if not gl_project:
+ return None
+
+ branches = self.list_field_values("branch_name")
+ branches = find_matching_branches(self.config, branches)
+ if branches == []:
+ return None
+
+ branch_name = branches[0]
+ if gl_project.attributes["default_branch"] in branches:
+ branch_name = gl_project.attributes["default_branch"]
+
+ try:
+ return gl_project.files.get(path, branch_name).decode()
+ except gitlab.GitlabGetError:
+ return None
+
+ @_catch_timeouts_and_errors
+ def list_field_values(self, field_name, limit=None):
+ if field_name == "refs":
+ branches = self.list_field_values("branch_name")
+ tags = self.list_field_values("tag_name")
+
+ return [{"kind": "branch", "name": b} for b in branches] + [
+ {"kind": "tag", "name": t} for t in tags
+ ]
+
+ gl_client = self._get_authorized_client()
+ gl_project = gl_client.projects.get(self.config["build_source"])
+ if not gl_project:
+ return []
+
+ if field_name == "tag_name":
+ tags = gl_project.tags.list()
+ if not tags:
+ return []
+
+ if limit:
+ tags = tags[0:limit]
+
+ return [tag.attributes["name"] for tag in tags]
+
+ if field_name == "branch_name":
+ branches = gl_project.branches.list()
+ if not branches:
+ return []
+
+ if limit:
+ branches = branches[0:limit]
+
+ return [branch.attributes["name"] for branch in branches]
+
+ return None
+
+ def get_repository_url(self):
+ return gitlab_trigger.get_public_url(self.config["build_source"])
+
+ @_catch_timeouts_and_errors
+ def lookup_commit(self, repo_id, commit_sha):
+ if repo_id is None:
+ return None
+
+ gl_client = self._get_authorized_client()
+ gl_project = gl_client.projects.get(self.config["build_source"], lazy=True)
+ commit = gl_project.commits.get(commit_sha)
+ if not commit:
+ return None
+
+ return commit
+
+ @_catch_timeouts_and_errors
+ def lookup_user(self, email):
+ gl_client = self._get_authorized_client()
+ try:
+ result = gl_client.users.list(search=email)
+ if not result:
+ return None
+
+ [user] = result
+ return {
+ "username": user.attributes["username"],
+ "html_url": user.attributes["web_url"],
+ "avatar_url": user.attributes["avatar_url"],
+ }
+ except ValueError:
+ return None
+
+ @_catch_timeouts_and_errors
+ def get_metadata_for_commit(self, commit_sha, ref, repo):
+ commit = self.lookup_commit(repo.get_id(), commit_sha)
+ if commit is None:
+ return None
+
+ metadata = {
+ "commit": commit.attributes["id"],
+ "ref": ref,
+ "default_branch": repo.attributes["default_branch"],
+ "git_url": repo.attributes["ssh_url_to_repo"],
+ "commit_info": {
+ "url": os.path.join(
+ repo.attributes["web_url"], "commit", commit.attributes["id"]
+ ),
+ "message": commit.attributes["message"],
+ "date": commit.attributes["committed_date"],
+ },
}
- return BuildTriggerHandler.build_namespaces_response(namespaces)
+ committer = None
+ if "committer_email" in commit.attributes:
+ committer = self.lookup_user(commit.attributes["committer_email"])
- def _get_namespace(self, gl_client, gl_namespace, lazy=False):
- try:
- if gl_namespace.attributes['kind'] == 'group':
- return gl_client.groups.get(gl_namespace.attributes['id'], lazy=lazy)
+ author = None
+ if "author_email" in commit.attributes:
+ author = self.lookup_user(commit.attributes["author_email"])
- if gl_namespace.attributes['kind'] == 'user':
- return gl_client.users.get(gl_client.user.attributes['id'], lazy=lazy)
+ if committer is not None:
+ metadata["commit_info"]["committer"] = {
+ "username": committer["username"],
+ "avatar_url": committer["avatar_url"],
+ "url": committer.get("http_url", ""),
+ }
- # Note: This doesn't seem to work for IDs retrieved via the namespaces API; the IDs are
- # different.
- return gl_client.users.get(gl_namespace.attributes['id'], lazy=lazy)
- except gitlab.GitlabGetError:
- return None
+ if author is not None:
+ metadata["commit_info"]["author"] = {
+ "username": author["username"],
+ "avatar_url": author["avatar_url"],
+ "url": author.get("http_url", ""),
+ }
- @_catch_timeouts_and_errors
- def list_build_sources_for_namespace(self, namespace_id):
- if not namespace_id:
- return []
+ return metadata
- def repo_view(repo):
- # Because *anything* can be None in GitLab API!
- permissions = repo.attributes.get('permissions') or {}
- group_access = permissions.get('group_access') or {}
- project_access = permissions.get('project_access') or {}
+ @_catch_timeouts_and_errors
+ def manual_start(self, run_parameters=None):
+ gl_client = self._get_authorized_client()
+ gl_project = gl_client.projects.get(self.config["build_source"])
+ if not gl_project:
+ raise TriggerStartException("Could not find repository")
- missing_group_access = permissions.get('group_access') is None
- missing_project_access = permissions.get('project_access') is None
+ def get_tag_sha(tag_name):
+ try:
+ tag = gl_project.tags.get(tag_name)
+ except gitlab.GitlabGetError:
+ raise TriggerStartException("Could not find tag in repository")
- access_level = max(group_access.get('access_level') or 0,
- project_access.get('access_level') or 0)
+ return tag.attributes["commit"]["id"]
- has_admin_permission = _ACCESS_LEVEL_MAP.get(access_level, ("", False))[1]
- if missing_group_access or missing_project_access:
- # Default to has permission if we cannot check the permissions. This will allow our users
- # to select the repository and then GitLab's own checks will ensure that the webhook is
- # added only if allowed.
- # TODO: Do we want to display this differently in the UI?
- has_admin_permission = True
+ def get_branch_sha(branch_name):
+ try:
+ branch = gl_project.branches.get(branch_name)
+ except gitlab.GitlabGetError:
+ raise TriggerStartException("Could not find branch in repository")
- view = {
- 'name': repo.attributes['path'],
- 'full_name': repo.attributes['path_with_namespace'],
- 'description': repo.attributes.get('description') or '',
- 'url': repo.attributes.get('web_url'),
- 'has_admin_permissions': has_admin_permission,
- 'private': repo.attributes.get('visibility') == 'private',
- }
+ return branch.attributes["commit"]["id"]
- if repo.attributes.get('last_activity_at'):
- try:
- last_modified = dateutil.parser.parse(repo.attributes['last_activity_at'])
- view['last_updated'] = timegm(last_modified.utctimetuple())
- except ValueError:
- logger.exception('Gitlab gave us an invalid last_activity_at: %s', last_modified)
+ # Find the branch or tag to build.
+ (commit_sha, ref) = determine_build_ref(
+ run_parameters,
+ get_branch_sha,
+ get_tag_sha,
+ gl_project.attributes["default_branch"],
+ )
- return view
+ metadata = self.get_metadata_for_commit(commit_sha, ref, gl_project)
+ return self.prepare_build(metadata, is_manual=True)
- gl_client = self._get_authorized_client()
+ @_catch_timeouts_and_errors
+ def handle_trigger_request(self, request):
+ payload = request.get_json()
+ if not payload:
+ raise InvalidPayloadException()
- try:
- gl_namespace = gl_client.namespaces.get(namespace_id)
- except gitlab.GitlabGetError:
- return []
+ logger.debug("GitLab trigger payload %s", payload)
- namespace_obj = self._get_namespace(gl_client, gl_namespace, lazy=True)
- repositories = _paginated_iterator(namespace_obj.projects.list, RepositoryReadException)
+ # Lookup the default branch.
+ gl_client = self._get_authorized_client()
+ gl_project = gl_client.projects.get(self.config["build_source"])
+ if not gl_project:
+ logger.debug(
+ "Skipping GitLab build; project %s not found",
+ self.config["build_source"],
+ )
+ raise InvalidPayloadException()
- try:
- return BuildTriggerHandler.build_sources_response([repo_view(repo) for repo in repositories])
- except gitlab.GitlabGetError:
- return []
+ def lookup_commit(repo_id, commit_sha):
+ commit = self.lookup_commit(repo_id, commit_sha)
+ if commit is None:
+ return None
- @_catch_timeouts_and_errors
- def list_build_subdirs(self):
- config = self.config
- gl_client = self._get_authorized_client()
- new_build_source = config['build_source']
+ return dict(commit.attributes)
- gl_project = gl_client.projects.get(new_build_source)
- if not gl_project:
- msg = 'Unable to find GitLab repository for source: %s' % new_build_source
- raise RepositoryReadException(msg)
+ default_branch = gl_project.attributes["default_branch"]
+ metadata = get_transformed_webhook_payload(
+ payload,
+ default_branch=default_branch,
+ lookup_user=self.lookup_user,
+ lookup_commit=lookup_commit,
+ )
+ prepared = self.prepare_build(metadata)
- repo_branches = gl_project.branches.list()
- if not repo_branches:
- msg = 'Unable to find GitLab branches for source: %s' % new_build_source
- raise RepositoryReadException(msg)
-
- branches = [branch.attributes['name'] for branch in repo_branches]
- branches = find_matching_branches(config, branches)
- branches = branches or [gl_project.attributes['default_branch'] or 'master']
-
- repo_tree = gl_project.repository_tree(ref=branches[0])
- if not repo_tree:
- msg = 'Unable to find GitLab repository tree for source: %s' % new_build_source
- raise RepositoryReadException(msg)
-
- return [node['name'] for node in repo_tree if self.filename_is_dockerfile(node['name'])]
-
- @_catch_timeouts_and_errors
- def load_dockerfile_contents(self):
- gl_client = self._get_authorized_client()
- path = self.get_dockerfile_path()
-
- gl_project = gl_client.projects.get(self.config['build_source'])
- if not gl_project:
- return None
-
- branches = self.list_field_values('branch_name')
- branches = find_matching_branches(self.config, branches)
- if branches == []:
- return None
-
- branch_name = branches[0]
- if gl_project.attributes['default_branch'] in branches:
- branch_name = gl_project.attributes['default_branch']
-
- try:
- return gl_project.files.get(path, branch_name).decode()
- except gitlab.GitlabGetError:
- return None
-
- @_catch_timeouts_and_errors
- def list_field_values(self, field_name, limit=None):
- if field_name == 'refs':
- branches = self.list_field_values('branch_name')
- tags = self.list_field_values('tag_name')
-
- return ([{'kind': 'branch', 'name': b} for b in branches] +
- [{'kind': 'tag', 'name': t} for t in tags])
-
- gl_client = self._get_authorized_client()
- gl_project = gl_client.projects.get(self.config['build_source'])
- if not gl_project:
- return []
-
- if field_name == 'tag_name':
- tags = gl_project.tags.list()
- if not tags:
- return []
-
- if limit:
- tags = tags[0:limit]
-
- return [tag.attributes['name'] for tag in tags]
-
- if field_name == 'branch_name':
- branches = gl_project.branches.list()
- if not branches:
- return []
-
- if limit:
- branches = branches[0:limit]
-
- return [branch.attributes['name'] for branch in branches]
-
- return None
-
- def get_repository_url(self):
- return gitlab_trigger.get_public_url(self.config['build_source'])
-
- @_catch_timeouts_and_errors
- def lookup_commit(self, repo_id, commit_sha):
- if repo_id is None:
- return None
-
- gl_client = self._get_authorized_client()
- gl_project = gl_client.projects.get(self.config['build_source'], lazy=True)
- commit = gl_project.commits.get(commit_sha)
- if not commit:
- return None
-
- return commit
-
- @_catch_timeouts_and_errors
- def lookup_user(self, email):
- gl_client = self._get_authorized_client()
- try:
- result = gl_client.users.list(search=email)
- if not result:
- return None
-
- [user] = result
- return {
- 'username': user.attributes['username'],
- 'html_url': user.attributes['web_url'],
- 'avatar_url': user.attributes['avatar_url']
- }
- except ValueError:
- return None
-
- @_catch_timeouts_and_errors
- def get_metadata_for_commit(self, commit_sha, ref, repo):
- commit = self.lookup_commit(repo.get_id(), commit_sha)
- if commit is None:
- return None
-
- metadata = {
- 'commit': commit.attributes['id'],
- 'ref': ref,
- 'default_branch': repo.attributes['default_branch'],
- 'git_url': repo.attributes['ssh_url_to_repo'],
- 'commit_info': {
- 'url': os.path.join(repo.attributes['web_url'], 'commit', commit.attributes['id']),
- 'message': commit.attributes['message'],
- 'date': commit.attributes['committed_date'],
- },
- }
-
- committer = None
- if 'committer_email' in commit.attributes:
- committer = self.lookup_user(commit.attributes['committer_email'])
-
- author = None
- if 'author_email' in commit.attributes:
- author = self.lookup_user(commit.attributes['author_email'])
-
- if committer is not None:
- metadata['commit_info']['committer'] = {
- 'username': committer['username'],
- 'avatar_url': committer['avatar_url'],
- 'url': committer.get('http_url', ''),
- }
-
- if author is not None:
- metadata['commit_info']['author'] = {
- 'username': author['username'],
- 'avatar_url': author['avatar_url'],
- 'url': author.get('http_url', ''),
- }
-
- return metadata
-
- @_catch_timeouts_and_errors
- def manual_start(self, run_parameters=None):
- gl_client = self._get_authorized_client()
- gl_project = gl_client.projects.get(self.config['build_source'])
- if not gl_project:
- raise TriggerStartException('Could not find repository')
-
- def get_tag_sha(tag_name):
- try:
- tag = gl_project.tags.get(tag_name)
- except gitlab.GitlabGetError:
- raise TriggerStartException('Could not find tag in repository')
-
- return tag.attributes['commit']['id']
-
- def get_branch_sha(branch_name):
- try:
- branch = gl_project.branches.get(branch_name)
- except gitlab.GitlabGetError:
- raise TriggerStartException('Could not find branch in repository')
-
- return branch.attributes['commit']['id']
-
- # Find the branch or tag to build.
- (commit_sha, ref) = determine_build_ref(run_parameters, get_branch_sha, get_tag_sha,
- gl_project.attributes['default_branch'])
-
- metadata = self.get_metadata_for_commit(commit_sha, ref, gl_project)
- return self.prepare_build(metadata, is_manual=True)
-
- @_catch_timeouts_and_errors
- def handle_trigger_request(self, request):
- payload = request.get_json()
- if not payload:
- raise InvalidPayloadException()
-
- logger.debug('GitLab trigger payload %s', payload)
-
- # Lookup the default branch.
- gl_client = self._get_authorized_client()
- gl_project = gl_client.projects.get(self.config['build_source'])
- if not gl_project:
- logger.debug('Skipping GitLab build; project %s not found', self.config['build_source'])
- raise InvalidPayloadException()
-
- def lookup_commit(repo_id, commit_sha):
- commit = self.lookup_commit(repo_id, commit_sha)
- if commit is None:
- return None
-
- return dict(commit.attributes)
-
- default_branch = gl_project.attributes['default_branch']
- metadata = get_transformed_webhook_payload(payload, default_branch=default_branch,
- lookup_user=self.lookup_user,
- lookup_commit=lookup_commit)
- prepared = self.prepare_build(metadata)
-
- # Check if we should skip this build.
- raise_if_skipped_build(prepared, self.config)
- return prepared
+ # Check if we should skip this build.
+ raise_if_skipped_build(prepared, self.config)
+ return prepared
diff --git a/buildtrigger/test/bitbucketmock.py b/buildtrigger/test/bitbucketmock.py
index 0e5cad97f..b6cf2b3b8 100644
--- a/buildtrigger/test/bitbucketmock.py
+++ b/buildtrigger/test/bitbucketmock.py
@@ -4,156 +4,168 @@ from mock import Mock
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
from util.morecollections import AttrDict
-def get_bitbucket_trigger(dockerfile_path=''):
- trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
- trigger = BitbucketBuildTrigger(trigger_obj, {
- 'build_source': 'foo/bar',
- 'dockerfile_path': dockerfile_path,
- 'nickname': 'knownuser',
- 'account_id': 'foo',
- })
- trigger._get_client = get_mock_bitbucket
- return trigger
+def get_bitbucket_trigger(dockerfile_path=""):
+ trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
+ trigger = BitbucketBuildTrigger(
+ trigger_obj,
+ {
+ "build_source": "foo/bar",
+ "dockerfile_path": dockerfile_path,
+ "nickname": "knownuser",
+ "account_id": "foo",
+ },
+ )
+
+ trigger._get_client = get_mock_bitbucket
+ return trigger
+
def get_repo_path_contents(path, revision):
- data = {
- 'files': [{'path': 'Dockerfile'}],
- }
+ data = {"files": [{"path": "Dockerfile"}]}
+
+ return (True, data, None)
- return (True, data, None)
def get_raw_path_contents(path, revision):
- if path == 'Dockerfile':
- return (True, 'hello world', None)
+ if path == "Dockerfile":
+ return (True, "hello world", None)
- if path == 'somesubdir/Dockerfile':
- return (True, 'hi universe', None)
+ if path == "somesubdir/Dockerfile":
+ return (True, "hi universe", None)
+
+ return (False, None, None)
- return (False, None, None)
def get_branches_and_tags():
- data = {
- 'branches': [{'name': 'master'}, {'name': 'otherbranch'}],
- 'tags': [{'name': 'sometag'}, {'name': 'someothertag'}],
- }
- return (True, data, None)
+ data = {
+ "branches": [{"name": "master"}, {"name": "otherbranch"}],
+ "tags": [{"name": "sometag"}, {"name": "someothertag"}],
+ }
+ return (True, data, None)
+
def get_branches():
- return (True, {'master': {}, 'otherbranch': {}}, None)
+ return (True, {"master": {}, "otherbranch": {}}, None)
+
def get_tags():
- return (True, {'sometag': {}, 'someothertag': {}}, None)
+ return (True, {"sometag": {}, "someothertag": {}}, None)
+
def get_branch(branch_name):
- if branch_name != 'master':
- return (False, None, None)
+ if branch_name != "master":
+ return (False, None, None)
- data = {
- 'target': {
- 'hash': 'aaaaaaa',
- },
- }
+ data = {"target": {"hash": "aaaaaaa"}}
+
+ return (True, data, None)
- return (True, data, None)
def get_tag(tag_name):
- if tag_name != 'sometag':
- return (False, None, None)
+ if tag_name != "sometag":
+ return (False, None, None)
- data = {
- 'target': {
- 'hash': 'aaaaaaa',
- },
- }
+ data = {"target": {"hash": "aaaaaaa"}}
+
+ return (True, data, None)
- return (True, data, None)
def get_changeset_mock(commit_sha):
- if commit_sha != 'aaaaaaa':
- return (False, None, 'Not found')
+ if commit_sha != "aaaaaaa":
+ return (False, None, "Not found")
- data = {
- 'node': 'aaaaaaa',
- 'message': 'some message',
- 'timestamp': 'now',
- 'raw_author': 'foo@bar.com',
- }
+ data = {
+ "node": "aaaaaaa",
+ "message": "some message",
+ "timestamp": "now",
+ "raw_author": "foo@bar.com",
+ }
+
+ return (True, data, None)
- return (True, data, None)
def get_changesets():
- changesets_mock = Mock()
- changesets_mock.get = Mock(side_effect=get_changeset_mock)
- return changesets_mock
+ changesets_mock = Mock()
+ changesets_mock.get = Mock(side_effect=get_changeset_mock)
+ return changesets_mock
+
def get_deploykeys():
- deploykeys_mock = Mock()
- deploykeys_mock.create = Mock(return_value=(True, {'pk': 'someprivatekey'}, None))
- deploykeys_mock.delete = Mock(return_value=(True, {}, None))
- return deploykeys_mock
+ deploykeys_mock = Mock()
+ deploykeys_mock.create = Mock(return_value=(True, {"pk": "someprivatekey"}, None))
+ deploykeys_mock.delete = Mock(return_value=(True, {}, None))
+ return deploykeys_mock
+
def get_webhooks():
- webhooks_mock = Mock()
- webhooks_mock.create = Mock(return_value=(True, {'uuid': 'someuuid'}, None))
- webhooks_mock.delete = Mock(return_value=(True, {}, None))
- return webhooks_mock
+ webhooks_mock = Mock()
+ webhooks_mock.create = Mock(return_value=(True, {"uuid": "someuuid"}, None))
+ webhooks_mock.delete = Mock(return_value=(True, {}, None))
+ return webhooks_mock
+
def get_repo_mock(name):
- if name != 'bar':
- return None
+ if name != "bar":
+ return None
- repo_mock = Mock()
- repo_mock.get_main_branch = Mock(return_value=(True, {'name': 'master'}, None))
- repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
- repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
- repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
- repo_mock.get_branches = Mock(side_effect=get_branches)
- repo_mock.get_tags = Mock(side_effect=get_tags)
- repo_mock.get_branch = Mock(side_effect=get_branch)
- repo_mock.get_tag = Mock(side_effect=get_tag)
+ repo_mock = Mock()
+ repo_mock.get_main_branch = Mock(return_value=(True, {"name": "master"}, None))
+ repo_mock.get_path_contents = Mock(side_effect=get_repo_path_contents)
+ repo_mock.get_raw_path_contents = Mock(side_effect=get_raw_path_contents)
+ repo_mock.get_branches_and_tags = Mock(side_effect=get_branches_and_tags)
+ repo_mock.get_branches = Mock(side_effect=get_branches)
+ repo_mock.get_tags = Mock(side_effect=get_tags)
+ repo_mock.get_branch = Mock(side_effect=get_branch)
+ repo_mock.get_tag = Mock(side_effect=get_tag)
+
+ repo_mock.changesets = Mock(side_effect=get_changesets)
+ repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
+ repo_mock.webhooks = Mock(side_effect=get_webhooks)
+ return repo_mock
- repo_mock.changesets = Mock(side_effect=get_changesets)
- repo_mock.deploykeys = Mock(side_effect=get_deploykeys)
- repo_mock.webhooks = Mock(side_effect=get_webhooks)
- return repo_mock
def get_repositories_mock():
- repos_mock = Mock()
- repos_mock.get = Mock(side_effect=get_repo_mock)
- return repos_mock
+ repos_mock = Mock()
+ repos_mock.get = Mock(side_effect=get_repo_mock)
+ return repos_mock
+
def get_namespace_mock(namespace):
- namespace_mock = Mock()
- namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
- return namespace_mock
+ namespace_mock = Mock()
+ namespace_mock.repositories = Mock(side_effect=get_repositories_mock)
+ return namespace_mock
+
def get_repo(namespace, name):
- return {
- 'owner': namespace,
- 'logo': 'avatarurl',
- 'slug': name,
- 'description': 'some %s repo' % (name),
- 'utc_last_updated': str(datetime.utcfromtimestamp(0)),
- 'read_only': namespace != 'knownuser',
- 'is_private': name == 'somerepo',
- }
+ return {
+ "owner": namespace,
+ "logo": "avatarurl",
+ "slug": name,
+ "description": "some %s repo" % (name),
+ "utc_last_updated": str(datetime.utcfromtimestamp(0)),
+ "read_only": namespace != "knownuser",
+ "is_private": name == "somerepo",
+ }
+
def get_visible_repos():
- repos = [
- get_repo('knownuser', 'somerepo'),
- get_repo('someorg', 'somerepo'),
- get_repo('someorg', 'anotherrepo'),
- ]
- return (True, repos, None)
+ repos = [
+ get_repo("knownuser", "somerepo"),
+ get_repo("someorg", "somerepo"),
+ get_repo("someorg", "anotherrepo"),
+ ]
+ return (True, repos, None)
+
def get_authed_mock(token, secret):
- authed_mock = Mock()
- authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
- authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
- return authed_mock
+ authed_mock = Mock()
+ authed_mock.for_namespace = Mock(side_effect=get_namespace_mock)
+ authed_mock.get_visible_repositories = Mock(side_effect=get_visible_repos)
+ return authed_mock
+
def get_mock_bitbucket():
- bitbucket_mock = Mock()
- bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)
- return bitbucket_mock
+ bitbucket_mock = Mock()
+ bitbucket_mock.get_authorized_client = Mock(side_effect=get_authed_mock)
+ return bitbucket_mock
diff --git a/buildtrigger/test/githubmock.py b/buildtrigger/test/githubmock.py
index e0f8daffc..c8fcbe73f 100644
--- a/buildtrigger/test/githubmock.py
+++ b/buildtrigger/test/githubmock.py
@@ -6,173 +6,178 @@ from github import GithubException
from buildtrigger.githubhandler import GithubBuildTrigger
from util.morecollections import AttrDict
-def get_github_trigger(dockerfile_path=''):
- trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
- trigger = GithubBuildTrigger(trigger_obj, {'build_source': 'foo', 'dockerfile_path': dockerfile_path})
- trigger._get_client = get_mock_github
- return trigger
+
+def get_github_trigger(dockerfile_path=""):
+ trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
+ trigger = GithubBuildTrigger(
+ trigger_obj, {"build_source": "foo", "dockerfile_path": dockerfile_path}
+ )
+ trigger._get_client = get_mock_github
+ return trigger
+
def get_mock_github():
- def get_commit_mock(commit_sha):
- if commit_sha == 'aaaaaaa':
- commit_mock = Mock()
- commit_mock.sha = commit_sha
- commit_mock.html_url = 'http://url/to/commit'
- commit_mock.last_modified = 'now'
+ def get_commit_mock(commit_sha):
+ if commit_sha == "aaaaaaa":
+ commit_mock = Mock()
+ commit_mock.sha = commit_sha
+ commit_mock.html_url = "http://url/to/commit"
+ commit_mock.last_modified = "now"
- commit_mock.commit = Mock()
- commit_mock.commit.message = 'some cool message'
+ commit_mock.commit = Mock()
+ commit_mock.commit.message = "some cool message"
- commit_mock.committer = Mock()
- commit_mock.committer.login = 'someuser'
- commit_mock.committer.avatar_url = 'avatarurl'
- commit_mock.committer.html_url = 'htmlurl'
+ commit_mock.committer = Mock()
+ commit_mock.committer.login = "someuser"
+ commit_mock.committer.avatar_url = "avatarurl"
+ commit_mock.committer.html_url = "htmlurl"
- commit_mock.author = Mock()
- commit_mock.author.login = 'someuser'
- commit_mock.author.avatar_url = 'avatarurl'
- commit_mock.author.html_url = 'htmlurl'
- return commit_mock
+ commit_mock.author = Mock()
+ commit_mock.author.login = "someuser"
+ commit_mock.author.avatar_url = "avatarurl"
+ commit_mock.author.html_url = "htmlurl"
+ return commit_mock
- raise GithubException(None, None)
+ raise GithubException(None, None)
- def get_branch_mock(branch_name):
- if branch_name == 'master':
- branch_mock = Mock()
- branch_mock.commit = Mock()
- branch_mock.commit.sha = 'aaaaaaa'
- return branch_mock
+ def get_branch_mock(branch_name):
+ if branch_name == "master":
+ branch_mock = Mock()
+ branch_mock.commit = Mock()
+ branch_mock.commit.sha = "aaaaaaa"
+ return branch_mock
- raise GithubException(None, None)
+ raise GithubException(None, None)
+
+ def get_repo_mock(namespace, name):
+ repo_mock = Mock()
+ repo_mock.owner = Mock()
+ repo_mock.owner.login = namespace
+
+ repo_mock.full_name = "%s/%s" % (namespace, name)
+ repo_mock.name = name
+ repo_mock.description = "some %s repo" % (name)
+
+ if name != "anotherrepo":
+ repo_mock.pushed_at = datetime.utcfromtimestamp(0)
+ else:
+ repo_mock.pushed_at = None
+
+ repo_mock.html_url = "https://bitbucket.org/%s/%s" % (namespace, name)
+ repo_mock.private = name == "somerepo"
+ repo_mock.permissions = Mock()
+ repo_mock.permissions.admin = namespace == "knownuser"
+ return repo_mock
+
+ def get_user_repos_mock(type="all", sort="created"):
+ return [get_repo_mock("knownuser", "somerepo")]
+
+ def get_org_repos_mock(type="all"):
+ return [
+ get_repo_mock("someorg", "somerepo"),
+ get_repo_mock("someorg", "anotherrepo"),
+ ]
+
+ def get_orgs_mock():
+ return [get_org_mock("someorg")]
+
+ def get_user_mock(username="knownuser"):
+ if username == "knownuser":
+ user_mock = Mock()
+ user_mock.name = username
+ user_mock.plan = Mock()
+ user_mock.plan.private_repos = 1
+ user_mock.login = username
+ user_mock.html_url = "https://bitbucket.org/%s" % (username)
+ user_mock.avatar_url = "avatarurl"
+ user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
+ user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
+ return user_mock
+
+ raise GithubException(None, None)
+
+ def get_org_mock(namespace):
+ if namespace == "someorg":
+ org_mock = Mock()
+ org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
+ org_mock.login = namespace
+ org_mock.html_url = "https://bitbucket.org/%s" % (namespace)
+ org_mock.avatar_url = "avatarurl"
+ org_mock.name = namespace
+ org_mock.plan = Mock()
+ org_mock.plan.private_repos = 2
+ return org_mock
+
+ raise GithubException(None, None)
+
+ def get_tags_mock():
+ sometag = Mock()
+ sometag.name = "sometag"
+ sometag.commit = get_commit_mock("aaaaaaa")
+
+ someothertag = Mock()
+ someothertag.name = "someothertag"
+ someothertag.commit = get_commit_mock("aaaaaaa")
+ return [sometag, someothertag]
+
+ def get_branches_mock():
+ master = Mock()
+ master.name = "master"
+ master.commit = get_commit_mock("aaaaaaa")
+
+ otherbranch = Mock()
+ otherbranch.name = "otherbranch"
+ otherbranch.commit = get_commit_mock("aaaaaaa")
+ return [master, otherbranch]
+
+ def get_contents_mock(filepath):
+ if filepath == "Dockerfile":
+ m = Mock()
+ m.content = "hello world"
+ return m
+
+ if filepath == "somesubdir/Dockerfile":
+ m = Mock()
+ m.content = "hi universe"
+ return m
+
+ raise GithubException(None, None)
+
+ def get_git_tree_mock(commit_sha, recursive=False):
+ first_file = Mock()
+ first_file.type = "blob"
+ first_file.path = "Dockerfile"
+
+ second_file = Mock()
+ second_file.type = "other"
+ second_file.path = "/some/Dockerfile"
+
+ third_file = Mock()
+ third_file.type = "blob"
+ third_file.path = "somesubdir/Dockerfile"
+
+ t = Mock()
+
+ if commit_sha == "aaaaaaa":
+ t.tree = [first_file, second_file, third_file]
+ else:
+ t.tree = []
+
+ return t
- def get_repo_mock(namespace, name):
repo_mock = Mock()
- repo_mock.owner = Mock()
- repo_mock.owner.login = namespace
+ repo_mock.default_branch = "master"
+ repo_mock.ssh_url = "ssh_url"
- repo_mock.full_name = '%s/%s' % (namespace, name)
- repo_mock.name = name
- repo_mock.description = 'some %s repo' % (name)
+ repo_mock.get_branch = Mock(side_effect=get_branch_mock)
+ repo_mock.get_tags = Mock(side_effect=get_tags_mock)
+ repo_mock.get_branches = Mock(side_effect=get_branches_mock)
+ repo_mock.get_commit = Mock(side_effect=get_commit_mock)
+ repo_mock.get_contents = Mock(side_effect=get_contents_mock)
+ repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
- if name != 'anotherrepo':
- repo_mock.pushed_at = datetime.utcfromtimestamp(0)
- else:
- repo_mock.pushed_at = None
-
- repo_mock.html_url = 'https://bitbucket.org/%s/%s' % (namespace, name)
- repo_mock.private = name == 'somerepo'
- repo_mock.permissions = Mock()
- repo_mock.permissions.admin = namespace == 'knownuser'
- return repo_mock
-
- def get_user_repos_mock(type='all', sort='created'):
- return [get_repo_mock('knownuser', 'somerepo')]
-
- def get_org_repos_mock(type='all'):
- return [get_repo_mock('someorg', 'somerepo'), get_repo_mock('someorg', 'anotherrepo')]
-
- def get_orgs_mock():
- return [get_org_mock('someorg')]
-
- def get_user_mock(username='knownuser'):
- if username == 'knownuser':
- user_mock = Mock()
- user_mock.name = username
- user_mock.plan = Mock()
- user_mock.plan.private_repos = 1
- user_mock.login = username
- user_mock.html_url = 'https://bitbucket.org/%s' % (username)
- user_mock.avatar_url = 'avatarurl'
- user_mock.get_repos = Mock(side_effect=get_user_repos_mock)
- user_mock.get_orgs = Mock(side_effect=get_orgs_mock)
- return user_mock
-
- raise GithubException(None, None)
-
- def get_org_mock(namespace):
- if namespace == 'someorg':
- org_mock = Mock()
- org_mock.get_repos = Mock(side_effect=get_org_repos_mock)
- org_mock.login = namespace
- org_mock.html_url = 'https://bitbucket.org/%s' % (namespace)
- org_mock.avatar_url = 'avatarurl'
- org_mock.name = namespace
- org_mock.plan = Mock()
- org_mock.plan.private_repos = 2
- return org_mock
-
- raise GithubException(None, None)
-
- def get_tags_mock():
- sometag = Mock()
- sometag.name = 'sometag'
- sometag.commit = get_commit_mock('aaaaaaa')
-
- someothertag = Mock()
- someothertag.name = 'someothertag'
- someothertag.commit = get_commit_mock('aaaaaaa')
- return [sometag, someothertag]
-
- def get_branches_mock():
- master = Mock()
- master.name = 'master'
- master.commit = get_commit_mock('aaaaaaa')
-
- otherbranch = Mock()
- otherbranch.name = 'otherbranch'
- otherbranch.commit = get_commit_mock('aaaaaaa')
- return [master, otherbranch]
-
- def get_contents_mock(filepath):
- if filepath == 'Dockerfile':
- m = Mock()
- m.content = 'hello world'
- return m
-
- if filepath == 'somesubdir/Dockerfile':
- m = Mock()
- m.content = 'hi universe'
- return m
-
- raise GithubException(None, None)
-
- def get_git_tree_mock(commit_sha, recursive=False):
- first_file = Mock()
- first_file.type = 'blob'
- first_file.path = 'Dockerfile'
-
- second_file = Mock()
- second_file.type = 'other'
- second_file.path = '/some/Dockerfile'
-
- third_file = Mock()
- third_file.type = 'blob'
- third_file.path = 'somesubdir/Dockerfile'
-
- t = Mock()
-
- if commit_sha == 'aaaaaaa':
- t.tree = [
- first_file, second_file, third_file,
- ]
- else:
- t.tree = []
-
- return t
-
- repo_mock = Mock()
- repo_mock.default_branch = 'master'
- repo_mock.ssh_url = 'ssh_url'
-
- repo_mock.get_branch = Mock(side_effect=get_branch_mock)
- repo_mock.get_tags = Mock(side_effect=get_tags_mock)
- repo_mock.get_branches = Mock(side_effect=get_branches_mock)
- repo_mock.get_commit = Mock(side_effect=get_commit_mock)
- repo_mock.get_contents = Mock(side_effect=get_contents_mock)
- repo_mock.get_git_tree = Mock(side_effect=get_git_tree_mock)
-
- gh_mock = Mock()
- gh_mock.get_repo = Mock(return_value=repo_mock)
- gh_mock.get_user = Mock(side_effect=get_user_mock)
- gh_mock.get_organization = Mock(side_effect=get_org_mock)
- return gh_mock
+ gh_mock = Mock()
+ gh_mock.get_repo = Mock(return_value=repo_mock)
+ gh_mock.get_user = Mock(side_effect=get_user_mock)
+ gh_mock.get_organization = Mock(side_effect=get_org_mock)
+ return gh_mock
diff --git a/buildtrigger/test/gitlabmock.py b/buildtrigger/test/gitlabmock.py
index cd864241e..90c983a40 100644
--- a/buildtrigger/test/gitlabmock.py
+++ b/buildtrigger/test/gitlabmock.py
@@ -11,588 +11,598 @@ from buildtrigger.gitlabhandler import GitLabBuildTrigger
from util.morecollections import AttrDict
-@urlmatch(netloc=r'fakegitlab')
+@urlmatch(netloc=r"fakegitlab")
def catchall_handler(url, request):
- return {'status_code': 404}
+ return {"status_code": 404}
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/users$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/users$")
def users_handler(url, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
+
+ if url.query.find("knownuser") < 0:
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps([]),
+ }
- if url.query.find('knownuser') < 0:
return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([]),
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "id": 1,
+ "username": "knownuser",
+ "name": "Known User",
+ "state": "active",
+ "avatar_url": "avatarurl",
+ "web_url": "https://bitbucket.org/knownuser",
+ }
+ ]
+ ),
}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([
- {
- "id": 1,
- "username": "knownuser",
- "name": "Known User",
- "state": "active",
- "avatar_url": "avatarurl",
- "web_url": "https://bitbucket.org/knownuser",
- },
- ]),
- }
-
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/user$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/user$")
def user_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 1,
- "username": "john_smith",
- "email": "john@example.com",
- "name": "John Smith",
- "state": "active",
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 1,
+ "username": "john_smith",
+ "email": "john@example.com",
+ "name": "John Smith",
+ "state": "active",
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/foo%2Fbar$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/foo%2Fbar$")
def project_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 4,
- "description": None,
- "default_branch": "master",
- "visibility": "private",
- "path_with_namespace": "someorg/somerepo",
- "ssh_url_to_repo": "git@example.com:someorg/somerepo.git",
- "web_url": "http://example.com/someorg/somerepo",
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 4,
+ "description": None,
+ "default_branch": "master",
+ "visibility": "private",
+ "path_with_namespace": "someorg/somerepo",
+ "ssh_url_to_repo": "git@example.com:someorg/somerepo.git",
+ "web_url": "http://example.com/someorg/somerepo",
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tree$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tree$")
def project_tree_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([
- {
- "id": "a1e8f8d745cc87e3a9248358d9352bb7f9a0aeba",
- "name": "Dockerfile",
- "type": "tree",
- "path": "files/Dockerfile",
- "mode": "040000",
- },
- ]),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "id": "a1e8f8d745cc87e3a9248358d9352bb7f9a0aeba",
+ "name": "Dockerfile",
+ "type": "tree",
+ "path": "files/Dockerfile",
+ "mode": "040000",
+ }
+ ]
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tags$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tags$")
def project_tags_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([
- {
- 'name': 'sometag',
- 'commit': {
- 'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
- },
- },
- {
- 'name': 'someothertag',
- 'commit': {
- 'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
- },
- },
- ]),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "name": "sometag",
+ "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
+ },
+ {
+ "name": "someothertag",
+ "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
+ },
+ ]
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/branches$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/branches$")
def project_branches_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([
- {
- 'name': 'master',
- 'commit': {
- 'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
- },
- },
- {
- 'name': 'otherbranch',
- 'commit': {
- 'id': '60a8ff033665e1207714d6670fcd7b65304ec02f',
- },
- },
- ]),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "name": "master",
+ "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
+ },
+ {
+ "name": "otherbranch",
+ "commit": {"id": "60a8ff033665e1207714d6670fcd7b65304ec02f"},
+ },
+ ]
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/branches/master$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/branches/master$")
def project_branch_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "name": "master",
- "merged": True,
- "protected": True,
- "developers_can_push": False,
- "developers_can_merge": False,
- "commit": {
- "author_email": "john@example.com",
- "author_name": "John Smith",
- "authored_date": "2012-06-27T05:51:39-07:00",
- "committed_date": "2012-06-28T03:44:20-07:00",
- "committer_email": "john@example.com",
- "committer_name": "John Smith",
- "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
- "short_id": "7b5c3cc",
- "title": "add projects API",
- "message": "add projects API",
- "parent_ids": [
- "4ad91d3c1144c406e50c7b33bae684bd6837faf8",
- ],
- },
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "name": "master",
+ "merged": True,
+ "protected": True,
+ "developers_can_push": False,
+ "developers_can_merge": False,
+ "commit": {
+ "author_email": "john@example.com",
+ "author_name": "John Smith",
+ "authored_date": "2012-06-27T05:51:39-07:00",
+ "committed_date": "2012-06-28T03:44:20-07:00",
+ "committer_email": "john@example.com",
+ "committer_name": "John Smith",
+ "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
+ "short_id": "7b5c3cc",
+ "title": "add projects API",
+ "message": "add projects API",
+ "parent_ids": ["4ad91d3c1144c406e50c7b33bae684bd6837faf8"],
+ },
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces/someorg$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces/someorg$")
def namespace_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 2,
- "name": "someorg",
- "path": "someorg",
- "kind": "group",
- "full_path": "someorg",
- "parent_id": None,
- "members_count_with_descendants": 2
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 2,
+ "name": "someorg",
+ "path": "someorg",
+ "kind": "group",
+ "full_path": "someorg",
+ "parent_id": None,
+ "members_count_with_descendants": 2,
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces/knownuser$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces/knownuser$")
def user_namespace_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 1,
- "name": "knownuser",
- "path": "knownuser",
- "kind": "user",
- "full_path": "knownuser",
- "parent_id": None,
- "members_count_with_descendants": 2
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 1,
+ "name": "knownuser",
+ "path": "knownuser",
+ "kind": "user",
+ "full_path": "knownuser",
+ "parent_id": None,
+ "members_count_with_descendants": 2,
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/namespaces(/)?$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/namespaces(/)?$")
def namespaces_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([{
- "id": 2,
- "name": "someorg",
- "path": "someorg",
- "kind": "group",
- "full_path": "someorg",
- "parent_id": None,
- "web_url": "http://gitlab.com/groups/someorg",
- "members_count_with_descendants": 2
- }]),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "id": 2,
+ "name": "someorg",
+ "path": "someorg",
+ "kind": "group",
+ "full_path": "someorg",
+ "parent_id": None,
+ "web_url": "http://gitlab.com/groups/someorg",
+ "members_count_with_descendants": 2,
+ }
+ ]
+ ),
+ }
def get_projects_handler(add_permissions_block):
- @urlmatch(netloc=r'fakegitlab', path=r'/api/v4/groups/2/projects$')
- def projects_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/groups/2/projects$")
+ def projects_handler(_, request):
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- permissions_block = {
- "project_access": {
- "access_level": 10,
- "notification_level": 3
- },
- "group_access": {
- "access_level": 20,
- "notification_level": 3
- },
- }
+ permissions_block = {
+ "project_access": {"access_level": 10, "notification_level": 3},
+ "group_access": {"access_level": 20, "notification_level": 3},
+ }
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([{
- "id": 4,
- "name": "Some project",
- "description": None,
- "default_branch": "master",
- "visibility": "private",
- "path": "someproject",
- "path_with_namespace": "someorg/someproject",
- "last_activity_at": "2013-09-30T13:46:02Z",
- "web_url": "http://example.com/someorg/someproject",
- "permissions": permissions_block if add_permissions_block else None,
- },
- {
- "id": 5,
- "name": "Another project",
- "description": None,
- "default_branch": "master",
- "visibility": "public",
- "path": "anotherproject",
- "path_with_namespace": "someorg/anotherproject",
- "last_activity_at": "2013-09-30T13:46:02Z",
- "web_url": "http://example.com/someorg/anotherproject",
- }]),
- }
- return projects_handler
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "id": 4,
+ "name": "Some project",
+ "description": None,
+ "default_branch": "master",
+ "visibility": "private",
+ "path": "someproject",
+ "path_with_namespace": "someorg/someproject",
+ "last_activity_at": "2013-09-30T13:46:02Z",
+ "web_url": "http://example.com/someorg/someproject",
+ "permissions": permissions_block
+ if add_permissions_block
+ else None,
+ },
+ {
+ "id": 5,
+ "name": "Another project",
+ "description": None,
+ "default_branch": "master",
+ "visibility": "public",
+ "path": "anotherproject",
+ "path_with_namespace": "someorg/anotherproject",
+ "last_activity_at": "2013-09-30T13:46:02Z",
+ "web_url": "http://example.com/someorg/anotherproject",
+ },
+ ]
+ ),
+ }
+
+ return projects_handler
def get_group_handler(null_avatar):
- @urlmatch(netloc=r'fakegitlab', path=r'/api/v4/groups/2$')
- def group_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ @urlmatch(netloc=r"fakegitlab", path=r"/api/v4/groups/2$")
+ def group_handler(_, request):
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
+
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 1,
+ "name": "SomeOrg Group",
+ "path": "someorg",
+ "description": "An interesting group",
+ "visibility": "public",
+ "lfs_enabled": True,
+ "avatar_url": "avatar_url" if not null_avatar else None,
+ "web_url": "http://gitlab.com/groups/someorg",
+ "request_access_enabled": False,
+ "full_name": "SomeOrg Group",
+ "full_path": "someorg",
+ "parent_id": None,
+ }
+ ),
+ }
+
+ return group_handler
+
+
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/files/Dockerfile$")
+def dockerfile_handler(_, request):
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 1,
- "name": "SomeOrg Group",
- "path": "someorg",
- "description": "An interesting group",
- "visibility": "public",
- "lfs_enabled": True,
- "avatar_url": 'avatar_url' if not null_avatar else None,
- "web_url": "http://gitlab.com/groups/someorg",
- "request_access_enabled": False,
- "full_name": "SomeOrg Group",
- "full_path": "someorg",
- "parent_id": None,
- }),
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "file_name": "Dockerfile",
+ "file_path": "Dockerfile",
+ "size": 10,
+ "encoding": "base64",
+ "content": base64.b64encode("hello world"),
+ "ref": "master",
+ "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
+ "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
+ "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d",
+ }
+ ),
}
- return group_handler
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/files/Dockerfile$')
-def dockerfile_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
-
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "file_name": "Dockerfile",
- "file_path": "Dockerfile",
- "size": 10,
- "encoding": "base64",
- "content": base64.b64encode('hello world'),
- "ref": "master",
- "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
- "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
- "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d"
- }),
- }
-
-
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/files/somesubdir%2FDockerfile$')
+@urlmatch(
+ netloc=r"fakegitlab",
+ path=r"/api/v4/projects/4/repository/files/somesubdir%2FDockerfile$",
+)
def sub_dockerfile_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "file_name": "Dockerfile",
- "file_path": "somesubdir/Dockerfile",
- "size": 10,
- "encoding": "base64",
- "content": base64.b64encode('hi universe'),
- "ref": "master",
- "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
- "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
- "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d"
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "file_name": "Dockerfile",
+ "file_path": "somesubdir/Dockerfile",
+ "size": 10,
+ "encoding": "base64",
+ "content": base64.b64encode("hi universe"),
+ "ref": "master",
+ "blob_id": "79f7bbd25901e8334750839545a9bd021f0e4c83",
+ "commit_id": "d5a3ff139356ce33e37e73add446f16869741b50",
+ "last_commit_id": "570e7b2abdd848b95f2f578043fc23bd6f6fd24d",
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/repository/tags/sometag$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/repository/tags/sometag$")
def tag_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "name": "sometag",
- "message": "some cool message",
- "target": "60a8ff033665e1207714d6670fcd7b65304ec02f",
- "commit": {
- "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
- "short_id": "60a8ff03",
- "title": "Initial commit",
- "created_at": "2017-07-26T11:08:53.000+02:00",
- "parent_ids": [
- "f61c062ff8bcbdb00e0a1b3317a91aed6ceee06b"
- ],
- "message": "v5.0.0\n",
- "author_name": "Arthur Verschaeve",
- "author_email": "contact@arthurverschaeve.be",
- "authored_date": "2015-02-01T21:56:31.000+01:00",
- "committer_name": "Arthur Verschaeve",
- "committer_email": "contact@arthurverschaeve.be",
- "committed_date": "2015-02-01T21:56:31.000+01:00"
- },
- "release": None,
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "name": "sometag",
+ "message": "some cool message",
+ "target": "60a8ff033665e1207714d6670fcd7b65304ec02f",
+ "commit": {
+ "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
+ "short_id": "60a8ff03",
+ "title": "Initial commit",
+ "created_at": "2017-07-26T11:08:53.000+02:00",
+ "parent_ids": ["f61c062ff8bcbdb00e0a1b3317a91aed6ceee06b"],
+ "message": "v5.0.0\n",
+ "author_name": "Arthur Verschaeve",
+ "author_email": "contact@arthurverschaeve.be",
+ "authored_date": "2015-02-01T21:56:31.000+01:00",
+ "committer_name": "Arthur Verschaeve",
+ "committer_email": "contact@arthurverschaeve.be",
+ "committed_date": "2015-02-01T21:56:31.000+01:00",
+ },
+ "release": None,
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/foo%2Fbar/repository/commits/60a8ff033665e1207714d6670fcd7b65304ec02f$')
+@urlmatch(
+ netloc=r"fakegitlab",
+ path=r"/api/v4/projects/foo%2Fbar/repository/commits/60a8ff033665e1207714d6670fcd7b65304ec02f$",
+)
def commit_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
- "short_id": "60a8ff03366",
- "title": "Sanitize for network graph",
- "author_name": "someguy",
- "author_email": "some.guy@gmail.com",
- "committer_name": "Some Guy",
- "committer_email": "some.guy@gmail.com",
- "created_at": "2012-09-20T09:06:12+03:00",
- "message": "Sanitize for network graph",
- "committed_date": "2012-09-20T09:06:12+03:00",
- "authored_date": "2012-09-20T09:06:12+03:00",
- "parent_ids": [
- "ae1d9fb46aa2b07ee9836d49862ec4e2c46fbbba"
- ],
- "last_pipeline" : {
- "id": 8,
- "ref": "master",
- "sha": "2dc6aa325a317eda67812f05600bdf0fcdc70ab0",
- "status": "created",
- },
- "stats": {
- "additions": 15,
- "deletions": 10,
- "total": 25
- },
- "status": "running"
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": "60a8ff033665e1207714d6670fcd7b65304ec02f",
+ "short_id": "60a8ff03366",
+ "title": "Sanitize for network graph",
+ "author_name": "someguy",
+ "author_email": "some.guy@gmail.com",
+ "committer_name": "Some Guy",
+ "committer_email": "some.guy@gmail.com",
+ "created_at": "2012-09-20T09:06:12+03:00",
+ "message": "Sanitize for network graph",
+ "committed_date": "2012-09-20T09:06:12+03:00",
+ "authored_date": "2012-09-20T09:06:12+03:00",
+ "parent_ids": ["ae1d9fb46aa2b07ee9836d49862ec4e2c46fbbba"],
+ "last_pipeline": {
+ "id": 8,
+ "ref": "master",
+ "sha": "2dc6aa325a317eda67812f05600bdf0fcdc70ab0",
+ "status": "created",
+ },
+ "stats": {"additions": 15, "deletions": 10, "total": 25},
+ "status": "running",
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/deploy_keys$', method='POST')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/deploy_keys$", method="POST")
def create_deploykey_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 1,
- "title": "Public key",
- "key": "ssh-rsa some stuff",
- "created_at": "2013-10-02T10:12:29Z",
- "can_push": False,
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 1,
+ "title": "Public key",
+ "key": "ssh-rsa some stuff",
+ "created_at": "2013-10-02T10:12:29Z",
+ "can_push": False,
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/hooks$', method='POST')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/hooks$", method="POST")
def create_hook_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({
- "id": 1,
- "url": "http://example.com/hook",
- "project_id": 4,
- "push_events": True,
- "issues_events": True,
- "confidential_issues_events": True,
- "merge_requests_events": True,
- "tag_push_events": True,
- "note_events": True,
- "job_events": True,
- "pipeline_events": True,
- "wiki_page_events": True,
- "enable_ssl_verification": True,
- "created_at": "2012-10-12T17:04:47Z",
- }),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ {
+ "id": 1,
+ "url": "http://example.com/hook",
+ "project_id": 4,
+ "push_events": True,
+ "issues_events": True,
+ "confidential_issues_events": True,
+ "merge_requests_events": True,
+ "tag_push_events": True,
+ "note_events": True,
+ "job_events": True,
+ "pipeline_events": True,
+ "wiki_page_events": True,
+ "enable_ssl_verification": True,
+ "created_at": "2012-10-12T17:04:47Z",
+ }
+ ),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/hooks/1$', method='DELETE')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/projects/4/hooks/1$", method="DELETE")
def delete_hook_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({}),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps({}),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/projects/4/deploy_keys/1$', method='DELETE')
+@urlmatch(
+ netloc=r"fakegitlab", path=r"/api/v4/projects/4/deploy_keys/1$", method="DELETE"
+)
def delete_deploykey_handker(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps({}),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps({}),
+ }
-@urlmatch(netloc=r'fakegitlab', path=r'/api/v4/users/1/projects$')
+@urlmatch(netloc=r"fakegitlab", path=r"/api/v4/users/1/projects$")
def user_projects_list_handler(_, request):
- if not request.headers.get('Authorization') == 'Bearer foobar':
- return {'status_code': 401}
+ if not request.headers.get("Authorization") == "Bearer foobar":
+ return {"status_code": 401}
- return {
- 'status_code': 200,
- 'headers': {
- 'Content-Type': 'application/json',
- },
- 'content': json.dumps([
- {
- "id": 2,
- "name": "Another project",
- "description": None,
- "default_branch": "master",
- "visibility": "public",
- "path": "anotherproject",
- "path_with_namespace": "knownuser/anotherproject",
- "last_activity_at": "2013-09-30T13:46:02Z",
- "web_url": "http://example.com/knownuser/anotherproject",
- }
- ]),
- }
+ return {
+ "status_code": 200,
+ "headers": {"Content-Type": "application/json"},
+ "content": json.dumps(
+ [
+ {
+ "id": 2,
+ "name": "Another project",
+ "description": None,
+ "default_branch": "master",
+ "visibility": "public",
+ "path": "anotherproject",
+ "path_with_namespace": "knownuser/anotherproject",
+ "last_activity_at": "2013-09-30T13:46:02Z",
+ "web_url": "http://example.com/knownuser/anotherproject",
+ }
+ ]
+ ),
+ }
@contextmanager
-def get_gitlab_trigger(dockerfile_path='', add_permissions=True, missing_avatar_url=False):
- handlers = [user_handler, users_handler, project_branches_handler, project_tree_handler,
- project_handler, get_projects_handler(add_permissions), tag_handler,
- project_branch_handler, get_group_handler(missing_avatar_url), dockerfile_handler,
- sub_dockerfile_handler, namespace_handler, user_namespace_handler, namespaces_handler,
- commit_handler, create_deploykey_handler, delete_deploykey_handker,
- create_hook_handler, delete_hook_handler, project_tags_handler,
- user_projects_list_handler, catchall_handler]
+def get_gitlab_trigger(
+ dockerfile_path="", add_permissions=True, missing_avatar_url=False
+):
+ handlers = [
+ user_handler,
+ users_handler,
+ project_branches_handler,
+ project_tree_handler,
+ project_handler,
+ get_projects_handler(add_permissions),
+ tag_handler,
+ project_branch_handler,
+ get_group_handler(missing_avatar_url),
+ dockerfile_handler,
+ sub_dockerfile_handler,
+ namespace_handler,
+ user_namespace_handler,
+ namespaces_handler,
+ commit_handler,
+ create_deploykey_handler,
+ delete_deploykey_handker,
+ create_hook_handler,
+ delete_hook_handler,
+ project_tags_handler,
+ user_projects_list_handler,
+ catchall_handler,
+ ]
- with HTTMock(*handlers):
- trigger_obj = AttrDict(dict(auth_token='foobar', id='sometrigger'))
- trigger = GitLabBuildTrigger(trigger_obj, {
- 'build_source': 'foo/bar',
- 'dockerfile_path': dockerfile_path,
- 'username': 'knownuser'
- })
+ with HTTMock(*handlers):
+ trigger_obj = AttrDict(dict(auth_token="foobar", id="sometrigger"))
+ trigger = GitLabBuildTrigger(
+ trigger_obj,
+ {
+ "build_source": "foo/bar",
+ "dockerfile_path": dockerfile_path,
+ "username": "knownuser",
+ },
+ )
- client = gitlab.Gitlab('http://fakegitlab', oauth_token='foobar', timeout=20, api_version=4)
- client.auth()
+ client = gitlab.Gitlab(
+ "http://fakegitlab", oauth_token="foobar", timeout=20, api_version=4
+ )
+ client.auth()
- trigger._get_authorized_client = lambda: client
- yield trigger
+ trigger._get_authorized_client = lambda: client
+ yield trigger
diff --git a/buildtrigger/test/test_basehandler.py b/buildtrigger/test/test_basehandler.py
index 7162c2535..50bdb5022 100644
--- a/buildtrigger/test/test_basehandler.py
+++ b/buildtrigger/test/test_basehandler.py
@@ -3,53 +3,74 @@ import pytest
from buildtrigger.basehandler import BuildTriggerHandler
-@pytest.mark.parametrize('input,output', [
- ("Dockerfile", True),
- ("server.Dockerfile", True),
- (u"Dockerfile", True),
- (u"server.Dockerfile", True),
- ("bad file name", False),
- (u"bad file name", False),
-])
+@pytest.mark.parametrize(
+ "input,output",
+ [
+ ("Dockerfile", True),
+ ("server.Dockerfile", True),
+ (u"Dockerfile", True),
+ (u"server.Dockerfile", True),
+ ("bad file name", False),
+ (u"bad file name", False),
+ ],
+)
def test_path_is_dockerfile(input, output):
- assert BuildTriggerHandler.filename_is_dockerfile(input) == output
+ assert BuildTriggerHandler.filename_is_dockerfile(input) == output
-@pytest.mark.parametrize('input,output', [
- ("", {}),
- ("/a", {"/a": ["/"]}),
- ("a", {"/a": ["/"]}),
- ("/b/a", {"/b/a": ["/b", "/"]}),
- ("b/a", {"/b/a": ["/b", "/"]}),
- ("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
- ("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
- ("/a", {"/a": ["/"]}),
-])
+@pytest.mark.parametrize(
+ "input,output",
+ [
+ ("", {}),
+ ("/a", {"/a": ["/"]}),
+ ("a", {"/a": ["/"]}),
+ ("/b/a", {"/b/a": ["/b", "/"]}),
+ ("b/a", {"/b/a": ["/b", "/"]}),
+ ("/c/b/a", {"/c/b/a": ["/c/b", "/c", "/"]}),
+ ("/a//b//c", {"/a/b/c": ["/", "/a", "/a/b"]}),
+ ("/a", {"/a": ["/"]}),
+ ],
+)
def test_subdir_path_map_no_previous(input, output):
- actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
- for key in actual_mapping:
- value = actual_mapping[key]
- actual_mapping[key] = value.sort()
- for key in output:
- value = output[key]
- output[key] = value.sort()
+ actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(input)
+ for key in actual_mapping:
+ value = actual_mapping[key]
+ actual_mapping[key] = value.sort()
+ for key in output:
+ value = output[key]
+ output[key] = value.sort()
- assert actual_mapping == output
+ assert actual_mapping == output
-@pytest.mark.parametrize('new_path,original_dictionary,output', [
- ("/a", {}, {"/a": ["/"]}),
- ("b", {"/a": ["some_path", "another_path"]}, {"/a": ["some_path", "another_path"], "/b": ["/"]}),
- ("/a/b/c/d", {"/e": ["some_path", "another_path"]},
- {"/e": ["some_path", "another_path"], "/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"]}),
-])
+@pytest.mark.parametrize(
+ "new_path,original_dictionary,output",
+ [
+ ("/a", {}, {"/a": ["/"]}),
+ (
+ "b",
+ {"/a": ["some_path", "another_path"]},
+ {"/a": ["some_path", "another_path"], "/b": ["/"]},
+ ),
+ (
+ "/a/b/c/d",
+ {"/e": ["some_path", "another_path"]},
+ {
+ "/e": ["some_path", "another_path"],
+ "/a/b/c/d": ["/", "/a", "/a/b", "/a/b/c"],
+ },
+ ),
+ ],
+)
def test_subdir_path_map(new_path, original_dictionary, output):
- actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(new_path, original_dictionary)
- for key in actual_mapping:
- value = actual_mapping[key]
- actual_mapping[key] = value.sort()
- for key in output:
- value = output[key]
- output[key] = value.sort()
+ actual_mapping = BuildTriggerHandler.get_parent_directory_mappings(
+ new_path, original_dictionary
+ )
+ for key in actual_mapping:
+ value = actual_mapping[key]
+ actual_mapping[key] = value.sort()
+ for key in output:
+ value = output[key]
+ output[key] = value.sort()
- assert actual_mapping == output
+ assert actual_mapping == output
diff --git a/buildtrigger/test/test_bitbuckethandler.py b/buildtrigger/test/test_bitbuckethandler.py
index dbb47521a..3b08917f7 100644
--- a/buildtrigger/test/test_bitbuckethandler.py
+++ b/buildtrigger/test/test_bitbuckethandler.py
@@ -2,35 +2,44 @@ import json
import pytest
from buildtrigger.test.bitbucketmock import get_bitbucket_trigger
-from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
- InvalidPayloadException)
+from buildtrigger.triggerutil import (
+ SkipRequestException,
+ ValidationRequestException,
+ InvalidPayloadException,
+)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
+
@pytest.fixture
def bitbucket_trigger():
- return get_bitbucket_trigger()
+ return get_bitbucket_trigger()
def test_list_build_subdirs(bitbucket_trigger):
- assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
+ assert bitbucket_trigger.list_build_subdirs() == ["/Dockerfile"]
-@pytest.mark.parametrize('dockerfile_path, contents', [
- ('/Dockerfile', 'hello world'),
- ('somesubdir/Dockerfile', 'hi universe'),
- ('unknownpath', None),
-])
+@pytest.mark.parametrize(
+ "dockerfile_path, contents",
+ [
+ ("/Dockerfile", "hello world"),
+ ("somesubdir/Dockerfile", "hi universe"),
+ ("unknownpath", None),
+ ],
+)
def test_load_dockerfile_contents(dockerfile_path, contents):
- trigger = get_bitbucket_trigger(dockerfile_path)
- assert trigger.load_dockerfile_contents() == contents
+ trigger = get_bitbucket_trigger(dockerfile_path)
+ assert trigger.load_dockerfile_contents() == contents
-@pytest.mark.parametrize('payload, expected_error, expected_message', [
- ('{}', InvalidPayloadException, "'push' is a required property"),
-
- # Valid payload:
- ('''{
+@pytest.mark.parametrize(
+ "payload, expected_error, expected_message",
+ [
+ ("{}", InvalidPayloadException, "'push' is a required property"),
+ # Valid payload:
+ (
+ """{
"push": {
"changes": [{
"new": {
@@ -51,10 +60,13 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": {
"full_name": "foo/bar"
}
- }''', None, None),
-
- # Skip message:
- ('''{
+ }""",
+ None,
+ None,
+ ),
+ # Skip message:
+ (
+ """{
"push": {
"changes": [{
"new": {
@@ -75,17 +87,25 @@ def test_load_dockerfile_contents(dockerfile_path, contents):
"repository": {
"full_name": "foo/bar"
}
- }''', SkipRequestException, ''),
-])
-def test_handle_trigger_request(bitbucket_trigger, payload, expected_error, expected_message):
- def get_payload():
- return json.loads(payload)
+ }""",
+ SkipRequestException,
+ "",
+ ),
+ ],
+)
+def test_handle_trigger_request(
+ bitbucket_trigger, payload, expected_error, expected_message
+):
+ def get_payload():
+ return json.loads(payload)
- request = AttrDict(dict(get_json=get_payload))
+ request = AttrDict(dict(get_json=get_payload))
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- bitbucket_trigger.handle_trigger_request(request)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(bitbucket_trigger.handle_trigger_request(request), PreparedBuild)
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ bitbucket_trigger.handle_trigger_request(request)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(
+ bitbucket_trigger.handle_trigger_request(request), PreparedBuild
+ )
diff --git a/buildtrigger/test/test_customhandler.py b/buildtrigger/test/test_customhandler.py
index cbb5f484e..984eb27ce 100644
--- a/buildtrigger/test/test_customhandler.py
+++ b/buildtrigger/test/test_customhandler.py
@@ -1,20 +1,32 @@
import pytest
from buildtrigger.customhandler import CustomBuildTrigger
-from buildtrigger.triggerutil import (InvalidPayloadException, SkipRequestException,
- TriggerStartException)
+from buildtrigger.triggerutil import (
+ InvalidPayloadException,
+ SkipRequestException,
+ TriggerStartException,
+)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
-@pytest.mark.parametrize('payload, expected_error, expected_message', [
- ('', InvalidPayloadException, 'Missing expected payload'),
- ('{}', InvalidPayloadException, "'commit' is a required property"),
- ('{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
- InvalidPayloadException, "u'foo' does not match '^([A-Fa-f0-9]{7,})$'"),
-
- ('{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}', None, None),
- ('''{
+@pytest.mark.parametrize(
+ "payload, expected_error, expected_message",
+ [
+ ("", InvalidPayloadException, "Missing expected payload"),
+ ("{}", InvalidPayloadException, "'commit' is a required property"),
+ (
+ '{"commit": "foo", "ref": "refs/heads/something", "default_branch": "baz"}',
+ InvalidPayloadException,
+ "u'foo' does not match '^([A-Fa-f0-9]{7,})$'",
+ ),
+ (
+ '{"commit": "11d6fbc", "ref": "refs/heads/something", "default_branch": "baz"}',
+ None,
+ None,
+ ),
+ (
+ """{
"commit": "11d6fbc",
"ref": "refs/heads/something",
"default_branch": "baz",
@@ -23,29 +35,41 @@ from util.morecollections import AttrDict
"url": "http://foo.bar",
"date": "NOW"
}
- }''', SkipRequestException, ''),
-])
+ }""",
+ SkipRequestException,
+ "",
+ ),
+ ],
+)
def test_handle_trigger_request(payload, expected_error, expected_message):
- trigger = CustomBuildTrigger(None, {'build_source': 'foo'})
- request = AttrDict(dict(data=payload))
+ trigger = CustomBuildTrigger(None, {"build_source": "foo"})
+ request = AttrDict(dict(data=payload))
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- trigger.handle_trigger_request(request)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ trigger.handle_trigger_request(request)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(trigger.handle_trigger_request(request), PreparedBuild)
-@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
- ({}, TriggerStartException, 'missing required parameter'),
- ({'commit_sha': 'foo'}, TriggerStartException, "'foo' does not match '^([A-Fa-f0-9]{7,})$'"),
- ({'commit_sha': '11d6fbc'}, None, None),
-])
+
+@pytest.mark.parametrize(
+ "run_parameters, expected_error, expected_message",
+ [
+ ({}, TriggerStartException, "missing required parameter"),
+ (
+ {"commit_sha": "foo"},
+ TriggerStartException,
+ "'foo' does not match '^([A-Fa-f0-9]{7,})$'",
+ ),
+ ({"commit_sha": "11d6fbc"}, None, None),
+ ],
+)
def test_manual_start(run_parameters, expected_error, expected_message):
- trigger = CustomBuildTrigger(None, {'build_source': 'foo'})
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- trigger.manual_start(run_parameters)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(trigger.manual_start(run_parameters), PreparedBuild)
+ trigger = CustomBuildTrigger(None, {"build_source": "foo"})
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ trigger.manual_start(run_parameters)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(trigger.manual_start(run_parameters), PreparedBuild)
diff --git a/buildtrigger/test/test_githosthandler.py b/buildtrigger/test/test_githosthandler.py
index fadf8dce5..f0c43b458 100644
--- a/buildtrigger/test/test_githosthandler.py
+++ b/buildtrigger/test/test_githosthandler.py
@@ -9,113 +9,145 @@ from endpoints.building import PreparedBuild
# in this fixture. Each trigger's mock is expected to return the same data for all of these calls.
@pytest.fixture(params=[get_github_trigger(), get_bitbucket_trigger()])
def githost_trigger(request):
- return request.param
-
-@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
- # No branch or tag specified: use the commit of the default branch.
- ({}, None, None),
-
- # Invalid branch.
- ({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException,
- 'Could not find branch in repository'),
-
- # Invalid tag.
- ({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException,
- 'Could not find tag in repository'),
-
- # Valid branch.
- ({'refs': {'kind': 'branch', 'name': 'master'}}, None, None),
-
- # Valid tag.
- ({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None),
-])
-def test_manual_start(run_parameters, expected_error, expected_message, githost_trigger):
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- githost_trigger.manual_start(run_parameters)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
+ return request.param
-@pytest.mark.parametrize('name, expected', [
- ('refs', [
- {'kind': 'branch', 'name': 'master'},
- {'kind': 'branch', 'name': 'otherbranch'},
- {'kind': 'tag', 'name': 'sometag'},
- {'kind': 'tag', 'name': 'someothertag'},
- ]),
- ('tag_name', set(['sometag', 'someothertag'])),
- ('branch_name', set(['master', 'otherbranch'])),
- ('invalid', None)
-])
+@pytest.mark.parametrize(
+ "run_parameters, expected_error, expected_message",
+ [
+ # No branch or tag specified: use the commit of the default branch.
+ ({}, None, None),
+ # Invalid branch.
+ (
+ {"refs": {"kind": "branch", "name": "invalid"}},
+ TriggerStartException,
+ "Could not find branch in repository",
+ ),
+ # Invalid tag.
+ (
+ {"refs": {"kind": "tag", "name": "invalid"}},
+ TriggerStartException,
+ "Could not find tag in repository",
+ ),
+ # Valid branch.
+ ({"refs": {"kind": "branch", "name": "master"}}, None, None),
+ # Valid tag.
+ ({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
+ ],
+)
+def test_manual_start(
+ run_parameters, expected_error, expected_message, githost_trigger
+):
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ githost_trigger.manual_start(run_parameters)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(githost_trigger.manual_start(run_parameters), PreparedBuild)
+
+
+@pytest.mark.parametrize(
+ "name, expected",
+ [
+ (
+ "refs",
+ [
+ {"kind": "branch", "name": "master"},
+ {"kind": "branch", "name": "otherbranch"},
+ {"kind": "tag", "name": "sometag"},
+ {"kind": "tag", "name": "someothertag"},
+ ],
+ ),
+ ("tag_name", set(["sometag", "someothertag"])),
+ ("branch_name", set(["master", "otherbranch"])),
+ ("invalid", None),
+ ],
+)
def test_list_field_values(name, expected, githost_trigger):
- if expected is None:
- assert githost_trigger.list_field_values(name) is None
- elif isinstance(expected, set):
- assert set(githost_trigger.list_field_values(name)) == set(expected)
- else:
- assert githost_trigger.list_field_values(name) == expected
+ if expected is None:
+ assert githost_trigger.list_field_values(name) is None
+ elif isinstance(expected, set):
+ assert set(githost_trigger.list_field_values(name)) == set(expected)
+ else:
+ assert githost_trigger.list_field_values(name) == expected
def test_list_build_source_namespaces():
- namespaces_expected = [
- {
- 'personal': True,
- 'score': 1,
- 'avatar_url': 'avatarurl',
- 'id': 'knownuser',
- 'title': 'knownuser',
- 'url': 'https://bitbucket.org/knownuser',
- },
- {
- 'score': 2,
- 'title': 'someorg',
- 'personal': False,
- 'url': 'https://bitbucket.org/someorg',
- 'avatar_url': 'avatarurl',
- 'id': 'someorg'
- }
- ]
+ namespaces_expected = [
+ {
+ "personal": True,
+ "score": 1,
+ "avatar_url": "avatarurl",
+ "id": "knownuser",
+ "title": "knownuser",
+ "url": "https://bitbucket.org/knownuser",
+ },
+ {
+ "score": 2,
+ "title": "someorg",
+ "personal": False,
+ "url": "https://bitbucket.org/someorg",
+ "avatar_url": "avatarurl",
+ "id": "someorg",
+ },
+ ]
- found = get_bitbucket_trigger().list_build_source_namespaces()
- found.sort()
+ found = get_bitbucket_trigger().list_build_source_namespaces()
+ found.sort()
- namespaces_expected.sort()
- assert found == namespaces_expected
+ namespaces_expected.sort()
+ assert found == namespaces_expected
-@pytest.mark.parametrize('namespace, expected', [
- ('', []),
- ('unknown', []),
-
- ('knownuser', [
- {
- 'last_updated': 0, 'name': 'somerepo',
- 'url': 'https://bitbucket.org/knownuser/somerepo', 'private': True,
- 'full_name': 'knownuser/somerepo', 'has_admin_permissions': True,
- 'description': 'some somerepo repo'
- }]),
-
- ('someorg', [
- {
- 'last_updated': 0, 'name': 'somerepo',
- 'url': 'https://bitbucket.org/someorg/somerepo', 'private': True,
- 'full_name': 'someorg/somerepo', 'has_admin_permissions': False,
- 'description': 'some somerepo repo'
- },
- {
- 'last_updated': 0, 'name': 'anotherrepo',
- 'url': 'https://bitbucket.org/someorg/anotherrepo', 'private': False,
- 'full_name': 'someorg/anotherrepo', 'has_admin_permissions': False,
- 'description': 'some anotherrepo repo'
- }]),
-])
+@pytest.mark.parametrize(
+ "namespace, expected",
+ [
+ ("", []),
+ ("unknown", []),
+ (
+ "knownuser",
+ [
+ {
+ "last_updated": 0,
+ "name": "somerepo",
+ "url": "https://bitbucket.org/knownuser/somerepo",
+ "private": True,
+ "full_name": "knownuser/somerepo",
+ "has_admin_permissions": True,
+ "description": "some somerepo repo",
+ }
+ ],
+ ),
+ (
+ "someorg",
+ [
+ {
+ "last_updated": 0,
+ "name": "somerepo",
+ "url": "https://bitbucket.org/someorg/somerepo",
+ "private": True,
+ "full_name": "someorg/somerepo",
+ "has_admin_permissions": False,
+ "description": "some somerepo repo",
+ },
+ {
+ "last_updated": 0,
+ "name": "anotherrepo",
+ "url": "https://bitbucket.org/someorg/anotherrepo",
+ "private": False,
+ "full_name": "someorg/anotherrepo",
+ "has_admin_permissions": False,
+ "description": "some anotherrepo repo",
+ },
+ ],
+ ),
+ ],
+)
def test_list_build_sources_for_namespace(namespace, expected, githost_trigger):
- assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
+ assert githost_trigger.list_build_sources_for_namespace(namespace) == expected
def test_activate_and_deactivate(githost_trigger):
- _, private_key = githost_trigger.activate('http://some/url')
- assert 'private_key' in private_key
- githost_trigger.deactivate()
+ _, private_key = githost_trigger.activate("http://some/url")
+ assert "private_key" in private_key
+ githost_trigger.deactivate()
diff --git a/buildtrigger/test/test_githubhandler.py b/buildtrigger/test/test_githubhandler.py
index f7012b0cf..7866359ce 100644
--- a/buildtrigger/test/test_githubhandler.py
+++ b/buildtrigger/test/test_githubhandler.py
@@ -2,24 +2,33 @@ import json
import pytest
from buildtrigger.test.githubmock import get_github_trigger
-from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
- InvalidPayloadException)
+from buildtrigger.triggerutil import (
+ SkipRequestException,
+ ValidationRequestException,
+ InvalidPayloadException,
+)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
+
@pytest.fixture
def github_trigger():
- return get_github_trigger()
+ return get_github_trigger()
-@pytest.mark.parametrize('payload, expected_error, expected_message', [
- ('{"zen": true}', SkipRequestException, ""),
-
- ('{}', InvalidPayloadException, "Missing 'repository' on request"),
- ('{"repository": "foo"}', InvalidPayloadException, "Missing 'owner' on repository"),
-
- # Valid payload:
- ('''{
+@pytest.mark.parametrize(
+ "payload, expected_error, expected_message",
+ [
+ ('{"zen": true}', SkipRequestException, ""),
+ ("{}", InvalidPayloadException, "Missing 'repository' on request"),
+ (
+ '{"repository": "foo"}',
+ InvalidPayloadException,
+ "Missing 'owner' on repository",
+ ),
+ # Valid payload:
+ (
+ """{
"repository": {
"owner": {
"name": "someguy"
@@ -34,10 +43,13 @@ def github_trigger():
"message": "some message",
"timestamp": "NOW"
}
- }''', None, None),
-
- # Skip message:
- ('''{
+ }""",
+ None,
+ None,
+ ),
+ # Skip message:
+ (
+ """{
"repository": {
"owner": {
"name": "someguy"
@@ -52,66 +64,84 @@ def github_trigger():
"message": "[skip build]",
"timestamp": "NOW"
}
- }''', SkipRequestException, ''),
-])
-def test_handle_trigger_request(github_trigger, payload, expected_error, expected_message):
- def get_payload():
- return json.loads(payload)
+ }""",
+ SkipRequestException,
+ "",
+ ),
+ ],
+)
+def test_handle_trigger_request(
+ github_trigger, payload, expected_error, expected_message
+):
+ def get_payload():
+ return json.loads(payload)
- request = AttrDict(dict(get_json=get_payload))
+ request = AttrDict(dict(get_json=get_payload))
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- github_trigger.handle_trigger_request(request)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ github_trigger.handle_trigger_request(request)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(github_trigger.handle_trigger_request(request), PreparedBuild)
-@pytest.mark.parametrize('dockerfile_path, contents', [
- ('/Dockerfile', 'hello world'),
- ('somesubdir/Dockerfile', 'hi universe'),
- ('unknownpath', None),
-])
+@pytest.mark.parametrize(
+ "dockerfile_path, contents",
+ [
+ ("/Dockerfile", "hello world"),
+ ("somesubdir/Dockerfile", "hi universe"),
+ ("unknownpath", None),
+ ],
+)
def test_load_dockerfile_contents(dockerfile_path, contents):
- trigger = get_github_trigger(dockerfile_path)
- assert trigger.load_dockerfile_contents() == contents
+ trigger = get_github_trigger(dockerfile_path)
+ assert trigger.load_dockerfile_contents() == contents
-@pytest.mark.parametrize('username, expected_response', [
- ('unknownuser', None),
- ('knownuser', {'html_url': 'https://bitbucket.org/knownuser', 'avatar_url': 'avatarurl'}),
-])
+@pytest.mark.parametrize(
+ "username, expected_response",
+ [
+ ("unknownuser", None),
+ (
+ "knownuser",
+ {"html_url": "https://bitbucket.org/knownuser", "avatar_url": "avatarurl"},
+ ),
+ ],
+)
def test_lookup_user(username, expected_response, github_trigger):
- assert github_trigger.lookup_user(username) == expected_response
+ assert github_trigger.lookup_user(username) == expected_response
def test_list_build_subdirs(github_trigger):
- assert github_trigger.list_build_subdirs() == ['Dockerfile', 'somesubdir/Dockerfile']
+ assert github_trigger.list_build_subdirs() == [
+ "Dockerfile",
+ "somesubdir/Dockerfile",
+ ]
def test_list_build_source_namespaces(github_trigger):
- namespaces_expected = [
- {
- 'personal': True,
- 'score': 1,
- 'avatar_url': 'avatarurl',
- 'id': 'knownuser',
- 'title': 'knownuser',
- 'url': 'https://bitbucket.org/knownuser',
- },
- {
- 'score': 0,
- 'title': 'someorg',
- 'personal': False,
- 'url': '',
- 'avatar_url': 'avatarurl',
- 'id': 'someorg'
- }
- ]
+ namespaces_expected = [
+ {
+ "personal": True,
+ "score": 1,
+ "avatar_url": "avatarurl",
+ "id": "knownuser",
+ "title": "knownuser",
+ "url": "https://bitbucket.org/knownuser",
+ },
+ {
+ "score": 0,
+ "title": "someorg",
+ "personal": False,
+ "url": "",
+ "avatar_url": "avatarurl",
+ "id": "someorg",
+ },
+ ]
- found = github_trigger.list_build_source_namespaces()
- found.sort()
+ found = github_trigger.list_build_source_namespaces()
+ found.sort()
- namespaces_expected.sort()
- assert found == namespaces_expected
+ namespaces_expected.sort()
+ assert found == namespaces_expected
diff --git a/buildtrigger/test/test_gitlabhandler.py b/buildtrigger/test/test_gitlabhandler.py
index b74095a8c..cb9b50581 100644
--- a/buildtrigger/test/test_gitlabhandler.py
+++ b/buildtrigger/test/test_gitlabhandler.py
@@ -4,91 +4,111 @@ import pytest
from mock import Mock
from buildtrigger.test.gitlabmock import get_gitlab_trigger
-from buildtrigger.triggerutil import (SkipRequestException, ValidationRequestException,
- InvalidPayloadException, TriggerStartException)
+from buildtrigger.triggerutil import (
+ SkipRequestException,
+ ValidationRequestException,
+ InvalidPayloadException,
+ TriggerStartException,
+)
from endpoints.building import PreparedBuild
from util.morecollections import AttrDict
+
@pytest.fixture()
def gitlab_trigger():
- with get_gitlab_trigger() as t:
- yield t
+ with get_gitlab_trigger() as t:
+ yield t
def test_list_build_subdirs(gitlab_trigger):
- assert gitlab_trigger.list_build_subdirs() == ['Dockerfile']
+ assert gitlab_trigger.list_build_subdirs() == ["Dockerfile"]
-@pytest.mark.parametrize('dockerfile_path, contents', [
- ('/Dockerfile', 'hello world'),
- ('somesubdir/Dockerfile', 'hi universe'),
- ('unknownpath', None),
-])
+@pytest.mark.parametrize(
+ "dockerfile_path, contents",
+ [
+ ("/Dockerfile", "hello world"),
+ ("somesubdir/Dockerfile", "hi universe"),
+ ("unknownpath", None),
+ ],
+)
def test_load_dockerfile_contents(dockerfile_path, contents):
- with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
- assert trigger.load_dockerfile_contents() == contents
+ with get_gitlab_trigger(dockerfile_path=dockerfile_path) as trigger:
+ assert trigger.load_dockerfile_contents() == contents
-@pytest.mark.parametrize('email, expected_response', [
- ('unknown@email.com', None),
- ('knownuser', {'username': 'knownuser', 'html_url': 'https://bitbucket.org/knownuser',
- 'avatar_url': 'avatarurl'}),
-])
+@pytest.mark.parametrize(
+ "email, expected_response",
+ [
+ ("unknown@email.com", None),
+ (
+ "knownuser",
+ {
+ "username": "knownuser",
+ "html_url": "https://bitbucket.org/knownuser",
+ "avatar_url": "avatarurl",
+ },
+ ),
+ ],
+)
def test_lookup_user(email, expected_response, gitlab_trigger):
- assert gitlab_trigger.lookup_user(email) == expected_response
+ assert gitlab_trigger.lookup_user(email) == expected_response
def test_null_permissions():
- with get_gitlab_trigger(add_permissions=False) as trigger:
- sources = trigger.list_build_sources_for_namespace('someorg')
- source = sources[0]
- assert source['has_admin_permissions']
+ with get_gitlab_trigger(add_permissions=False) as trigger:
+ sources = trigger.list_build_sources_for_namespace("someorg")
+ source = sources[0]
+ assert source["has_admin_permissions"]
def test_list_build_sources():
- with get_gitlab_trigger() as trigger:
- sources = trigger.list_build_sources_for_namespace('someorg')
- assert sources == [
- {
- 'last_updated': 1380548762,
- 'name': u'someproject',
- 'url': u'http://example.com/someorg/someproject',
- 'private': True,
- 'full_name': u'someorg/someproject',
- 'has_admin_permissions': False,
- 'description': ''
- },
- {
- 'last_updated': 1380548762,
- 'name': u'anotherproject',
- 'url': u'http://example.com/someorg/anotherproject',
- 'private': False,
- 'full_name': u'someorg/anotherproject',
- 'has_admin_permissions': True,
- 'description': '',
- }]
+ with get_gitlab_trigger() as trigger:
+ sources = trigger.list_build_sources_for_namespace("someorg")
+ assert sources == [
+ {
+ "last_updated": 1380548762,
+ "name": u"someproject",
+ "url": u"http://example.com/someorg/someproject",
+ "private": True,
+ "full_name": u"someorg/someproject",
+ "has_admin_permissions": False,
+ "description": "",
+ },
+ {
+ "last_updated": 1380548762,
+ "name": u"anotherproject",
+ "url": u"http://example.com/someorg/anotherproject",
+ "private": False,
+ "full_name": u"someorg/anotherproject",
+ "has_admin_permissions": True,
+ "description": "",
+ },
+ ]
def test_null_avatar():
- with get_gitlab_trigger(missing_avatar_url=True) as trigger:
- namespace_data = trigger.list_build_source_namespaces()
- expected = {
- 'avatar_url': None,
- 'personal': False,
- 'title': u'someorg',
- 'url': u'http://gitlab.com/groups/someorg',
- 'score': 1,
- 'id': '2',
- }
+ with get_gitlab_trigger(missing_avatar_url=True) as trigger:
+ namespace_data = trigger.list_build_source_namespaces()
+ expected = {
+ "avatar_url": None,
+ "personal": False,
+ "title": u"someorg",
+ "url": u"http://gitlab.com/groups/someorg",
+ "score": 1,
+ "id": "2",
+ }
- assert namespace_data == [expected]
+ assert namespace_data == [expected]
-@pytest.mark.parametrize('payload, expected_error, expected_message', [
- ('{}', InvalidPayloadException, ''),
-
- # Valid payload:
- ('''{
+@pytest.mark.parametrize(
+ "payload, expected_error, expected_message",
+ [
+ ("{}", InvalidPayloadException, ""),
+ # Valid payload:
+ (
+ """{
"object_kind": "push",
"ref": "refs/heads/master",
"checkout_sha": "aaaaaaa",
@@ -103,10 +123,13 @@ def test_null_avatar():
"timestamp": "now"
}
]
- }''', None, None),
-
- # Skip message:
- ('''{
+ }""",
+ None,
+ None,
+ ),
+ # Skip message:
+ (
+ """{
"object_kind": "push",
"ref": "refs/heads/master",
"checkout_sha": "aaaaaaa",
@@ -121,111 +144,136 @@ def test_null_avatar():
"timestamp": "now"
}
]
- }''', SkipRequestException, ''),
-])
-def test_handle_trigger_request(gitlab_trigger, payload, expected_error, expected_message):
- def get_payload():
- return json.loads(payload)
+ }""",
+ SkipRequestException,
+ "",
+ ),
+ ],
+)
+def test_handle_trigger_request(
+ gitlab_trigger, payload, expected_error, expected_message
+):
+ def get_payload():
+ return json.loads(payload)
- request = AttrDict(dict(get_json=get_payload))
+ request = AttrDict(dict(get_json=get_payload))
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- gitlab_trigger.handle_trigger_request(request)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ gitlab_trigger.handle_trigger_request(request)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(gitlab_trigger.handle_trigger_request(request), PreparedBuild)
-@pytest.mark.parametrize('run_parameters, expected_error, expected_message', [
- # No branch or tag specified: use the commit of the default branch.
- ({}, None, None),
-
- # Invalid branch.
- ({'refs': {'kind': 'branch', 'name': 'invalid'}}, TriggerStartException,
- 'Could not find branch in repository'),
-
- # Invalid tag.
- ({'refs': {'kind': 'tag', 'name': 'invalid'}}, TriggerStartException,
- 'Could not find tag in repository'),
-
- # Valid branch.
- ({'refs': {'kind': 'branch', 'name': 'master'}}, None, None),
-
- # Valid tag.
- ({'refs': {'kind': 'tag', 'name': 'sometag'}}, None, None),
-])
+@pytest.mark.parametrize(
+ "run_parameters, expected_error, expected_message",
+ [
+ # No branch or tag specified: use the commit of the default branch.
+ ({}, None, None),
+ # Invalid branch.
+ (
+ {"refs": {"kind": "branch", "name": "invalid"}},
+ TriggerStartException,
+ "Could not find branch in repository",
+ ),
+ # Invalid tag.
+ (
+ {"refs": {"kind": "tag", "name": "invalid"}},
+ TriggerStartException,
+ "Could not find tag in repository",
+ ),
+ # Valid branch.
+ ({"refs": {"kind": "branch", "name": "master"}}, None, None),
+ # Valid tag.
+ ({"refs": {"kind": "tag", "name": "sometag"}}, None, None),
+ ],
+)
def test_manual_start(run_parameters, expected_error, expected_message, gitlab_trigger):
- if expected_error is not None:
- with pytest.raises(expected_error) as ipe:
- gitlab_trigger.manual_start(run_parameters)
- assert str(ipe.value) == expected_message
- else:
- assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild)
+ if expected_error is not None:
+ with pytest.raises(expected_error) as ipe:
+ gitlab_trigger.manual_start(run_parameters)
+ assert str(ipe.value) == expected_message
+ else:
+ assert isinstance(gitlab_trigger.manual_start(run_parameters), PreparedBuild)
def test_activate_and_deactivate(gitlab_trigger):
- _, private_key = gitlab_trigger.activate('http://some/url')
- assert 'private_key' in private_key
+ _, private_key = gitlab_trigger.activate("http://some/url")
+ assert "private_key" in private_key
- gitlab_trigger.deactivate()
+ gitlab_trigger.deactivate()
-@pytest.mark.parametrize('name, expected', [
- ('refs', [
- {'kind': 'branch', 'name': 'master'},
- {'kind': 'branch', 'name': 'otherbranch'},
- {'kind': 'tag', 'name': 'sometag'},
- {'kind': 'tag', 'name': 'someothertag'},
- ]),
- ('tag_name', set(['sometag', 'someothertag'])),
- ('branch_name', set(['master', 'otherbranch'])),
- ('invalid', None)
-])
+@pytest.mark.parametrize(
+ "name, expected",
+ [
+ (
+ "refs",
+ [
+ {"kind": "branch", "name": "master"},
+ {"kind": "branch", "name": "otherbranch"},
+ {"kind": "tag", "name": "sometag"},
+ {"kind": "tag", "name": "someothertag"},
+ ],
+ ),
+ ("tag_name", set(["sometag", "someothertag"])),
+ ("branch_name", set(["master", "otherbranch"])),
+ ("invalid", None),
+ ],
+)
def test_list_field_values(name, expected, gitlab_trigger):
- if expected is None:
- assert gitlab_trigger.list_field_values(name) is None
- elif isinstance(expected, set):
- assert set(gitlab_trigger.list_field_values(name)) == set(expected)
- else:
- assert gitlab_trigger.list_field_values(name) == expected
+ if expected is None:
+ assert gitlab_trigger.list_field_values(name) is None
+ elif isinstance(expected, set):
+ assert set(gitlab_trigger.list_field_values(name)) == set(expected)
+ else:
+ assert gitlab_trigger.list_field_values(name) == expected
-@pytest.mark.parametrize('namespace, expected', [
- ('', []),
- ('unknown', []),
-
- ('knownuser', [
- {
- 'last_updated': 1380548762,
- 'name': u'anotherproject',
- 'url': u'http://example.com/knownuser/anotherproject',
- 'private': False,
- 'full_name': u'knownuser/anotherproject',
- 'has_admin_permissions': True,
- 'description': ''
- },
- ]),
-
- ('someorg', [
- {
- 'last_updated': 1380548762,
- 'name': u'someproject',
- 'url': u'http://example.com/someorg/someproject',
- 'private': True,
- 'full_name': u'someorg/someproject',
- 'has_admin_permissions': False,
- 'description': ''
- },
- {
- 'last_updated': 1380548762,
- 'name': u'anotherproject',
- 'url': u'http://example.com/someorg/anotherproject',
- 'private': False,
- 'full_name': u'someorg/anotherproject',
- 'has_admin_permissions': True,
- 'description': '',
- }]),
-])
+@pytest.mark.parametrize(
+ "namespace, expected",
+ [
+ ("", []),
+ ("unknown", []),
+ (
+ "knownuser",
+ [
+ {
+ "last_updated": 1380548762,
+ "name": u"anotherproject",
+ "url": u"http://example.com/knownuser/anotherproject",
+ "private": False,
+ "full_name": u"knownuser/anotherproject",
+ "has_admin_permissions": True,
+ "description": "",
+ }
+ ],
+ ),
+ (
+ "someorg",
+ [
+ {
+ "last_updated": 1380548762,
+ "name": u"someproject",
+ "url": u"http://example.com/someorg/someproject",
+ "private": True,
+ "full_name": u"someorg/someproject",
+ "has_admin_permissions": False,
+ "description": "",
+ },
+ {
+ "last_updated": 1380548762,
+ "name": u"anotherproject",
+ "url": u"http://example.com/someorg/anotherproject",
+ "private": False,
+ "full_name": u"someorg/anotherproject",
+ "has_admin_permissions": True,
+ "description": "",
+ },
+ ],
+ ),
+ ],
+)
def test_list_build_sources_for_namespace(namespace, expected, gitlab_trigger):
- assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected
+ assert gitlab_trigger.list_build_sources_for_namespace(namespace) == expected
diff --git a/buildtrigger/test/test_prepare_trigger.py b/buildtrigger/test/test_prepare_trigger.py
index e3aab6b48..839c0a91a 100644
--- a/buildtrigger/test/test_prepare_trigger.py
+++ b/buildtrigger/test/test_prepare_trigger.py
@@ -12,561 +12,577 @@ from buildtrigger.githubhandler import get_transformed_webhook_payload as gh_web
from buildtrigger.gitlabhandler import get_transformed_webhook_payload as gl_webhook
from buildtrigger.triggerutil import SkipRequestException
+
def assertSkipped(filename, processor, *args, **kwargs):
- with open('buildtrigger/test/triggerjson/%s.json' % filename) as f:
- payload = json.loads(f.read())
+ with open("buildtrigger/test/triggerjson/%s.json" % filename) as f:
+ payload = json.loads(f.read())
- nargs = [payload]
- nargs.extend(args)
+ nargs = [payload]
+ nargs.extend(args)
- with pytest.raises(SkipRequestException):
- processor(*nargs, **kwargs)
+ with pytest.raises(SkipRequestException):
+ processor(*nargs, **kwargs)
def assertSchema(filename, expected, processor, *args, **kwargs):
- with open('buildtrigger/test/triggerjson/%s.json' % filename) as f:
- payload = json.loads(f.read())
+ with open("buildtrigger/test/triggerjson/%s.json" % filename) as f:
+ payload = json.loads(f.read())
- nargs = [payload]
- nargs.extend(args)
+ nargs = [payload]
+ nargs.extend(args)
- created = processor(*nargs, **kwargs)
- assert created == expected
- validate(created, METADATA_SCHEMA)
+ created = processor(*nargs, **kwargs)
+ assert created == expected
+ validate(created, METADATA_SCHEMA)
def test_custom_custom():
- expected = {
- u'commit':u'1c002dd',
- u'commit_info': {
- u'url': u'gitsoftware.com/repository/commits/1234567',
- u'date': u'timestamp',
- u'message': u'initial commit',
- u'committer': {
- u'username': u'user',
- u'url': u'gitsoftware.com/users/user',
- u'avatar_url': u'gravatar.com/user.png'
- },
- u'author': {
- u'username': u'user',
- u'url': u'gitsoftware.com/users/user',
- u'avatar_url': u'gravatar.com/user.png'
- }
- },
- u'ref': u'refs/heads/master',
- u'default_branch': u'master',
- u'git_url': u'foobar',
- }
+ expected = {
+ u"commit": u"1c002dd",
+ u"commit_info": {
+ u"url": u"gitsoftware.com/repository/commits/1234567",
+ u"date": u"timestamp",
+ u"message": u"initial commit",
+ u"committer": {
+ u"username": u"user",
+ u"url": u"gitsoftware.com/users/user",
+ u"avatar_url": u"gravatar.com/user.png",
+ },
+ u"author": {
+ u"username": u"user",
+ u"url": u"gitsoftware.com/users/user",
+ u"avatar_url": u"gravatar.com/user.png",
+ },
+ },
+ u"ref": u"refs/heads/master",
+ u"default_branch": u"master",
+ u"git_url": u"foobar",
+ }
- assertSchema('custom_webhook', expected, custom_trigger_payload, git_url='foobar')
+ assertSchema("custom_webhook", expected, custom_trigger_payload, git_url="foobar")
def test_custom_gitlab():
- expected = {
- 'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'ref': u'refs/heads/master',
- 'git_url': u'git@gitlab.com:jsmith/somerepo.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'date': u'2015-08-13T19:33:18+00:00',
- 'message': u'Fix link\n',
- },
- }
+ expected = {
+ "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@gitlab.com:jsmith/somerepo.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "date": u"2015-08-13T19:33:18+00:00",
+ "message": u"Fix link\n",
+ },
+ }
- assertSchema('gitlab_webhook', expected, custom_trigger_payload, git_url='git@gitlab.com:jsmith/somerepo.git')
+ assertSchema(
+ "gitlab_webhook",
+ expected,
+ custom_trigger_payload,
+ git_url="git@gitlab.com:jsmith/somerepo.git",
+ )
def test_custom_github():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile',
- 'committer': {
- 'username': u'jsmith',
- },
- 'author': {
- 'username': u'jsmith',
- },
- },
- }
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ "committer": {"username": u"jsmith"},
+ "author": {"username": u"jsmith"},
+ },
+ }
- assertSchema('github_webhook', expected, custom_trigger_payload,
- git_url='git@github.com:jsmith/anothertest.git')
+ assertSchema(
+ "github_webhook",
+ expected,
+ custom_trigger_payload,
+ git_url="git@github.com:jsmith/anothertest.git",
+ )
def test_custom_bitbucket():
- expected = {
- "commit": u"af64ae7188685f8424040b4735ad12941b980d75",
- "ref": u"refs/heads/master",
- "git_url": u"git@bitbucket.org:jsmith/another-repo.git",
- "commit_info": {
- "url": u"https://bitbucket.org/jsmith/another-repo/commits/af64ae7188685f8424040b4735ad12941b980d75",
- "date": u"2015-09-10T20:40:54+00:00",
- "message": u"Dockerfile edited online with Bitbucket",
- "author": {
- "username": u"John Smith",
- "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
- },
- "committer": {
- "username": u"John Smith",
- "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
- },
- },
- }
+ expected = {
+ "commit": u"af64ae7188685f8424040b4735ad12941b980d75",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@bitbucket.org:jsmith/another-repo.git",
+ "commit_info": {
+ "url": u"https://bitbucket.org/jsmith/another-repo/commits/af64ae7188685f8424040b4735ad12941b980d75",
+ "date": u"2015-09-10T20:40:54+00:00",
+ "message": u"Dockerfile edited online with Bitbucket",
+ "author": {
+ "username": u"John Smith",
+ "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
+ },
+ "committer": {
+ "username": u"John Smith",
+ "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
+ },
+ },
+ }
- assertSchema('bitbucket_webhook', expected, custom_trigger_payload, git_url='git@bitbucket.org:jsmith/another-repo.git')
+ assertSchema(
+ "bitbucket_webhook",
+ expected,
+ custom_trigger_payload,
+ git_url="git@bitbucket.org:jsmith/another-repo.git",
+ )
def test_bitbucket_customer_payload_noauthor():
- expected = {
- "commit": "a0ec139843b2bb281ab21a433266ddc498e605dc",
- "ref": "refs/heads/master",
- "git_url": "git@bitbucket.org:somecoollabs/svc-identity.git",
- "commit_info": {
- "url": "https://bitbucket.org/somecoollabs/svc-identity/commits/a0ec139843b2bb281ab21a433266ddc498e605dc",
- "date": "2015-09-25T00:55:08+00:00",
- "message": "Update version.py to 0.1.2 [skip ci]\n\n(by utilitybelt/scripts/autotag_version.py)\n",
- "committer": {
- "username": "CodeShip Tagging",
- "avatar_url": "https://bitbucket.org/account/SomeCoolLabs_CodeShip/avatar/32/",
- },
- },
- }
+ expected = {
+ "commit": "a0ec139843b2bb281ab21a433266ddc498e605dc",
+ "ref": "refs/heads/master",
+ "git_url": "git@bitbucket.org:somecoollabs/svc-identity.git",
+ "commit_info": {
+ "url": "https://bitbucket.org/somecoollabs/svc-identity/commits/a0ec139843b2bb281ab21a433266ddc498e605dc",
+ "date": "2015-09-25T00:55:08+00:00",
+ "message": "Update version.py to 0.1.2 [skip ci]\n\n(by utilitybelt/scripts/autotag_version.py)\n",
+ "committer": {
+ "username": "CodeShip Tagging",
+ "avatar_url": "https://bitbucket.org/account/SomeCoolLabs_CodeShip/avatar/32/",
+ },
+ },
+ }
- assertSchema('bitbucket_customer_example_noauthor', expected, bb_webhook)
+ assertSchema("bitbucket_customer_example_noauthor", expected, bb_webhook)
def test_bitbucket_customer_payload_tag():
- expected = {
- "commit": "a0ec139843b2bb281ab21a433266ddc498e605dc",
- "ref": "refs/tags/0.1.2",
- "git_url": "git@bitbucket.org:somecoollabs/svc-identity.git",
- "commit_info": {
- "url": "https://bitbucket.org/somecoollabs/svc-identity/commits/a0ec139843b2bb281ab21a433266ddc498e605dc",
- "date": "2015-09-25T00:55:08+00:00",
- "message": "Update version.py to 0.1.2 [skip ci]\n\n(by utilitybelt/scripts/autotag_version.py)\n",
- "committer": {
- "username": "CodeShip Tagging",
- "avatar_url": "https://bitbucket.org/account/SomeCoolLabs_CodeShip/avatar/32/",
- },
- },
- }
+ expected = {
+ "commit": "a0ec139843b2bb281ab21a433266ddc498e605dc",
+ "ref": "refs/tags/0.1.2",
+ "git_url": "git@bitbucket.org:somecoollabs/svc-identity.git",
+ "commit_info": {
+ "url": "https://bitbucket.org/somecoollabs/svc-identity/commits/a0ec139843b2bb281ab21a433266ddc498e605dc",
+ "date": "2015-09-25T00:55:08+00:00",
+ "message": "Update version.py to 0.1.2 [skip ci]\n\n(by utilitybelt/scripts/autotag_version.py)\n",
+ "committer": {
+ "username": "CodeShip Tagging",
+ "avatar_url": "https://bitbucket.org/account/SomeCoolLabs_CodeShip/avatar/32/",
+ },
+ },
+ }
- assertSchema('bitbucket_customer_example_tag', expected, bb_webhook)
+ assertSchema("bitbucket_customer_example_tag", expected, bb_webhook)
def test_bitbucket_commit():
- ref = 'refs/heads/somebranch'
- default_branch = 'somebranch'
- repository_name = 'foo/bar'
+ ref = "refs/heads/somebranch"
+ default_branch = "somebranch"
+ repository_name = "foo/bar"
- def lookup_author(_):
- return {
- 'user': {
- 'display_name': 'cooluser',
- 'avatar': 'http://some/avatar/url'
- }
+ def lookup_author(_):
+ return {
+ "user": {"display_name": "cooluser", "avatar": "http://some/avatar/url"}
+ }
+
+ expected = {
+ "commit": u"abdeaf1b2b4a6b9ddf742c1e1754236380435a62",
+ "ref": u"refs/heads/somebranch",
+ "git_url": u"git@bitbucket.org:foo/bar.git",
+ "default_branch": u"somebranch",
+ "commit_info": {
+ "url": u"https://bitbucket.org/foo/bar/commits/abdeaf1b2b4a6b9ddf742c1e1754236380435a62",
+ "date": u"2012-07-24 00:26:36",
+ "message": u"making some changes\n",
+ "author": {
+ "avatar_url": u"http://some/avatar/url",
+ "username": u"cooluser",
+ },
+ },
}
- expected = {
- "commit": u"abdeaf1b2b4a6b9ddf742c1e1754236380435a62",
- "ref": u"refs/heads/somebranch",
- "git_url": u"git@bitbucket.org:foo/bar.git",
- "default_branch": u"somebranch",
- "commit_info": {
- "url": u"https://bitbucket.org/foo/bar/commits/abdeaf1b2b4a6b9ddf742c1e1754236380435a62",
- "date": u"2012-07-24 00:26:36",
- "message": u"making some changes\n",
- "author": {
- "avatar_url": u"http://some/avatar/url",
- "username": u"cooluser",
- }
- }
- }
+ assertSchema(
+ "bitbucket_commit",
+ expected,
+ bb_commit,
+ ref,
+ default_branch,
+ repository_name,
+ lookup_author,
+ )
- assertSchema('bitbucket_commit', expected, bb_commit, ref, default_branch,
- repository_name, lookup_author)
def test_bitbucket_webhook_payload():
- expected = {
- "commit": u"af64ae7188685f8424040b4735ad12941b980d75",
- "ref": u"refs/heads/master",
- "git_url": u"git@bitbucket.org:jsmith/another-repo.git",
- "commit_info": {
- "url": u"https://bitbucket.org/jsmith/another-repo/commits/af64ae7188685f8424040b4735ad12941b980d75",
- "date": u"2015-09-10T20:40:54+00:00",
- "message": u"Dockerfile edited online with Bitbucket",
- "author": {
- "username": u"John Smith",
- "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
- },
- "committer": {
- "username": u"John Smith",
- "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
- },
- },
- }
+ expected = {
+ "commit": u"af64ae7188685f8424040b4735ad12941b980d75",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@bitbucket.org:jsmith/another-repo.git",
+ "commit_info": {
+ "url": u"https://bitbucket.org/jsmith/another-repo/commits/af64ae7188685f8424040b4735ad12941b980d75",
+ "date": u"2015-09-10T20:40:54+00:00",
+ "message": u"Dockerfile edited online with Bitbucket",
+ "author": {
+ "username": u"John Smith",
+ "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
+ },
+ "committer": {
+ "username": u"John Smith",
+ "avatar_url": u"https://bitbucket.org/account/jsmith/avatar/32/",
+ },
+ },
+ }
- assertSchema('bitbucket_webhook', expected, bb_webhook)
+ assertSchema("bitbucket_webhook", expected, bb_webhook)
def test_github_webhook_payload_slash_branch():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/slash/branch',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile',
- 'committer': {
- 'username': u'jsmith',
- },
- 'author': {
- 'username': u'jsmith',
- },
- },
- }
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/slash/branch",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ "committer": {"username": u"jsmith"},
+ "author": {"username": u"jsmith"},
+ },
+ }
- assertSchema('github_webhook_slash_branch', expected, gh_webhook)
+ assertSchema("github_webhook_slash_branch", expected, gh_webhook)
def test_github_webhook_payload():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile',
- 'committer': {
- 'username': u'jsmith',
- },
- 'author': {
- 'username': u'jsmith',
- },
- },
- }
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ "committer": {"username": u"jsmith"},
+ "author": {"username": u"jsmith"},
+ },
+ }
- assertSchema('github_webhook', expected, gh_webhook)
+ assertSchema("github_webhook", expected, gh_webhook)
def test_github_webhook_payload_with_lookup():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile',
- 'committer': {
- 'username': u'jsmith',
- 'url': u'http://github.com/jsmith',
- 'avatar_url': u'http://some/avatar/url',
- },
- 'author': {
- 'username': u'jsmith',
- 'url': u'http://github.com/jsmith',
- 'avatar_url': u'http://some/avatar/url',
- },
- },
- }
-
- def lookup_user(_):
- return {
- 'html_url': 'http://github.com/jsmith',
- 'avatar_url': 'http://some/avatar/url'
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ "committer": {
+ "username": u"jsmith",
+ "url": u"http://github.com/jsmith",
+ "avatar_url": u"http://some/avatar/url",
+ },
+ "author": {
+ "username": u"jsmith",
+ "url": u"http://github.com/jsmith",
+ "avatar_url": u"http://some/avatar/url",
+ },
+ },
}
- assertSchema('github_webhook', expected, gh_webhook, lookup_user=lookup_user)
+ def lookup_user(_):
+ return {
+ "html_url": "http://github.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ }
+
+ assertSchema("github_webhook", expected, gh_webhook, lookup_user=lookup_user)
def test_github_webhook_payload_missing_fields_with_lookup():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile'
- },
- }
-
- def lookup_user(username):
- if not username:
- raise Exception('Fail!')
-
- return {
- 'html_url': 'http://github.com/jsmith',
- 'avatar_url': 'http://some/avatar/url'
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ },
}
- assertSchema('github_webhook_missing', expected, gh_webhook, lookup_user=lookup_user)
+ def lookup_user(username):
+ if not username:
+ raise Exception("Fail!")
+
+ return {
+ "html_url": "http://github.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ }
+
+ assertSchema(
+ "github_webhook_missing", expected, gh_webhook, lookup_user=lookup_user
+ )
def test_gitlab_webhook_payload():
- expected = {
- 'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'ref': u'refs/heads/master',
- 'git_url': u'git@gitlab.com:jsmith/somerepo.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'date': u'2015-08-13T19:33:18+00:00',
- 'message': u'Fix link\n',
- },
- }
+ expected = {
+ "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@gitlab.com:jsmith/somerepo.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "date": u"2015-08-13T19:33:18+00:00",
+ "message": u"Fix link\n",
+ },
+ }
- assertSchema('gitlab_webhook', expected, gl_webhook)
+ assertSchema("gitlab_webhook", expected, gl_webhook)
def test_github_webhook_payload_known_issue():
- expected = {
- "commit": "118b07121695d9f2e40a5ff264fdcc2917680870",
- "ref": "refs/heads/master",
- "default_branch": "master",
- "git_url": "git@github.com:jsmith/docker-test.git",
- "commit_info": {
- "url": "https://github.com/jsmith/docker-test/commit/118b07121695d9f2e40a5ff264fdcc2917680870",
- "date": "2015-09-25T14:55:11-04:00",
- "message": "Fail",
- },
- }
+ expected = {
+ "commit": "118b07121695d9f2e40a5ff264fdcc2917680870",
+ "ref": "refs/heads/master",
+ "default_branch": "master",
+ "git_url": "git@github.com:jsmith/docker-test.git",
+ "commit_info": {
+ "url": "https://github.com/jsmith/docker-test/commit/118b07121695d9f2e40a5ff264fdcc2917680870",
+ "date": "2015-09-25T14:55:11-04:00",
+ "message": "Fail",
+ },
+ }
- assertSchema('github_webhook_noname', expected, gh_webhook)
+ assertSchema("github_webhook_noname", expected, gh_webhook)
def test_github_webhook_payload_missing_fields():
- expected = {
- 'commit': u'410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- 'git_url': u'git@github.com:jsmith/anothertest.git',
- 'commit_info': {
- 'url': u'https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c',
- 'date': u'2015-09-11T14:26:16-04:00',
- 'message': u'Update Dockerfile'
- },
- }
+ expected = {
+ "commit": u"410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ "git_url": u"git@github.com:jsmith/anothertest.git",
+ "commit_info": {
+ "url": u"https://github.com/jsmith/anothertest/commit/410f4cdf8ff09b87f245b13845e8497f90b90a4c",
+ "date": u"2015-09-11T14:26:16-04:00",
+ "message": u"Update Dockerfile",
+ },
+ }
- assertSchema('github_webhook_missing', expected, gh_webhook)
+ assertSchema("github_webhook_missing", expected, gh_webhook)
def test_gitlab_webhook_nocommit_payload():
- assertSkipped('gitlab_webhook_nocommit', gl_webhook)
+ assertSkipped("gitlab_webhook_nocommit", gl_webhook)
def test_gitlab_webhook_multiple_commits():
- expected = {
- 'commit': u'9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53',
- 'ref': u'refs/heads/master',
- 'git_url': u'git@gitlab.com:jsmith/some-test-project.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/jsmith/some-test-project/commit/9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53',
- 'date': u'2016-09-29T15:02:41+00:00',
- 'message': u"Merge branch 'foobar' into 'master'\r\n\r\nAdd changelog\r\n\r\nSome merge thing\r\n\r\nSee merge request !1",
- 'author': {
- 'username': 'jsmith',
- 'url': 'http://gitlab.com/jsmith',
- 'avatar_url': 'http://some/avatar/url'
- },
- },
- }
-
- def lookup_user(_):
- return {
- 'username': 'jsmith',
- 'html_url': 'http://gitlab.com/jsmith',
- 'avatar_url': 'http://some/avatar/url',
+ expected = {
+ "commit": u"9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@gitlab.com:jsmith/some-test-project.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/jsmith/some-test-project/commit/9a052a0b2fbe01d4a1a88638dd9fe31c1c56ef53",
+ "date": u"2016-09-29T15:02:41+00:00",
+ "message": u"Merge branch 'foobar' into 'master'\r\n\r\nAdd changelog\r\n\r\nSome merge thing\r\n\r\nSee merge request !1",
+ "author": {
+ "username": "jsmith",
+ "url": "http://gitlab.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ },
+ },
}
- assertSchema('gitlab_webhook_multicommit', expected, gl_webhook, lookup_user=lookup_user)
+ def lookup_user(_):
+ return {
+ "username": "jsmith",
+ "html_url": "http://gitlab.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ }
+
+ assertSchema(
+ "gitlab_webhook_multicommit", expected, gl_webhook, lookup_user=lookup_user
+ )
def test_gitlab_webhook_for_tag():
- expected = {
- 'commit': u'82b3d5ae55f7080f1e6022629cdb57bfae7cccc7',
- 'commit_info': {
- 'author': {
- 'avatar_url': 'http://some/avatar/url',
- 'url': 'http://gitlab.com/jsmith',
- 'username': 'jsmith'
- },
- 'date': '2015-08-13T19:33:18+00:00',
- 'message': 'Fix link\n',
- 'url': 'https://some/url',
- },
- 'git_url': u'git@example.com:jsmith/example.git',
- 'ref': u'refs/tags/v1.0.0',
- }
-
- def lookup_user(_):
- return {
- 'username': 'jsmith',
- 'html_url': 'http://gitlab.com/jsmith',
- 'avatar_url': 'http://some/avatar/url',
+ expected = {
+ "commit": u"82b3d5ae55f7080f1e6022629cdb57bfae7cccc7",
+ "commit_info": {
+ "author": {
+ "avatar_url": "http://some/avatar/url",
+ "url": "http://gitlab.com/jsmith",
+ "username": "jsmith",
+ },
+ "date": "2015-08-13T19:33:18+00:00",
+ "message": "Fix link\n",
+ "url": "https://some/url",
+ },
+ "git_url": u"git@example.com:jsmith/example.git",
+ "ref": u"refs/tags/v1.0.0",
}
- def lookup_commit(repo_id, commit_sha):
- if commit_sha == '82b3d5ae55f7080f1e6022629cdb57bfae7cccc7':
- return {
- "id": "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7",
- "message": "Fix link\n",
- "timestamp": "2015-08-13T19:33:18+00:00",
- "url": "https://some/url",
- "author_name": "Foo Guy",
- "author_email": "foo@bar.com",
- }
+ def lookup_user(_):
+ return {
+ "username": "jsmith",
+ "html_url": "http://gitlab.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ }
- return None
+ def lookup_commit(repo_id, commit_sha):
+ if commit_sha == "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7":
+ return {
+ "id": "82b3d5ae55f7080f1e6022629cdb57bfae7cccc7",
+ "message": "Fix link\n",
+ "timestamp": "2015-08-13T19:33:18+00:00",
+ "url": "https://some/url",
+ "author_name": "Foo Guy",
+ "author_email": "foo@bar.com",
+ }
- assertSchema('gitlab_webhook_tag', expected, gl_webhook, lookup_user=lookup_user,
- lookup_commit=lookup_commit)
+ return None
+
+ assertSchema(
+ "gitlab_webhook_tag",
+ expected,
+ gl_webhook,
+ lookup_user=lookup_user,
+ lookup_commit=lookup_commit,
+ )
def test_gitlab_webhook_for_tag_nocommit():
- assertSkipped('gitlab_webhook_tag', gl_webhook)
+ assertSkipped("gitlab_webhook_tag", gl_webhook)
def test_gitlab_webhook_for_tag_commit_sha_null():
- assertSkipped('gitlab_webhook_tag_commit_sha_null', gl_webhook)
+ assertSkipped("gitlab_webhook_tag_commit_sha_null", gl_webhook)
def test_gitlab_webhook_for_tag_known_issue():
- expected = {
- 'commit': u'770830e7ca132856991e6db4f7fc0f4dbe20bd5f',
- 'ref': u'refs/tags/thirdtag',
- 'git_url': u'git@gitlab.com:someuser/some-test-project.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f',
- 'date': u'2019-10-17T18:07:48Z',
- 'message': u'Update Dockerfile',
- 'author': {
- 'username': 'someuser',
- 'url': 'http://gitlab.com/someuser',
- 'avatar_url': 'http://some/avatar/url',
- },
- },
- }
-
- def lookup_user(_):
- return {
- 'username': 'someuser',
- 'html_url': 'http://gitlab.com/someuser',
- 'avatar_url': 'http://some/avatar/url',
+ expected = {
+ "commit": u"770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ "ref": u"refs/tags/thirdtag",
+ "git_url": u"git@gitlab.com:someuser/some-test-project.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ "date": u"2019-10-17T18:07:48Z",
+ "message": u"Update Dockerfile",
+ "author": {
+ "username": "someuser",
+ "url": "http://gitlab.com/someuser",
+ "avatar_url": "http://some/avatar/url",
+ },
+ },
}
- assertSchema('gitlab_webhook_tag_commit_issue', expected, gl_webhook, lookup_user=lookup_user)
+ def lookup_user(_):
+ return {
+ "username": "someuser",
+ "html_url": "http://gitlab.com/someuser",
+ "avatar_url": "http://some/avatar/url",
+ }
+
+ assertSchema(
+ "gitlab_webhook_tag_commit_issue", expected, gl_webhook, lookup_user=lookup_user
+ )
def test_gitlab_webhook_payload_known_issue():
- expected = {
- 'commit': u'770830e7ca132856991e6db4f7fc0f4dbe20bd5f',
- 'ref': u'refs/tags/fourthtag',
- 'git_url': u'git@gitlab.com:someuser/some-test-project.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f',
- 'date': u'2019-10-17T18:07:48Z',
- 'message': u'Update Dockerfile',
- },
- }
+ expected = {
+ "commit": u"770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ "ref": u"refs/tags/fourthtag",
+ "git_url": u"git@gitlab.com:someuser/some-test-project.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ "date": u"2019-10-17T18:07:48Z",
+ "message": u"Update Dockerfile",
+ },
+ }
- def lookup_commit(repo_id, commit_sha):
- if commit_sha == '770830e7ca132856991e6db4f7fc0f4dbe20bd5f':
- return {
- "added": [],
- "author": {
- "name": "Some User",
- "email": "someuser@somedomain.com"
- },
- "url": "https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
- "message": "Update Dockerfile",
- "removed": [],
- "modified": [
- "Dockerfile"
- ],
- "id": "770830e7ca132856991e6db4f7fc0f4dbe20bd5f"
- }
+ def lookup_commit(repo_id, commit_sha):
+ if commit_sha == "770830e7ca132856991e6db4f7fc0f4dbe20bd5f":
+ return {
+ "added": [],
+ "author": {"name": "Some User", "email": "someuser@somedomain.com"},
+ "url": "https://gitlab.com/someuser/some-test-project/commit/770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ "message": "Update Dockerfile",
+ "removed": [],
+ "modified": ["Dockerfile"],
+ "id": "770830e7ca132856991e6db4f7fc0f4dbe20bd5f",
+ }
- return None
+ return None
- assertSchema('gitlab_webhook_known_issue', expected, gl_webhook, lookup_commit=lookup_commit)
+ assertSchema(
+ "gitlab_webhook_known_issue", expected, gl_webhook, lookup_commit=lookup_commit
+ )
def test_gitlab_webhook_for_other():
- assertSkipped('gitlab_webhook_other', gl_webhook)
+ assertSkipped("gitlab_webhook_other", gl_webhook)
def test_gitlab_webhook_payload_with_lookup():
- expected = {
- 'commit': u'fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'ref': u'refs/heads/master',
- 'git_url': u'git@gitlab.com:jsmith/somerepo.git',
- 'commit_info': {
- 'url': u'https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e',
- 'date': u'2015-08-13T19:33:18+00:00',
- 'message': u'Fix link\n',
- 'author': {
- 'username': 'jsmith',
- 'url': 'http://gitlab.com/jsmith',
- 'avatar_url': 'http://some/avatar/url',
- },
- },
- }
-
- def lookup_user(_):
- return {
- 'username': 'jsmith',
- 'html_url': 'http://gitlab.com/jsmith',
- 'avatar_url': 'http://some/avatar/url',
+ expected = {
+ "commit": u"fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "ref": u"refs/heads/master",
+ "git_url": u"git@gitlab.com:jsmith/somerepo.git",
+ "commit_info": {
+ "url": u"https://gitlab.com/jsmith/somerepo/commit/fb88379ee45de28a0a4590fddcbd8eff8b36026e",
+ "date": u"2015-08-13T19:33:18+00:00",
+ "message": u"Fix link\n",
+ "author": {
+ "username": "jsmith",
+ "url": "http://gitlab.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ },
+ },
}
- assertSchema('gitlab_webhook', expected, gl_webhook, lookup_user=lookup_user)
+ def lookup_user(_):
+ return {
+ "username": "jsmith",
+ "html_url": "http://gitlab.com/jsmith",
+ "avatar_url": "http://some/avatar/url",
+ }
+
+ assertSchema("gitlab_webhook", expected, gl_webhook, lookup_user=lookup_user)
def test_github_webhook_payload_deleted_commit():
- expected = {
- 'commit': u'456806b662cb903a0febbaed8344f3ed42f27bab',
- 'commit_info': {
- 'author': {
- 'username': u'jsmith'
- },
- 'committer': {
- 'username': u'jsmith'
- },
- 'date': u'2015-12-08T18:07:03-05:00',
- 'message': (u'Merge pull request #1044 from jsmith/errerror\n\n' +
- 'Assign the exception to a variable to log it'),
- 'url': u'https://github.com/jsmith/somerepo/commit/456806b662cb903a0febbaed8344f3ed42f27bab'
- },
- 'git_url': u'git@github.com:jsmith/somerepo.git',
- 'ref': u'refs/heads/master',
- 'default_branch': u'master',
- }
+ expected = {
+ "commit": u"456806b662cb903a0febbaed8344f3ed42f27bab",
+ "commit_info": {
+ "author": {"username": u"jsmith"},
+ "committer": {"username": u"jsmith"},
+ "date": u"2015-12-08T18:07:03-05:00",
+ "message": (
+ u"Merge pull request #1044 from jsmith/errerror\n\n"
+ + "Assign the exception to a variable to log it"
+ ),
+ "url": u"https://github.com/jsmith/somerepo/commit/456806b662cb903a0febbaed8344f3ed42f27bab",
+ },
+ "git_url": u"git@github.com:jsmith/somerepo.git",
+ "ref": u"refs/heads/master",
+ "default_branch": u"master",
+ }
- def lookup_user(_):
- return None
+ def lookup_user(_):
+ return None
- assertSchema('github_webhook_deletedcommit', expected, gh_webhook, lookup_user=lookup_user)
+ assertSchema(
+ "github_webhook_deletedcommit", expected, gh_webhook, lookup_user=lookup_user
+ )
def test_github_webhook_known_issue():
- def lookup_user(_):
- return None
+ def lookup_user(_):
+ return None
- assertSkipped('github_webhook_knownissue', gh_webhook, lookup_user=lookup_user)
+ assertSkipped("github_webhook_knownissue", gh_webhook, lookup_user=lookup_user)
def test_bitbucket_webhook_known_issue():
- assertSkipped('bitbucket_knownissue', bb_webhook)
+ assertSkipped("bitbucket_knownissue", bb_webhook)
diff --git a/buildtrigger/test/test_triggerutil.py b/buildtrigger/test/test_triggerutil.py
index 15f1bec10..6a1b6ce28 100644
--- a/buildtrigger/test/test_triggerutil.py
+++ b/buildtrigger/test/test_triggerutil.py
@@ -4,22 +4,43 @@ import pytest
from buildtrigger.triggerutil import matches_ref
-@pytest.mark.parametrize('ref, filt, matches', [
- ('ref/heads/master', '.+', True),
- ('ref/heads/master', 'heads/.+', True),
- ('ref/heads/master', 'heads/master', True),
- ('ref/heads/slash/branch', 'heads/slash/branch', True),
- ('ref/heads/slash/branch', 'heads/.+', True),
- ('ref/heads/foobar', 'heads/master', False),
- ('ref/heads/master', 'tags/master', False),
-
- ('ref/heads/master', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
- ('ref/heads/alpha', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
- ('ref/heads/beta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
- ('ref/heads/gamma', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', True),
-
- ('ref/heads/delta', '(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)', False),
-])
+@pytest.mark.parametrize(
+ "ref, filt, matches",
+ [
+ ("ref/heads/master", ".+", True),
+ ("ref/heads/master", "heads/.+", True),
+ ("ref/heads/master", "heads/master", True),
+ ("ref/heads/slash/branch", "heads/slash/branch", True),
+ ("ref/heads/slash/branch", "heads/.+", True),
+ ("ref/heads/foobar", "heads/master", False),
+ ("ref/heads/master", "tags/master", False),
+ (
+ "ref/heads/master",
+ "(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
+ True,
+ ),
+ (
+ "ref/heads/alpha",
+ "(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
+ True,
+ ),
+ (
+ "ref/heads/beta",
+ "(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
+ True,
+ ),
+ (
+ "ref/heads/gamma",
+ "(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
+ True,
+ ),
+ (
+ "ref/heads/delta",
+ "(((heads/alpha)|(heads/beta))|(heads/gamma))|(heads/master)",
+ False,
+ ),
+ ],
+)
def test_matches_ref(ref, filt, matches):
- assert matches_ref(ref, re.compile(filt)) == matches
+ assert matches_ref(ref, re.compile(filt)) == matches
diff --git a/buildtrigger/triggerutil.py b/buildtrigger/triggerutil.py
index 5c459e53e..c24effb3f 100644
--- a/buildtrigger/triggerutil.py
+++ b/buildtrigger/triggerutil.py
@@ -3,128 +3,146 @@ import io
import logging
import re
+
class TriggerException(Exception):
- pass
+ pass
+
class TriggerAuthException(TriggerException):
- pass
+ pass
+
class InvalidPayloadException(TriggerException):
- pass
+ pass
+
class BuildArchiveException(TriggerException):
- pass
+ pass
+
class InvalidServiceException(TriggerException):
- pass
+ pass
+
class TriggerActivationException(TriggerException):
- pass
+ pass
+
class TriggerDeactivationException(TriggerException):
- pass
+ pass
+
class TriggerStartException(TriggerException):
- pass
+ pass
+
class ValidationRequestException(TriggerException):
- pass
+ pass
+
class SkipRequestException(TriggerException):
- pass
+ pass
+
class EmptyRepositoryException(TriggerException):
- pass
+ pass
+
class RepositoryReadException(TriggerException):
- pass
+ pass
+
class TriggerProviderException(TriggerException):
- pass
+ pass
+
logger = logging.getLogger(__name__)
+
def determine_build_ref(run_parameters, get_branch_sha, get_tag_sha, default_branch):
- run_parameters = run_parameters or {}
+ run_parameters = run_parameters or {}
- kind = ''
- value = ''
+ kind = ""
+ value = ""
- if 'refs' in run_parameters and run_parameters['refs']:
- kind = run_parameters['refs']['kind']
- value = run_parameters['refs']['name']
- elif 'branch_name' in run_parameters:
- kind = 'branch'
- value = run_parameters['branch_name']
+ if "refs" in run_parameters and run_parameters["refs"]:
+ kind = run_parameters["refs"]["kind"]
+ value = run_parameters["refs"]["name"]
+ elif "branch_name" in run_parameters:
+ kind = "branch"
+ value = run_parameters["branch_name"]
- kind = kind or 'branch'
- value = value or default_branch or 'master'
+ kind = kind or "branch"
+ value = value or default_branch or "master"
- ref = 'refs/tags/' + value if kind == 'tag' else 'refs/heads/' + value
- commit_sha = get_tag_sha(value) if kind == 'tag' else get_branch_sha(value)
- return (commit_sha, ref)
+ ref = "refs/tags/" + value if kind == "tag" else "refs/heads/" + value
+ commit_sha = get_tag_sha(value) if kind == "tag" else get_branch_sha(value)
+ return (commit_sha, ref)
def find_matching_branches(config, branches):
- if 'branchtag_regex' in config:
- try:
- regex = re.compile(config['branchtag_regex'])
- return [branch for branch in branches
- if matches_ref('refs/heads/' + branch, regex)]
- except:
- pass
+ if "branchtag_regex" in config:
+ try:
+ regex = re.compile(config["branchtag_regex"])
+ return [
+ branch
+ for branch in branches
+ if matches_ref("refs/heads/" + branch, regex)
+ ]
+ except:
+ pass
- return branches
+ return branches
def should_skip_commit(metadata):
- if 'commit_info' in metadata:
- message = metadata['commit_info']['message']
- return '[skip build]' in message or '[build skip]' in message
- return False
+ if "commit_info" in metadata:
+ message = metadata["commit_info"]["message"]
+ return "[skip build]" in message or "[build skip]" in message
+ return False
def raise_if_skipped_build(prepared_build, config):
- """ Raises a SkipRequestException if the given build should be skipped. """
- # Check to ensure we have metadata.
- if not prepared_build.metadata:
- logger.debug('Skipping request due to missing metadata for prepared build')
- raise SkipRequestException()
+ """ Raises a SkipRequestException if the given build should be skipped. """
+ # Check to ensure we have metadata.
+ if not prepared_build.metadata:
+ logger.debug("Skipping request due to missing metadata for prepared build")
+ raise SkipRequestException()
- # Check the branchtag regex.
- if 'branchtag_regex' in config:
- try:
- regex = re.compile(config['branchtag_regex'])
- except:
- regex = re.compile('.*')
+ # Check the branchtag regex.
+ if "branchtag_regex" in config:
+ try:
+ regex = re.compile(config["branchtag_regex"])
+ except:
+ regex = re.compile(".*")
- if not matches_ref(prepared_build.metadata.get('ref'), regex):
- raise SkipRequestException()
+ if not matches_ref(prepared_build.metadata.get("ref"), regex):
+ raise SkipRequestException()
- # Check the commit message.
- if should_skip_commit(prepared_build.metadata):
- logger.debug('Skipping request due to commit message request')
- raise SkipRequestException()
+ # Check the commit message.
+ if should_skip_commit(prepared_build.metadata):
+ logger.debug("Skipping request due to commit message request")
+ raise SkipRequestException()
def matches_ref(ref, regex):
- match_string = ref.split('/', 1)[1]
- if not regex:
- return False
+ match_string = ref.split("/", 1)[1]
+ if not regex:
+ return False
- m = regex.match(match_string)
- if not m:
- return False
+ m = regex.match(match_string)
+ if not m:
+ return False
- return len(m.group(0)) == len(match_string)
+ return len(m.group(0)) == len(match_string)
def raise_unsupported():
- raise io.UnsupportedOperation
+ raise io.UnsupportedOperation
def get_trigger_config(trigger):
- try:
- return json.loads(trigger.config)
- except ValueError:
- return {}
+ try:
+ return json.loads(trigger.config)
+ except ValueError:
+ return {}
diff --git a/conf/gunicorn_local.py b/conf/gunicorn_local.py
index b33558ef2..ab9afc0ec 100644
--- a/conf/gunicorn_local.py
+++ b/conf/gunicorn_local.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -10,18 +11,24 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=True)
-bind = '0.0.0.0:5000'
-workers = get_worker_count('local', 2, minimum=2, maximum=8)
-worker_class = 'gevent'
+bind = "0.0.0.0:5000"
+workers = get_worker_count("local", 2, minimum=2, maximum=8)
+worker_class = "gevent"
daemon = False
-pythonpath = '.'
+pythonpath = "."
preload_app = True
+
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting local gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/conf/gunicorn_registry.py b/conf/gunicorn_registry.py
index 23590ba45..c072c740f 100644
--- a/conf/gunicorn_registry.py
+++ b/conf/gunicorn_registry.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
-bind = 'unix:/tmp/gunicorn_registry.sock'
-workers = get_worker_count('registry', 4, minimum=8, maximum=64)
-worker_class = 'gevent'
-pythonpath = '.'
+bind = "unix:/tmp/gunicorn_registry.sock"
+workers = get_worker_count("registry", 4, minimum=8, maximum=64)
+worker_class = "gevent"
+pythonpath = "."
preload_app = True
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting registry gunicorn with %s workers and %s worker class', workers,
- worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting registry gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/conf/gunicorn_secscan.py b/conf/gunicorn_secscan.py
index daea39c38..788d79808 100644
--- a/conf/gunicorn_secscan.py
+++ b/conf/gunicorn_secscan.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -10,19 +11,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
-bind = 'unix:/tmp/gunicorn_secscan.sock'
-workers = get_worker_count('secscan', 2, minimum=2, maximum=4)
-worker_class = 'gevent'
-pythonpath = '.'
+bind = "unix:/tmp/gunicorn_secscan.sock"
+workers = get_worker_count("secscan", 2, minimum=2, maximum=4)
+worker_class = "gevent"
+pythonpath = "."
preload_app = True
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting secscan gunicorn with %s workers and %s worker class', workers,
- worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting secscan gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/conf/gunicorn_verbs.py b/conf/gunicorn_verbs.py
index 9502f7563..2e6482384 100644
--- a/conf/gunicorn_verbs.py
+++ b/conf/gunicorn_verbs.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -10,18 +11,21 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
-bind = 'unix:/tmp/gunicorn_verbs.sock'
-workers = get_worker_count('verbs', 2, minimum=2, maximum=32)
-pythonpath = '.'
+bind = "unix:/tmp/gunicorn_verbs.sock"
+workers = get_worker_count("verbs", 2, minimum=2, maximum=32)
+pythonpath = "."
preload_app = True
timeout = 2000 # Because sync workers
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting verbs gunicorn with %s workers and sync worker class', workers)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting verbs gunicorn with %s workers and sync worker class", workers
+ )
diff --git a/conf/gunicorn_web.py b/conf/gunicorn_web.py
index 8bd1abaa0..2461861c2 100644
--- a/conf/gunicorn_web.py
+++ b/conf/gunicorn_web.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -11,18 +12,23 @@ from util.workers import get_worker_count
logconfig = logfile_path(debug=False)
-bind = 'unix:/tmp/gunicorn_web.sock'
-workers = get_worker_count('web', 2, minimum=2, maximum=32)
-worker_class = 'gevent'
-pythonpath = '.'
+bind = "unix:/tmp/gunicorn_web.sock"
+workers = get_worker_count("web", 2, minimum=2, maximum=32)
+worker_class = "gevent"
+pythonpath = "."
preload_app = True
+
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting web gunicorn with %s workers and %s worker class', workers,
- worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting web gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/conf/init/nginx_conf_create.py b/conf/init/nginx_conf_create.py
index 56a59a2d2..7b264a015 100644
--- a/conf/init/nginx_conf_create.py
+++ b/conf/init/nginx_conf_create.py
@@ -7,120 +7,130 @@ import jinja2
QUAYPATH = os.getenv("QUAYPATH", ".")
QUAYDIR = os.getenv("QUAYDIR", "/")
QUAYCONF_DIR = os.getenv("QUAYCONF", os.path.join(QUAYDIR, QUAYPATH, "conf"))
-STATIC_DIR = os.path.join(QUAYDIR, 'static')
+STATIC_DIR = os.path.join(QUAYDIR, "static")
-SSL_PROTOCOL_DEFAULTS = ['TLSv1', 'TLSv1.1', 'TLSv1.2']
+SSL_PROTOCOL_DEFAULTS = ["TLSv1", "TLSv1.1", "TLSv1.2"]
SSL_CIPHER_DEFAULTS = [
- 'ECDHE-RSA-AES128-GCM-SHA256',
- 'ECDHE-ECDSA-AES128-GCM-SHA256',
- 'ECDHE-RSA-AES256-GCM-SHA384',
- 'ECDHE-ECDSA-AES256-GCM-SHA384',
- 'DHE-RSA-AES128-GCM-SHA256',
- 'DHE-DSS-AES128-GCM-SHA256',
- 'kEDH+AESGCM',
- 'ECDHE-RSA-AES128-SHA256',
- 'ECDHE-ECDSA-AES128-SHA256',
- 'ECDHE-RSA-AES128-SHA',
- 'ECDHE-ECDSA-AES128-SHA',
- 'ECDHE-RSA-AES256-SHA384',
- 'ECDHE-ECDSA-AES256-SHA384',
- 'ECDHE-RSA-AES256-SHA',
- 'ECDHE-ECDSA-AES256-SHA',
- 'DHE-RSA-AES128-SHA256',
- 'DHE-RSA-AES128-SHA',
- 'DHE-DSS-AES128-SHA256',
- 'DHE-RSA-AES256-SHA256',
- 'DHE-DSS-AES256-SHA',
- 'DHE-RSA-AES256-SHA',
- 'AES128-GCM-SHA256',
- 'AES256-GCM-SHA384',
- 'AES128-SHA256',
- 'AES256-SHA256',
- 'AES128-SHA',
- 'AES256-SHA',
- 'AES',
- 'CAMELLIA',
- '!3DES',
- '!aNULL',
- '!eNULL',
- '!EXPORT',
- '!DES',
- '!RC4',
- '!MD5',
- '!PSK',
- '!aECDH',
- '!EDH-DSS-DES-CBC3-SHA',
- '!EDH-RSA-DES-CBC3-SHA',
- '!KRB5-DES-CBC3-SHA',
+ "ECDHE-RSA-AES128-GCM-SHA256",
+ "ECDHE-ECDSA-AES128-GCM-SHA256",
+ "ECDHE-RSA-AES256-GCM-SHA384",
+ "ECDHE-ECDSA-AES256-GCM-SHA384",
+ "DHE-RSA-AES128-GCM-SHA256",
+ "DHE-DSS-AES128-GCM-SHA256",
+ "kEDH+AESGCM",
+ "ECDHE-RSA-AES128-SHA256",
+ "ECDHE-ECDSA-AES128-SHA256",
+ "ECDHE-RSA-AES128-SHA",
+ "ECDHE-ECDSA-AES128-SHA",
+ "ECDHE-RSA-AES256-SHA384",
+ "ECDHE-ECDSA-AES256-SHA384",
+ "ECDHE-RSA-AES256-SHA",
+ "ECDHE-ECDSA-AES256-SHA",
+ "DHE-RSA-AES128-SHA256",
+ "DHE-RSA-AES128-SHA",
+ "DHE-DSS-AES128-SHA256",
+ "DHE-RSA-AES256-SHA256",
+ "DHE-DSS-AES256-SHA",
+ "DHE-RSA-AES256-SHA",
+ "AES128-GCM-SHA256",
+ "AES256-GCM-SHA384",
+ "AES128-SHA256",
+ "AES256-SHA256",
+ "AES128-SHA",
+ "AES256-SHA",
+ "AES",
+ "CAMELLIA",
+ "!3DES",
+ "!aNULL",
+ "!eNULL",
+ "!EXPORT",
+ "!DES",
+ "!RC4",
+ "!MD5",
+ "!PSK",
+ "!aECDH",
+ "!EDH-DSS-DES-CBC3-SHA",
+ "!EDH-RSA-DES-CBC3-SHA",
+ "!KRB5-DES-CBC3-SHA",
]
-def write_config(filename, **kwargs):
- with open(filename + ".jnj") as f:
- template = jinja2.Template(f.read())
- rendered = template.render(kwargs)
- with open(filename, 'w') as f:
- f.write(rendered)
+def write_config(filename, **kwargs):
+ with open(filename + ".jnj") as f:
+ template = jinja2.Template(f.read())
+ rendered = template.render(kwargs)
+
+ with open(filename, "w") as f:
+ f.write(rendered)
def generate_nginx_config(config):
- """
+ """
Generates nginx config from the app config
"""
- config = config or {}
- use_https = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.key'))
- use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/ssl.old.key'))
- v1_only_domain = config.get('V1_ONLY_DOMAIN', None)
- enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
- ssl_protocols = config.get('SSL_PROTOCOLS', SSL_PROTOCOL_DEFAULTS)
- ssl_ciphers = config.get('SSL_CIPHERS', SSL_CIPHER_DEFAULTS)
+ config = config or {}
+ use_https = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.key"))
+ use_old_certs = os.path.exists(os.path.join(QUAYCONF_DIR, "stack/ssl.old.key"))
+ v1_only_domain = config.get("V1_ONLY_DOMAIN", None)
+ enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
+ ssl_protocols = config.get("SSL_PROTOCOLS", SSL_PROTOCOL_DEFAULTS)
+ ssl_ciphers = config.get("SSL_CIPHERS", SSL_CIPHER_DEFAULTS)
- write_config(os.path.join(QUAYCONF_DIR, 'nginx/nginx.conf'), use_https=use_https,
- use_old_certs=use_old_certs,
- enable_rate_limits=enable_rate_limits,
- v1_only_domain=v1_only_domain,
- ssl_protocols=ssl_protocols,
- ssl_ciphers=':'.join(ssl_ciphers))
+ write_config(
+ os.path.join(QUAYCONF_DIR, "nginx/nginx.conf"),
+ use_https=use_https,
+ use_old_certs=use_old_certs,
+ enable_rate_limits=enable_rate_limits,
+ v1_only_domain=v1_only_domain,
+ ssl_protocols=ssl_protocols,
+ ssl_ciphers=":".join(ssl_ciphers),
+ )
def generate_server_config(config):
- """
+ """
Generates server config from the app config
"""
- config = config or {}
- tuf_server = config.get('TUF_SERVER', None)
- tuf_host = config.get('TUF_HOST', None)
- signing_enabled = config.get('FEATURE_SIGNING', False)
- maximum_layer_size = config.get('MAXIMUM_LAYER_SIZE', '20G')
- enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
+ config = config or {}
+ tuf_server = config.get("TUF_SERVER", None)
+ tuf_host = config.get("TUF_HOST", None)
+ signing_enabled = config.get("FEATURE_SIGNING", False)
+ maximum_layer_size = config.get("MAXIMUM_LAYER_SIZE", "20G")
+ enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
- write_config(
- os.path.join(QUAYCONF_DIR, 'nginx/server-base.conf'), tuf_server=tuf_server, tuf_host=tuf_host,
- signing_enabled=signing_enabled, maximum_layer_size=maximum_layer_size,
- enable_rate_limits=enable_rate_limits,
- static_dir=STATIC_DIR)
+ write_config(
+ os.path.join(QUAYCONF_DIR, "nginx/server-base.conf"),
+ tuf_server=tuf_server,
+ tuf_host=tuf_host,
+ signing_enabled=signing_enabled,
+ maximum_layer_size=maximum_layer_size,
+ enable_rate_limits=enable_rate_limits,
+ static_dir=STATIC_DIR,
+ )
def generate_rate_limiting_config(config):
- """
+ """
Generates rate limiting config from the app config
"""
- config = config or {}
- non_rate_limited_namespaces = config.get('NON_RATE_LIMITED_NAMESPACES') or set()
- enable_rate_limits = config.get('FEATURE_RATE_LIMITS', False)
- write_config(
- os.path.join(QUAYCONF_DIR, 'nginx/rate-limiting.conf'),
- non_rate_limited_namespaces=non_rate_limited_namespaces,
- enable_rate_limits=enable_rate_limits,
- static_dir=STATIC_DIR)
+ config = config or {}
+ non_rate_limited_namespaces = config.get("NON_RATE_LIMITED_NAMESPACES") or set()
+ enable_rate_limits = config.get("FEATURE_RATE_LIMITS", False)
+ write_config(
+ os.path.join(QUAYCONF_DIR, "nginx/rate-limiting.conf"),
+ non_rate_limited_namespaces=non_rate_limited_namespaces,
+ enable_rate_limits=enable_rate_limits,
+ static_dir=STATIC_DIR,
+ )
+
if __name__ == "__main__":
- if os.path.exists(os.path.join(QUAYCONF_DIR, 'stack/config.yaml')):
- with open(os.path.join(QUAYCONF_DIR, 'stack/config.yaml'), 'r') as f:
- config = yaml.load(f)
- else:
- config = None
+ if os.path.exists(os.path.join(QUAYCONF_DIR, "stack/config.yaml")):
+ with open(os.path.join(QUAYCONF_DIR, "stack/config.yaml"), "r") as f:
+ config = yaml.load(f)
+ else:
+ config = None
- generate_rate_limiting_config(config)
- generate_server_config(config)
- generate_nginx_config(config)
+ generate_rate_limiting_config(config)
+ generate_server_config(config)
+ generate_nginx_config(config)
diff --git a/conf/init/supervisord_conf_create.py b/conf/init/supervisord_conf_create.py
index 50f5cabbf..0463d6dfd 100644
--- a/conf/init/supervisord_conf_create.py
+++ b/conf/init/supervisord_conf_create.py
@@ -12,136 +12,74 @@ QUAY_OVERRIDE_SERVICES = os.getenv("QUAY_OVERRIDE_SERVICES", [])
def default_services():
- return {
- "blobuploadcleanupworker": {
- "autostart": "true"
- },
- "buildlogsarchiver": {
- "autostart": "true"
- },
- "builder": {
- "autostart": "true"
- },
- "chunkcleanupworker": {
- "autostart": "true"
- },
- "expiredappspecifictokenworker": {
- "autostart": "true"
- },
- "exportactionlogsworker": {
- "autostart": "true"
- },
- "gcworker": {
- "autostart": "true"
- },
- "globalpromstats": {
- "autostart": "true"
- },
- "labelbackfillworker": {
- "autostart": "true"
- },
- "logrotateworker": {
- "autostart": "true"
- },
- "namespacegcworker": {
- "autostart": "true"
- },
- "notificationworker": {
- "autostart": "true"
- },
- "queuecleanupworker": {
- "autostart": "true"
- },
- "repositoryactioncounter": {
- "autostart": "true"
- },
- "security_notification_worker": {
- "autostart": "true"
- },
- "securityworker": {
- "autostart": "true"
- },
- "storagereplication": {
- "autostart": "true"
- },
- "tagbackfillworker": {
- "autostart": "true"
- },
- "teamsyncworker": {
- "autostart": "true"
- },
- "dnsmasq": {
- "autostart": "true"
- },
- "gunicorn-registry": {
- "autostart": "true"
- },
- "gunicorn-secscan": {
- "autostart": "true"
- },
- "gunicorn-verbs": {
- "autostart": "true"
- },
- "gunicorn-web": {
- "autostart": "true"
- },
- "ip-resolver-update-worker": {
- "autostart": "true"
- },
- "jwtproxy": {
- "autostart": "true"
- },
- "memcache": {
- "autostart": "true"
- },
- "nginx": {
- "autostart": "true"
- },
- "prometheus-aggregator": {
- "autostart": "true"
- },
- "servicekey": {
- "autostart": "true"
- },
- "repomirrorworker": {
- "autostart": "false"
+ return {
+ "blobuploadcleanupworker": {"autostart": "true"},
+ "buildlogsarchiver": {"autostart": "true"},
+ "builder": {"autostart": "true"},
+ "chunkcleanupworker": {"autostart": "true"},
+ "expiredappspecifictokenworker": {"autostart": "true"},
+ "exportactionlogsworker": {"autostart": "true"},
+ "gcworker": {"autostart": "true"},
+ "globalpromstats": {"autostart": "true"},
+ "labelbackfillworker": {"autostart": "true"},
+ "logrotateworker": {"autostart": "true"},
+ "namespacegcworker": {"autostart": "true"},
+ "notificationworker": {"autostart": "true"},
+ "queuecleanupworker": {"autostart": "true"},
+ "repositoryactioncounter": {"autostart": "true"},
+ "security_notification_worker": {"autostart": "true"},
+ "securityworker": {"autostart": "true"},
+ "storagereplication": {"autostart": "true"},
+ "tagbackfillworker": {"autostart": "true"},
+ "teamsyncworker": {"autostart": "true"},
+ "dnsmasq": {"autostart": "true"},
+ "gunicorn-registry": {"autostart": "true"},
+ "gunicorn-secscan": {"autostart": "true"},
+ "gunicorn-verbs": {"autostart": "true"},
+ "gunicorn-web": {"autostart": "true"},
+ "ip-resolver-update-worker": {"autostart": "true"},
+ "jwtproxy": {"autostart": "true"},
+ "memcache": {"autostart": "true"},
+ "nginx": {"autostart": "true"},
+ "prometheus-aggregator": {"autostart": "true"},
+ "servicekey": {"autostart": "true"},
+ "repomirrorworker": {"autostart": "false"},
}
-}
def generate_supervisord_config(filename, config):
- with open(filename + ".jnj") as f:
- template = jinja2.Template(f.read())
- rendered = template.render(config=config)
+ with open(filename + ".jnj") as f:
+ template = jinja2.Template(f.read())
+ rendered = template.render(config=config)
- with open(filename, 'w') as f:
- f.write(rendered)
+ with open(filename, "w") as f:
+ f.write(rendered)
def limit_services(config, enabled_services):
- if enabled_services == []:
- return
+ if enabled_services == []:
+ return
- for service in config.keys():
- if service in enabled_services:
- config[service]["autostart"] = "true"
- else:
- config[service]["autostart"] = "false"
+ for service in config.keys():
+ if service in enabled_services:
+ config[service]["autostart"] = "true"
+ else:
+ config[service]["autostart"] = "false"
def override_services(config, override_services):
- if override_services == []:
- return
+ if override_services == []:
+ return
- for service in config.keys():
- if service + "=true" in override_services:
- config[service]["autostart"] = "true"
- elif service + "=false" in override_services:
- config[service]["autostart"] = "false"
+ for service in config.keys():
+ if service + "=true" in override_services:
+ config[service]["autostart"] = "true"
+ elif service + "=false" in override_services:
+ config[service]["autostart"] = "false"
if __name__ == "__main__":
- config = default_services()
- limit_services(config, QUAY_SERVICES)
- override_services(config, QUAY_OVERRIDE_SERVICES)
- generate_supervisord_config(os.path.join(QUAYCONF_DIR, 'supervisord.conf'), config)
+ config = default_services()
+ limit_services(config, QUAY_SERVICES)
+ override_services(config, QUAY_OVERRIDE_SERVICES)
+ generate_supervisord_config(os.path.join(QUAYCONF_DIR, "supervisord.conf"), config)
diff --git a/conf/init/test/test_supervisord_conf_create.py b/conf/init/test/test_supervisord_conf_create.py
index 8972b2e39..75c7313d4 100644
--- a/conf/init/test/test_supervisord_conf_create.py
+++ b/conf/init/test/test_supervisord_conf_create.py
@@ -6,17 +6,23 @@ import jinja2
from ..supervisord_conf_create import QUAYCONF_DIR, default_services, limit_services
+
def render_supervisord_conf(config):
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj")) as f:
- template = jinja2.Template(f.read())
- return template.render(config=config)
+ with open(
+ os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "../../supervisord.conf.jnj"
+ )
+ ) as f:
+ template = jinja2.Template(f.read())
+ return template.render(config=config)
+
def test_supervisord_conf_create_defaults():
- config = default_services()
- limit_services(config, [])
- rendered = render_supervisord_conf(config)
+ config = default_services()
+ limit_services(config, [])
+ rendered = render_supervisord_conf(config)
- expected = """[supervisord]
+ expected = """[supervisord]
nodaemon=true
[unix_http_server]
@@ -392,14 +398,15 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true
stderr_events_enabled = true
# EOF NO NEWLINE"""
- assert rendered == expected
+ assert rendered == expected
+
def test_supervisord_conf_create_all_overrides():
- config = default_services()
- limit_services(config, "servicekey,prometheus-aggregator")
- rendered = render_supervisord_conf(config)
+ config = default_services()
+ limit_services(config, "servicekey,prometheus-aggregator")
+ rendered = render_supervisord_conf(config)
- expected = """[supervisord]
+ expected = """[supervisord]
nodaemon=true
[unix_http_server]
@@ -775,4 +782,4 @@ stderr_logfile_maxbytes=0
stdout_events_enabled = true
stderr_events_enabled = true
# EOF NO NEWLINE"""
- assert rendered == expected
+ assert rendered == expected
diff --git a/config.py b/config.py
index ae742ece8..df693df58 100644
--- a/config.py
+++ b/config.py
@@ -7,603 +7,682 @@ from _init import ROOT_DIR, CONF_DIR
def build_requests_session():
- sess = requests.Session()
- adapter = requests.adapters.HTTPAdapter(pool_connections=100,
- pool_maxsize=100)
- sess.mount('http://', adapter)
- sess.mount('https://', adapter)
- return sess
+ sess = requests.Session()
+ adapter = requests.adapters.HTTPAdapter(pool_connections=100, pool_maxsize=100)
+ sess.mount("http://", adapter)
+ sess.mount("https://", adapter)
+ return sess
# The set of configuration key names that will be accessible in the client. Since these
# values are sent to the frontend, DO NOT PLACE ANY SECRETS OR KEYS in this list.
-CLIENT_WHITELIST = ['SERVER_HOSTNAME', 'PREFERRED_URL_SCHEME', 'MIXPANEL_KEY',
- 'STRIPE_PUBLISHABLE_KEY', 'ENTERPRISE_LOGO_URL', 'SENTRY_PUBLIC_DSN',
- 'AUTHENTICATION_TYPE', 'REGISTRY_TITLE', 'REGISTRY_TITLE_SHORT',
- 'CONTACT_INFO', 'AVATAR_KIND', 'LOCAL_OAUTH_HANDLER',
- 'SETUP_COMPLETE', 'DEBUG', 'MARKETO_MUNCHKIN_ID',
- 'STATIC_SITE_BUCKET', 'RECAPTCHA_SITE_KEY', 'CHANNEL_COLORS',
- 'TAG_EXPIRATION_OPTIONS', 'INTERNAL_OIDC_SERVICE_ID',
- 'SEARCH_RESULTS_PER_PAGE', 'SEARCH_MAX_RESULT_PAGE_COUNT', 'BRANDING']
+CLIENT_WHITELIST = [
+ "SERVER_HOSTNAME",
+ "PREFERRED_URL_SCHEME",
+ "MIXPANEL_KEY",
+ "STRIPE_PUBLISHABLE_KEY",
+ "ENTERPRISE_LOGO_URL",
+ "SENTRY_PUBLIC_DSN",
+ "AUTHENTICATION_TYPE",
+ "REGISTRY_TITLE",
+ "REGISTRY_TITLE_SHORT",
+ "CONTACT_INFO",
+ "AVATAR_KIND",
+ "LOCAL_OAUTH_HANDLER",
+ "SETUP_COMPLETE",
+ "DEBUG",
+ "MARKETO_MUNCHKIN_ID",
+ "STATIC_SITE_BUCKET",
+ "RECAPTCHA_SITE_KEY",
+ "CHANNEL_COLORS",
+ "TAG_EXPIRATION_OPTIONS",
+ "INTERNAL_OIDC_SERVICE_ID",
+ "SEARCH_RESULTS_PER_PAGE",
+ "SEARCH_MAX_RESULT_PAGE_COUNT",
+ "BRANDING",
+]
def frontend_visible_config(config_dict):
- visible_dict = {}
- for name in CLIENT_WHITELIST:
- if name.lower().find('secret') >= 0:
- raise Exception('Cannot whitelist secrets: %s' % name)
+ visible_dict = {}
+ for name in CLIENT_WHITELIST:
+ if name.lower().find("secret") >= 0:
+ raise Exception("Cannot whitelist secrets: %s" % name)
- if name in config_dict:
- visible_dict[name] = config_dict.get(name, None)
- if 'ENTERPRISE_LOGO_URL' in config_dict:
- visible_dict['BRANDING'] = visible_dict.get('BRANDING', {})
- visible_dict['BRANDING']['logo'] = config_dict['ENTERPRISE_LOGO_URL']
+ if name in config_dict:
+ visible_dict[name] = config_dict.get(name, None)
+ if "ENTERPRISE_LOGO_URL" in config_dict:
+ visible_dict["BRANDING"] = visible_dict.get("BRANDING", {})
+ visible_dict["BRANDING"]["logo"] = config_dict["ENTERPRISE_LOGO_URL"]
- return visible_dict
+ return visible_dict
# Configuration that should not be changed by end users
class ImmutableConfig(object):
- # Requests based HTTP client with a large request pool
- HTTPCLIENT = build_requests_session()
+ # Requests based HTTP client with a large request pool
+ HTTPCLIENT = build_requests_session()
- # Status tag config
- STATUS_TAGS = {}
- for tag_name in ['building', 'failed', 'none', 'ready', 'cancelled']:
- tag_path = os.path.join(ROOT_DIR, 'buildstatus', tag_name + '.svg')
- with open(tag_path) as tag_svg:
- STATUS_TAGS[tag_name] = tag_svg.read()
+ # Status tag config
+ STATUS_TAGS = {}
+ for tag_name in ["building", "failed", "none", "ready", "cancelled"]:
+ tag_path = os.path.join(ROOT_DIR, "buildstatus", tag_name + ".svg")
+ with open(tag_path) as tag_svg:
+ STATUS_TAGS[tag_name] = tag_svg.read()
- # Reverse DNS prefixes that are reserved for internal use on labels and should not be allowable
- # to be set via the API.
- DEFAULT_LABEL_KEY_RESERVED_PREFIXES = ['com.docker.', 'io.docker.', 'org.dockerproject.',
- 'org.opencontainers.', 'io.cncf.',
- 'io.kubernetes.', 'io.k8s.',
- 'io.quay', 'com.coreos', 'com.tectonic',
- 'internal', 'quay']
+ # Reverse DNS prefixes that are reserved for internal use on labels and should not be allowable
+ # to be set via the API.
+ DEFAULT_LABEL_KEY_RESERVED_PREFIXES = [
+ "com.docker.",
+ "io.docker.",
+ "org.dockerproject.",
+ "org.opencontainers.",
+ "io.cncf.",
+ "io.kubernetes.",
+ "io.k8s.",
+ "io.quay",
+ "com.coreos",
+ "com.tectonic",
+ "internal",
+ "quay",
+ ]
- # Colors for local avatars.
- AVATAR_COLORS = ['#969696', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728',
- '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2',
- '#7f7f7f', '#c7c7c7', '#bcbd22', '#1f77b4', '#17becf', '#9edae5', '#393b79',
- '#5254a3', '#6b6ecf', '#9c9ede', '#9ecae1', '#31a354', '#b5cf6b', '#a1d99b',
- '#8c6d31', '#ad494a', '#e7ba52', '#a55194']
+ # Colors for local avatars.
+ AVATAR_COLORS = [
+ "#969696",
+ "#aec7e8",
+ "#ff7f0e",
+ "#ffbb78",
+ "#2ca02c",
+ "#98df8a",
+ "#d62728",
+ "#ff9896",
+ "#9467bd",
+ "#c5b0d5",
+ "#8c564b",
+ "#c49c94",
+ "#e377c2",
+ "#f7b6d2",
+ "#7f7f7f",
+ "#c7c7c7",
+ "#bcbd22",
+ "#1f77b4",
+ "#17becf",
+ "#9edae5",
+ "#393b79",
+ "#5254a3",
+ "#6b6ecf",
+ "#9c9ede",
+ "#9ecae1",
+ "#31a354",
+ "#b5cf6b",
+ "#a1d99b",
+ "#8c6d31",
+ "#ad494a",
+ "#e7ba52",
+ "#a55194",
+ ]
- # Colors for channels.
- CHANNEL_COLORS = ['#969696', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728',
- '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2',
- '#7f7f7f', '#c7c7c7', '#bcbd22', '#1f77b4', '#17becf', '#9edae5', '#393b79',
- '#5254a3', '#6b6ecf', '#9c9ede', '#9ecae1', '#31a354', '#b5cf6b', '#a1d99b',
- '#8c6d31', '#ad494a', '#e7ba52', '#a55194']
+ # Colors for channels.
+ CHANNEL_COLORS = [
+ "#969696",
+ "#aec7e8",
+ "#ff7f0e",
+ "#ffbb78",
+ "#2ca02c",
+ "#98df8a",
+ "#d62728",
+ "#ff9896",
+ "#9467bd",
+ "#c5b0d5",
+ "#8c564b",
+ "#c49c94",
+ "#e377c2",
+ "#f7b6d2",
+ "#7f7f7f",
+ "#c7c7c7",
+ "#bcbd22",
+ "#1f77b4",
+ "#17becf",
+ "#9edae5",
+ "#393b79",
+ "#5254a3",
+ "#6b6ecf",
+ "#9c9ede",
+ "#9ecae1",
+ "#31a354",
+ "#b5cf6b",
+ "#a1d99b",
+ "#8c6d31",
+ "#ad494a",
+ "#e7ba52",
+ "#a55194",
+ ]
- PROPAGATE_EXCEPTIONS = True
+ PROPAGATE_EXCEPTIONS = True
class DefaultConfig(ImmutableConfig):
- # Flask config
- JSONIFY_PRETTYPRINT_REGULAR = False
- SESSION_COOKIE_SECURE = False
-
- SESSION_COOKIE_HTTPONLY = True
- SESSION_COOKIE_SAMESITE = 'Lax'
+ # Flask config
+ JSONIFY_PRETTYPRINT_REGULAR = False
+ SESSION_COOKIE_SECURE = False
- LOGGING_LEVEL = 'DEBUG'
- SEND_FILE_MAX_AGE_DEFAULT = 0
- PREFERRED_URL_SCHEME = 'http'
- SERVER_HOSTNAME = 'localhost:5000'
+ SESSION_COOKIE_HTTPONLY = True
+ SESSION_COOKIE_SAMESITE = "Lax"
- REGISTRY_TITLE = 'Project Quay'
- REGISTRY_TITLE_SHORT = 'Project Quay'
+ LOGGING_LEVEL = "DEBUG"
+ SEND_FILE_MAX_AGE_DEFAULT = 0
+ PREFERRED_URL_SCHEME = "http"
+ SERVER_HOSTNAME = "localhost:5000"
- CONTACT_INFO = []
+ REGISTRY_TITLE = "Project Quay"
+ REGISTRY_TITLE_SHORT = "Project Quay"
- # Mail config
- MAIL_SERVER = ''
- MAIL_USE_TLS = True
- MAIL_PORT = 587
- MAIL_USERNAME = None
- MAIL_PASSWORD = None
- MAIL_DEFAULT_SENDER = 'example@projectquay.io'
- MAIL_FAIL_SILENTLY = False
- TESTING = True
+ CONTACT_INFO = []
- # DB config
- DB_URI = 'sqlite:///test/data/test.db'
- DB_CONNECTION_ARGS = {
- 'threadlocals': True,
- 'autorollback': True,
- }
+ # Mail config
+ MAIL_SERVER = ""
+ MAIL_USE_TLS = True
+ MAIL_PORT = 587
+ MAIL_USERNAME = None
+ MAIL_PASSWORD = None
+ MAIL_DEFAULT_SENDER = "example@projectquay.io"
+ MAIL_FAIL_SILENTLY = False
+ TESTING = True
- @staticmethod
- def create_transaction(db):
- return db.transaction()
+ # DB config
+ DB_URI = "sqlite:///test/data/test.db"
+ DB_CONNECTION_ARGS = {"threadlocals": True, "autorollback": True}
- DB_TRANSACTION_FACTORY = create_transaction
+ @staticmethod
+ def create_transaction(db):
+ return db.transaction()
- # If set to 'readonly', the entire registry is placed into read only mode and no write operations
- # may be performed against it.
- REGISTRY_STATE = 'normal'
+ DB_TRANSACTION_FACTORY = create_transaction
- # If set to true, TLS is used, but is terminated by an external service (such as a load balancer).
- # Note that PREFERRED_URL_SCHEME must be `https` when this flag is set or it can lead to undefined
- # behavior.
- EXTERNAL_TLS_TERMINATION = False
+ # If set to 'readonly', the entire registry is placed into read only mode and no write operations
+ # may be performed against it.
+ REGISTRY_STATE = "normal"
- # If true, CDN URLs will be used for our external dependencies, rather than the local
- # copies.
- USE_CDN = False
+ # If set to true, TLS is used, but is terminated by an external service (such as a load balancer).
+ # Note that PREFERRED_URL_SCHEME must be `https` when this flag is set or it can lead to undefined
+ # behavior.
+ EXTERNAL_TLS_TERMINATION = False
- # Authentication
- AUTHENTICATION_TYPE = 'Database'
+ # If true, CDN URLs will be used for our external dependencies, rather than the local
+ # copies.
+ USE_CDN = False
- # Build logs
- BUILDLOGS_REDIS = {'host': 'localhost'}
- BUILDLOGS_OPTIONS = []
+ # Authentication
+ AUTHENTICATION_TYPE = "Database"
- # Real-time user events
- USER_EVENTS_REDIS = {'host': 'localhost'}
+ # Build logs
+ BUILDLOGS_REDIS = {"host": "localhost"}
+ BUILDLOGS_OPTIONS = []
- # Stripe config
- BILLING_TYPE = 'FakeStripe'
+ # Real-time user events
+ USER_EVENTS_REDIS = {"host": "localhost"}
- # Analytics
- ANALYTICS_TYPE = 'FakeAnalytics'
+ # Stripe config
+ BILLING_TYPE = "FakeStripe"
- # Build Queue Metrics
- QUEUE_METRICS_TYPE = 'Null'
- QUEUE_WORKER_METRICS_REFRESH_SECONDS = 300
+ # Analytics
+ ANALYTICS_TYPE = "FakeAnalytics"
- # Exception logging
- EXCEPTION_LOG_TYPE = 'FakeSentry'
- SENTRY_DSN = None
- SENTRY_PUBLIC_DSN = None
+ # Build Queue Metrics
+ QUEUE_METRICS_TYPE = "Null"
+ QUEUE_WORKER_METRICS_REFRESH_SECONDS = 300
- # Github Config
- GITHUB_LOGIN_CONFIG = None
- GITHUB_TRIGGER_CONFIG = None
+ # Exception logging
+ EXCEPTION_LOG_TYPE = "FakeSentry"
+ SENTRY_DSN = None
+ SENTRY_PUBLIC_DSN = None
- # Google Config.
- GOOGLE_LOGIN_CONFIG = None
+ # Github Config
+ GITHUB_LOGIN_CONFIG = None
+ GITHUB_TRIGGER_CONFIG = None
- # Bitbucket Config.
- BITBUCKET_TRIGGER_CONFIG = None
+ # Google Config.
+ GOOGLE_LOGIN_CONFIG = None
- # Gitlab Config.
- GITLAB_TRIGGER_CONFIG = None
+ # Bitbucket Config.
+ BITBUCKET_TRIGGER_CONFIG = None
- NOTIFICATION_QUEUE_NAME = 'notification'
- DOCKERFILE_BUILD_QUEUE_NAME = 'dockerfilebuild'
- REPLICATION_QUEUE_NAME = 'imagestoragereplication'
- SECSCAN_NOTIFICATION_QUEUE_NAME = 'security_notification'
- CHUNK_CLEANUP_QUEUE_NAME = 'chunk_cleanup'
- NAMESPACE_GC_QUEUE_NAME = 'namespacegc'
- EXPORT_ACTION_LOGS_QUEUE_NAME = 'exportactionlogs'
+ # Gitlab Config.
+ GITLAB_TRIGGER_CONFIG = None
- # Super user config. Note: This MUST BE an empty list for the default config.
- SUPER_USERS = []
+ NOTIFICATION_QUEUE_NAME = "notification"
+ DOCKERFILE_BUILD_QUEUE_NAME = "dockerfilebuild"
+ REPLICATION_QUEUE_NAME = "imagestoragereplication"
+ SECSCAN_NOTIFICATION_QUEUE_NAME = "security_notification"
+ CHUNK_CLEANUP_QUEUE_NAME = "chunk_cleanup"
+ NAMESPACE_GC_QUEUE_NAME = "namespacegc"
+ EXPORT_ACTION_LOGS_QUEUE_NAME = "exportactionlogs"
- # Feature Flag: Whether sessions are permanent.
- FEATURE_PERMANENT_SESSIONS = True
+ # Super user config. Note: This MUST BE an empty list for the default config.
+ SUPER_USERS = []
- # Feature Flag: Whether super users are supported.
- FEATURE_SUPER_USERS = True
+ # Feature Flag: Whether sessions are permanent.
+ FEATURE_PERMANENT_SESSIONS = True
- # Feature Flag: Whether to allow anonymous users to browse and pull public repositories.
- FEATURE_ANONYMOUS_ACCESS = True
+ # Feature Flag: Whether super users are supported.
+ FEATURE_SUPER_USERS = True
- # Feature Flag: Whether billing is required.
- FEATURE_BILLING = False
+ # Feature Flag: Whether to allow anonymous users to browse and pull public repositories.
+ FEATURE_ANONYMOUS_ACCESS = True
- # Feature Flag: Whether user accounts automatically have usage log access.
- FEATURE_USER_LOG_ACCESS = False
+ # Feature Flag: Whether billing is required.
+ FEATURE_BILLING = False
- # Feature Flag: Whether GitHub login is supported.
- FEATURE_GITHUB_LOGIN = False
+ # Feature Flag: Whether user accounts automatically have usage log access.
+ FEATURE_USER_LOG_ACCESS = False
- # Feature Flag: Whether Google login is supported.
- FEATURE_GOOGLE_LOGIN = False
+ # Feature Flag: Whether GitHub login is supported.
+ FEATURE_GITHUB_LOGIN = False
- # Feature Flag: Whether to support GitHub build triggers.
- FEATURE_GITHUB_BUILD = False
+ # Feature Flag: Whether Google login is supported.
+ FEATURE_GOOGLE_LOGIN = False
- # Feature Flag: Whether to support Bitbucket build triggers.
- FEATURE_BITBUCKET_BUILD = False
+ # Feature Flag: Whether to support GitHub build triggers.
+ FEATURE_GITHUB_BUILD = False
- # Feature Flag: Whether to support GitLab build triggers.
- FEATURE_GITLAB_BUILD = False
+ # Feature Flag: Whether to support Bitbucket build triggers.
+ FEATURE_BITBUCKET_BUILD = False
- # Feature Flag: Dockerfile build support.
- FEATURE_BUILD_SUPPORT = True
+ # Feature Flag: Whether to support GitLab build triggers.
+ FEATURE_GITLAB_BUILD = False
- # Feature Flag: Whether emails are enabled.
- FEATURE_MAILING = True
+ # Feature Flag: Dockerfile build support.
+ FEATURE_BUILD_SUPPORT = True
- # Feature Flag: Whether users can be created (by non-super users).
- FEATURE_USER_CREATION = True
+ # Feature Flag: Whether emails are enabled.
+ FEATURE_MAILING = True
- # Feature Flag: Whether users being created must be invited by another user.
- # If FEATURE_USER_CREATION is off, this flag has no effect.
- FEATURE_INVITE_ONLY_USER_CREATION = False
+ # Feature Flag: Whether users can be created (by non-super users).
+ FEATURE_USER_CREATION = True
- # Feature Flag: Whether users can be renamed
- FEATURE_USER_RENAME = False
+ # Feature Flag: Whether users being created must be invited by another user.
+ # If FEATURE_USER_CREATION is off, this flag has no effect.
+ FEATURE_INVITE_ONLY_USER_CREATION = False
- # Feature Flag: Whether non-encrypted passwords (as opposed to encrypted tokens) can be used for
- # basic auth.
- FEATURE_REQUIRE_ENCRYPTED_BASIC_AUTH = False
+ # Feature Flag: Whether users can be renamed
+ FEATURE_USER_RENAME = False
- # Feature Flag: Whether to automatically replicate between storage engines.
- FEATURE_STORAGE_REPLICATION = False
+ # Feature Flag: Whether non-encrypted passwords (as opposed to encrypted tokens) can be used for
+ # basic auth.
+ FEATURE_REQUIRE_ENCRYPTED_BASIC_AUTH = False
- # Feature Flag: Whether users can directly login to the UI.
- FEATURE_DIRECT_LOGIN = True
+ # Feature Flag: Whether to automatically replicate between storage engines.
+ FEATURE_STORAGE_REPLICATION = False
- # Feature Flag: Whether the v2/ endpoint is visible
- FEATURE_ADVERTISE_V2 = True
+ # Feature Flag: Whether users can directly login to the UI.
+ FEATURE_DIRECT_LOGIN = True
- # Semver spec for which Docker versions we will blacklist
- # Documentation: http://pythonhosted.org/semantic_version/reference.html#semantic_version.Spec
- BLACKLIST_V2_SPEC = '<1.6.0'
+ # Feature Flag: Whether the v2/ endpoint is visible
+ FEATURE_ADVERTISE_V2 = True
- # Feature Flag: Whether to restrict V1 pushes to the whitelist.
- FEATURE_RESTRICTED_V1_PUSH = False
- V1_PUSH_WHITELIST = []
+ # Semver spec for which Docker versions we will blacklist
+ # Documentation: http://pythonhosted.org/semantic_version/reference.html#semantic_version.Spec
+ BLACKLIST_V2_SPEC = "<1.6.0"
- # Feature Flag: Whether or not to rotate old action logs to storage.
- FEATURE_ACTION_LOG_ROTATION = False
+ # Feature Flag: Whether to restrict V1 pushes to the whitelist.
+ FEATURE_RESTRICTED_V1_PUSH = False
+ V1_PUSH_WHITELIST = []
- # Feature Flag: Whether to enable conversion to ACIs.
- FEATURE_ACI_CONVERSION = False
+ # Feature Flag: Whether or not to rotate old action logs to storage.
+ FEATURE_ACTION_LOG_ROTATION = False
- # Feature Flag: Whether to allow for "namespace-less" repositories when pulling and pushing from
- # Docker.
- FEATURE_LIBRARY_SUPPORT = True
+ # Feature Flag: Whether to enable conversion to ACIs.
+ FEATURE_ACI_CONVERSION = False
- # Feature Flag: Whether to require invitations when adding a user to a team.
- FEATURE_REQUIRE_TEAM_INVITE = True
+ # Feature Flag: Whether to allow for "namespace-less" repositories when pulling and pushing from
+ # Docker.
+ FEATURE_LIBRARY_SUPPORT = True
- # Feature Flag: Whether to proxy all direct download URLs in storage via the registry's nginx.
- FEATURE_PROXY_STORAGE = False
+ # Feature Flag: Whether to require invitations when adding a user to a team.
+ FEATURE_REQUIRE_TEAM_INVITE = True
- # Feature Flag: Whether to collect and support user metadata.
- FEATURE_USER_METADATA = False
+ # Feature Flag: Whether to proxy all direct download URLs in storage via the registry's nginx.
+ FEATURE_PROXY_STORAGE = False
- # Feature Flag: Whether to support signing
- FEATURE_SIGNING = False
+ # Feature Flag: Whether to collect and support user metadata.
+ FEATURE_USER_METADATA = False
- # Feature Flag: Whether to enable support for App repositories.
- FEATURE_APP_REGISTRY = False
+ # Feature Flag: Whether to support signing
+ FEATURE_SIGNING = False
- # Feature Flag: Whether app registry is in a read-only mode.
- FEATURE_READONLY_APP_REGISTRY = False
+ # Feature Flag: Whether to enable support for App repositories.
+ FEATURE_APP_REGISTRY = False
- # Feature Flag: If set to true, the _catalog endpoint returns public repositories. Otherwise,
- # only private repositories can be returned.
- FEATURE_PUBLIC_CATALOG = False
+ # Feature Flag: Whether app registry is in a read-only mode.
+ FEATURE_READONLY_APP_REGISTRY = False
- # Feature Flag: If set to true, build logs may be read by those with read access to the repo,
- # rather than only write access or admin access.
- FEATURE_READER_BUILD_LOGS = False
+ # Feature Flag: If set to true, the _catalog endpoint returns public repositories. Otherwise,
+ # only private repositories can be returned.
+ FEATURE_PUBLIC_CATALOG = False
- # Feature Flag: If set to true, autocompletion will apply to partial usernames.
- FEATURE_PARTIAL_USER_AUTOCOMPLETE = True
+ # Feature Flag: If set to true, build logs may be read by those with read access to the repo,
+ # rather than only write access or admin access.
+ FEATURE_READER_BUILD_LOGS = False
- # Feature Flag: If set to true, users can confirm (and modify) their initial usernames when
- # logging in via OIDC or a non-database internal auth provider.
- FEATURE_USERNAME_CONFIRMATION = True
+ # Feature Flag: If set to true, autocompletion will apply to partial usernames.
+ FEATURE_PARTIAL_USER_AUTOCOMPLETE = True
- # If a namespace is defined in the public namespace list, then it will appear on *all*
- # user's repository list pages, regardless of whether that user is a member of the namespace.
- # Typically, this is used by an enterprise customer in configuring a set of "well-known"
- # namespaces.
- PUBLIC_NAMESPACES = []
+ # Feature Flag: If set to true, users can confirm (and modify) their initial usernames when
+ # logging in via OIDC or a non-database internal auth provider.
+ FEATURE_USERNAME_CONFIRMATION = True
- # The namespace to use for library repositories.
- # Note: This must remain 'library' until Docker removes their hard-coded namespace for libraries.
- # See: https://github.com/docker/docker/blob/master/registry/session.go#L320
- LIBRARY_NAMESPACE = 'library'
+ # If a namespace is defined in the public namespace list, then it will appear on *all*
+ # user's repository list pages, regardless of whether that user is a member of the namespace.
+ # Typically, this is used by an enterprise customer in configuring a set of "well-known"
+ # namespaces.
+ PUBLIC_NAMESPACES = []
- BUILD_MANAGER = ('enterprise', {})
+ # The namespace to use for library repositories.
+ # Note: This must remain 'library' until Docker removes their hard-coded namespace for libraries.
+ # See: https://github.com/docker/docker/blob/master/registry/session.go#L320
+ LIBRARY_NAMESPACE = "library"
- DISTRIBUTED_STORAGE_CONFIG = {
- 'local_eu': ['LocalStorage', {'storage_path': 'test/data/registry/eu'}],
- 'local_us': ['LocalStorage', {'storage_path': 'test/data/registry/us'}],
- }
+ BUILD_MANAGER = ("enterprise", {})
- DISTRIBUTED_STORAGE_PREFERENCE = ['local_us']
- DISTRIBUTED_STORAGE_DEFAULT_LOCATIONS = ['local_us']
+ DISTRIBUTED_STORAGE_CONFIG = {
+ "local_eu": ["LocalStorage", {"storage_path": "test/data/registry/eu"}],
+ "local_us": ["LocalStorage", {"storage_path": "test/data/registry/us"}],
+ }
- # Health checker.
- HEALTH_CHECKER = ('LocalHealthCheck', {})
+ DISTRIBUTED_STORAGE_PREFERENCE = ["local_us"]
+ DISTRIBUTED_STORAGE_DEFAULT_LOCATIONS = ["local_us"]
- # Userfiles
- USERFILES_LOCATION = 'local_us'
- USERFILES_PATH = 'userfiles/'
+ # Health checker.
+ HEALTH_CHECKER = ("LocalHealthCheck", {})
- # Build logs archive
- LOG_ARCHIVE_LOCATION = 'local_us'
- LOG_ARCHIVE_PATH = 'logarchive/'
+ # Userfiles
+ USERFILES_LOCATION = "local_us"
+ USERFILES_PATH = "userfiles/"
- # Action logs archive
- ACTION_LOG_ARCHIVE_LOCATION = 'local_us'
- ACTION_LOG_ARCHIVE_PATH = 'actionlogarchive/'
- ACTION_LOG_ROTATION_THRESHOLD = '30d'
+ # Build logs archive
+ LOG_ARCHIVE_LOCATION = "local_us"
+ LOG_ARCHIVE_PATH = "logarchive/"
- # Allow registry pulls when unable to write to the audit log
- ALLOW_PULLS_WITHOUT_STRICT_LOGGING = False
+ # Action logs archive
+ ACTION_LOG_ARCHIVE_LOCATION = "local_us"
+ ACTION_LOG_ARCHIVE_PATH = "actionlogarchive/"
+ ACTION_LOG_ROTATION_THRESHOLD = "30d"
- # Temporary tag expiration in seconds, this may actually be longer based on GC policy
- PUSH_TEMP_TAG_EXPIRATION_SEC = 60 * 60 # One hour per layer
+ # Allow registry pulls when unable to write to the audit log
+ ALLOW_PULLS_WITHOUT_STRICT_LOGGING = False
- # Signed registry grant token expiration in seconds
- SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
+ # Temporary tag expiration in seconds, this may actually be longer based on GC policy
+ PUSH_TEMP_TAG_EXPIRATION_SEC = 60 * 60 # One hour per layer
- # Registry v2 JWT Auth config
- REGISTRY_JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed one hour, accounting for clock skew
+ # Signed registry grant token expiration in seconds
+ SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
- # The URL endpoint to which we redirect OAuth when generating a token locally.
- LOCAL_OAUTH_HANDLER = '/oauth/localapp'
+ # Registry v2 JWT Auth config
+ REGISTRY_JWT_AUTH_MAX_FRESH_S = (
+ 60 * 60 + 60
+ ) # At most signed one hour, accounting for clock skew
- # The various avatar background colors.
- AVATAR_KIND = 'local'
+ # The URL endpoint to which we redirect OAuth when generating a token locally.
+ LOCAL_OAUTH_HANDLER = "/oauth/localapp"
- # Custom branding
- BRANDING = {
- 'logo': '/static/img/quay-horizontal-color.svg',
- 'footer_img': None,
- 'footer_url': None,
- }
+ # The various avatar background colors.
+ AVATAR_KIND = "local"
- # How often the Garbage Collection worker runs.
- GARBAGE_COLLECTION_FREQUENCY = 30 # seconds
+ # Custom branding
+ BRANDING = {
+ "logo": "/static/img/quay-horizontal-color.svg",
+ "footer_img": None,
+ "footer_url": None,
+ }
- # How long notifications will try to send before timing out.
- NOTIFICATION_SEND_TIMEOUT = 10
+ # How often the Garbage Collection worker runs.
+ GARBAGE_COLLECTION_FREQUENCY = 30 # seconds
- # Security scanner
- FEATURE_SECURITY_SCANNER = False
- FEATURE_SECURITY_NOTIFICATIONS = False
+ # How long notifications will try to send before timing out.
+ NOTIFICATION_SEND_TIMEOUT = 10
- # The endpoint for the security scanner.
- SECURITY_SCANNER_ENDPOINT = 'http://192.168.99.101:6060'
+ # Security scanner
+ FEATURE_SECURITY_SCANNER = False
+ FEATURE_SECURITY_NOTIFICATIONS = False
- # The number of seconds between indexing intervals in the security scanner
- SECURITY_SCANNER_INDEXING_INTERVAL = 30
+ # The endpoint for the security scanner.
+ SECURITY_SCANNER_ENDPOINT = "http://192.168.99.101:6060"
- # If specified, the security scanner will only index images newer than the provided ID.
- SECURITY_SCANNER_INDEXING_MIN_ID = None
+ # The number of seconds between indexing intervals in the security scanner
+ SECURITY_SCANNER_INDEXING_INTERVAL = 30
- # If specified, the endpoint to be used for all POST calls to the security scanner.
- SECURITY_SCANNER_ENDPOINT_BATCH = None
+ # If specified, the security scanner will only index images newer than the provided ID.
+ SECURITY_SCANNER_INDEXING_MIN_ID = None
- # If specified, GET requests that return non-200 will be retried at the following instances.
- SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS = []
+ # If specified, the endpoint to be used for all POST calls to the security scanner.
+ SECURITY_SCANNER_ENDPOINT_BATCH = None
- # The indexing engine version running inside the security scanner.
- SECURITY_SCANNER_ENGINE_VERSION_TARGET = 3
+ # If specified, GET requests that return non-200 will be retried at the following instances.
+ SECURITY_SCANNER_READONLY_FAILOVER_ENDPOINTS = []
- # The version of the API to use for the security scanner.
- SECURITY_SCANNER_API_VERSION = 'v1'
+ # The indexing engine version running inside the security scanner.
+ SECURITY_SCANNER_ENGINE_VERSION_TARGET = 3
- # API call timeout for the security scanner.
- SECURITY_SCANNER_API_TIMEOUT_SECONDS = 10
+ # The version of the API to use for the security scanner.
+ SECURITY_SCANNER_API_VERSION = "v1"
- # POST call timeout for the security scanner.
- SECURITY_SCANNER_API_TIMEOUT_POST_SECONDS = 480
+ # API call timeout for the security scanner.
+ SECURITY_SCANNER_API_TIMEOUT_SECONDS = 10
- # The issuer name for the security scanner.
- SECURITY_SCANNER_ISSUER_NAME = 'security_scanner'
+ # POST call timeout for the security scanner.
+ SECURITY_SCANNER_API_TIMEOUT_POST_SECONDS = 480
- # Repository mirror
- FEATURE_REPO_MIRROR = False
+ # The issuer name for the security scanner.
+ SECURITY_SCANNER_ISSUER_NAME = "security_scanner"
- # The number of seconds between indexing intervals in the repository mirror
- REPO_MIRROR_INTERVAL = 30
+ # Repository mirror
+ FEATURE_REPO_MIRROR = False
- # Require HTTPS and verify certificates of Quay registry during mirror.
- REPO_MIRROR_TLS_VERIFY = True
+ # The number of seconds between indexing intervals in the repository mirror
+ REPO_MIRROR_INTERVAL = 30
- # Replaces the SERVER_HOSTNAME as the destination for mirroring.
- REPO_MIRROR_SERVER_HOSTNAME = None
+ # Require HTTPS and verify certificates of Quay registry during mirror.
+ REPO_MIRROR_TLS_VERIFY = True
- # JWTProxy Settings
- # The address (sans schema) to proxy outgoing requests through the jwtproxy
- # to be signed
- JWTPROXY_SIGNER = 'localhost:8081'
+ # Replaces the SERVER_HOSTNAME as the destination for mirroring.
+ REPO_MIRROR_SERVER_HOSTNAME = None
- # The audience that jwtproxy should verify on incoming requests
- # If None, will be calculated off of the SERVER_HOSTNAME (default)
- JWTPROXY_AUDIENCE = None
+ # JWTProxy Settings
+ # The address (sans schema) to proxy outgoing requests through the jwtproxy
+ # to be signed
+ JWTPROXY_SIGNER = "localhost:8081"
- # Torrent management flags
- FEATURE_BITTORRENT = False
- BITTORRENT_PIECE_SIZE = 512 * 1024
- BITTORRENT_ANNOUNCE_URL = 'https://localhost:6881/announce'
- BITTORRENT_FILENAME_PEPPER = str(uuid4())
- BITTORRENT_WEBSEED_LIFETIME = 3600
+ # The audience that jwtproxy should verify on incoming requests
+ # If None, will be calculated off of the SERVER_HOSTNAME (default)
+ JWTPROXY_AUDIENCE = None
- # "Secret" key for generating encrypted paging tokens. Only needed to be secret to
- # hide the ID range for production (in which this value is overridden). Should *not*
- # be relied upon for secure encryption otherwise.
- # This value is a Fernet key and should be 32bytes URL-safe base64 encoded.
- PAGE_TOKEN_KEY = '0OYrc16oBuksR8T3JGB-xxYSlZ2-7I_zzqrLzggBJ58='
+ # Torrent management flags
+ FEATURE_BITTORRENT = False
+ BITTORRENT_PIECE_SIZE = 512 * 1024
+ BITTORRENT_ANNOUNCE_URL = "https://localhost:6881/announce"
+ BITTORRENT_FILENAME_PEPPER = str(uuid4())
+ BITTORRENT_WEBSEED_LIFETIME = 3600
- # The timeout for service key approval.
- UNAPPROVED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 # One day
+ # "Secret" key for generating encrypted paging tokens. Only needed to be secret to
+ # hide the ID range for production (in which this value is overridden). Should *not*
+ # be relied upon for secure encryption otherwise.
+ # This value is a Fernet key and should be 32bytes URL-safe base64 encoded.
+ PAGE_TOKEN_KEY = "0OYrc16oBuksR8T3JGB-xxYSlZ2-7I_zzqrLzggBJ58="
- # How long to wait before GCing an expired service key.
- EXPIRED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 * 7 # One week
+ # The timeout for service key approval.
+ UNAPPROVED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 # One day
- # The ID of the user account in the database to be used for service audit logs. If none, the
- # lowest user in the database will be used.
- SERVICE_LOG_ACCOUNT_ID = None
+ # How long to wait before GCing an expired service key.
+ EXPIRED_SERVICE_KEY_TTL_SEC = 60 * 60 * 24 * 7 # One week
- # The service key ID for the instance service.
- # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
- INSTANCE_SERVICE_KEY_SERVICE = 'quay'
+ # The ID of the user account in the database to be used for service audit logs. If none, the
+ # lowest user in the database will be used.
+ SERVICE_LOG_ACCOUNT_ID = None
- # The location of the key ID file generated for this instance.
- INSTANCE_SERVICE_KEY_KID_LOCATION = os.path.join(CONF_DIR, 'quay.kid')
+ # The service key ID for the instance service.
+ # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
+ INSTANCE_SERVICE_KEY_SERVICE = "quay"
- # The location of the private key generated for this instance.
- # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
- INSTANCE_SERVICE_KEY_LOCATION = os.path.join(CONF_DIR, 'quay.pem')
+ # The location of the key ID file generated for this instance.
+ INSTANCE_SERVICE_KEY_KID_LOCATION = os.path.join(CONF_DIR, "quay.kid")
- # This instance's service key expiration in minutes.
- INSTANCE_SERVICE_KEY_EXPIRATION = 120
+ # The location of the private key generated for this instance.
+ # NOTE: If changed, jwtproxy_conf.yaml.jnj must also be updated.
+ INSTANCE_SERVICE_KEY_LOCATION = os.path.join(CONF_DIR, "quay.pem")
- # Number of minutes between expiration refresh in minutes. Should be the expiration / 2 minus
- # some additional window time.
- INSTANCE_SERVICE_KEY_REFRESH = 55
+ # This instance's service key expiration in minutes.
+ INSTANCE_SERVICE_KEY_EXPIRATION = 120
- # The whitelist of client IDs for OAuth applications that allow for direct login.
- DIRECT_OAUTH_CLIENTID_WHITELIST = []
+ # Number of minutes between expiration refresh in minutes. Should be the expiration / 2 minus
+ # some additional window time.
+ INSTANCE_SERVICE_KEY_REFRESH = 55
- # URL that specifies the location of the prometheus stats aggregator.
- PROMETHEUS_AGGREGATOR_URL = 'http://localhost:9092'
+ # The whitelist of client IDs for OAuth applications that allow for direct login.
+ DIRECT_OAUTH_CLIENTID_WHITELIST = []
- # Namespace prefix for all prometheus metrics.
- PROMETHEUS_NAMESPACE = 'quay'
+ # URL that specifies the location of the prometheus stats aggregator.
+ PROMETHEUS_AGGREGATOR_URL = "http://localhost:9092"
- # Overridable list of reverse DNS prefixes that are reserved for internal use on labels.
- LABEL_KEY_RESERVED_PREFIXES = []
+ # Namespace prefix for all prometheus metrics.
+ PROMETHEUS_NAMESPACE = "quay"
- # Delays workers from starting until a random point in time between 0 and their regular interval.
- STAGGER_WORKERS = True
+ # Overridable list of reverse DNS prefixes that are reserved for internal use on labels.
+ LABEL_KEY_RESERVED_PREFIXES = []
- # Location of the static marketing site.
- STATIC_SITE_BUCKET = None
+ # Delays workers from starting until a random point in time between 0 and their regular interval.
+ STAGGER_WORKERS = True
- # Site key and secret key for using recaptcha.
- FEATURE_RECAPTCHA = False
- RECAPTCHA_SITE_KEY = None
- RECAPTCHA_SECRET_KEY = None
+ # Location of the static marketing site.
+ STATIC_SITE_BUCKET = None
- # Server where TUF metadata can be found
- TUF_SERVER = None
+ # Site key and secret key for using recaptcha.
+ FEATURE_RECAPTCHA = False
+ RECAPTCHA_SITE_KEY = None
+ RECAPTCHA_SECRET_KEY = None
- # Prefix to add to metadata e.g. //
- TUF_GUN_PREFIX = None
+ # Server where TUF metadata can be found
+ TUF_SERVER = None
- # Maximum size allowed for layers in the registry.
- MAXIMUM_LAYER_SIZE = '20G'
+ # Prefix to add to metadata e.g. //
+ TUF_GUN_PREFIX = None
- # Feature Flag: Whether team syncing from the backing auth is enabled.
- FEATURE_TEAM_SYNCING = False
- TEAM_RESYNC_STALE_TIME = '30m'
- TEAM_SYNC_WORKER_FREQUENCY = 60 # seconds
+ # Maximum size allowed for layers in the registry.
+ MAXIMUM_LAYER_SIZE = "20G"
- # Feature Flag: If enabled, non-superusers can setup team syncing.
- FEATURE_NONSUPERUSER_TEAM_SYNCING_SETUP = False
+ # Feature Flag: Whether team syncing from the backing auth is enabled.
+ FEATURE_TEAM_SYNCING = False
+ TEAM_RESYNC_STALE_TIME = "30m"
+ TEAM_SYNC_WORKER_FREQUENCY = 60 # seconds
- # The default configurable tag expiration time for time machine.
- DEFAULT_TAG_EXPIRATION = '2w'
+ # Feature Flag: If enabled, non-superusers can setup team syncing.
+ FEATURE_NONSUPERUSER_TEAM_SYNCING_SETUP = False
- # The options to present in namespace settings for the tag expiration. If empty, no option
- # will be given and the default will be displayed read-only.
- TAG_EXPIRATION_OPTIONS = ['0s', '1d', '1w', '2w', '4w']
+ # The default configurable tag expiration time for time machine.
+ DEFAULT_TAG_EXPIRATION = "2w"
- # Feature Flag: Whether users can view and change their tag expiration.
- FEATURE_CHANGE_TAG_EXPIRATION = True
+ # The options to present in namespace settings for the tag expiration. If empty, no option
+ # will be given and the default will be displayed read-only.
+ TAG_EXPIRATION_OPTIONS = ["0s", "1d", "1w", "2w", "4w"]
- # Defines a secret for enabling the health-check endpoint's debug information.
- ENABLE_HEALTH_DEBUG_SECRET = None
+ # Feature Flag: Whether users can view and change their tag expiration.
+ FEATURE_CHANGE_TAG_EXPIRATION = True
- # The lifetime for a user recovery token before it becomes invalid.
- USER_RECOVERY_TOKEN_LIFETIME = '30m'
+ # Defines a secret for enabling the health-check endpoint's debug information.
+ ENABLE_HEALTH_DEBUG_SECRET = None
- # If specified, when app specific passwords expire by default.
- APP_SPECIFIC_TOKEN_EXPIRATION = None
+ # The lifetime for a user recovery token before it becomes invalid.
+ USER_RECOVERY_TOKEN_LIFETIME = "30m"
- # Feature Flag: If enabled, users can create and use app specific tokens to login via the CLI.
- FEATURE_APP_SPECIFIC_TOKENS = True
+ # If specified, when app specific passwords expire by default.
+ APP_SPECIFIC_TOKEN_EXPIRATION = None
- # How long expired app specific tokens should remain visible to users before being automatically
- # deleted. Set to None to turn off garbage collection.
- EXPIRED_APP_SPECIFIC_TOKEN_GC = '1d'
+ # Feature Flag: If enabled, users can create and use app specific tokens to login via the CLI.
+ FEATURE_APP_SPECIFIC_TOKENS = True
- # The size of pages returned by the Docker V2 API.
- V2_PAGINATION_SIZE = 50
+ # How long expired app specific tokens should remain visible to users before being automatically
+ # deleted. Set to None to turn off garbage collection.
+ EXPIRED_APP_SPECIFIC_TOKEN_GC = "1d"
- # If enabled, ensures that API calls are made with the X-Requested-With header
- # when called from a browser.
- BROWSER_API_CALLS_XHR_ONLY = True
+ # The size of pages returned by the Docker V2 API.
+ V2_PAGINATION_SIZE = 50
- # If set to a non-None integer value, the default number of maximum builds for a namespace.
- DEFAULT_NAMESPACE_MAXIMUM_BUILD_COUNT = None
+ # If enabled, ensures that API calls are made with the X-Requested-With header
+ # when called from a browser.
+ BROWSER_API_CALLS_XHR_ONLY = True
- # If set to a non-None integer value, the default number of maximum builds for a namespace whose
- # creator IP is deemed a threat.
- THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT = None
+ # If set to a non-None integer value, the default number of maximum builds for a namespace.
+ DEFAULT_NAMESPACE_MAXIMUM_BUILD_COUNT = None
- # The API Key to use when requesting IP information.
- IP_DATA_API_KEY = None
+ # If set to a non-None integer value, the default number of maximum builds for a namespace whose
+ # creator IP is deemed a threat.
+ THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT = None
- # For Billing Support Only: The number of allowed builds on a namespace that has been billed
- # successfully.
- BILLED_NAMESPACE_MAXIMUM_BUILD_COUNT = None
+ # The API Key to use when requesting IP information.
+ IP_DATA_API_KEY = None
- # Configuration for the data model cache.
- DATA_MODEL_CACHE_CONFIG = {
- 'engine': 'memcached',
- 'endpoint': ('127.0.0.1', 18080),
- }
+ # For Billing Support Only: The number of allowed builds on a namespace that has been billed
+ # successfully.
+ BILLED_NAMESPACE_MAXIMUM_BUILD_COUNT = None
- # Defines the number of successive failures of a build trigger's build before the trigger is
- # automatically disabled.
- SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD = 100
+ # Configuration for the data model cache.
+ DATA_MODEL_CACHE_CONFIG = {"engine": "memcached", "endpoint": ("127.0.0.1", 18080)}
- # Defines the number of successive internal errors of a build trigger's build before the
- # trigger is automatically disabled.
- SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD = 5
+ # Defines the number of successive failures of a build trigger's build before the trigger is
+ # automatically disabled.
+ SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD = 100
- # Defines the delay required (in seconds) before the last_accessed field of a user/robot or access
- # token will be updated after the previous update.
- LAST_ACCESSED_UPDATE_THRESHOLD_S = 60
+ # Defines the number of successive internal errors of a build trigger's build before the
+ # trigger is automatically disabled.
+ SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD = 5
- # Defines the number of results per page used to show search results
- SEARCH_RESULTS_PER_PAGE = 10
+ # Defines the delay required (in seconds) before the last_accessed field of a user/robot or access
+ # token will be updated after the previous update.
+ LAST_ACCESSED_UPDATE_THRESHOLD_S = 60
- # Defines the maximum number of pages the user can paginate before they are limited
- SEARCH_MAX_RESULT_PAGE_COUNT = 10
+ # Defines the number of results per page used to show search results
+ SEARCH_RESULTS_PER_PAGE = 10
- # Feature Flag: Whether to record when users were last accessed.
- FEATURE_USER_LAST_ACCESSED = True
+ # Defines the maximum number of pages the user can paginate before they are limited
+ SEARCH_MAX_RESULT_PAGE_COUNT = 10
- # Feature Flag: Whether to allow users to retrieve aggregated log counts.
- FEATURE_AGGREGATED_LOG_COUNT_RETRIEVAL = True
+ # Feature Flag: Whether to record when users were last accessed.
+ FEATURE_USER_LAST_ACCESSED = True
- # Feature Flag: Whether rate limiting is enabled.
- FEATURE_RATE_LIMITS = False
+ # Feature Flag: Whether to allow users to retrieve aggregated log counts.
+ FEATURE_AGGREGATED_LOG_COUNT_RETRIEVAL = True
- # Feature Flag: Whether to support log exporting.
- FEATURE_LOG_EXPORT = True
+ # Feature Flag: Whether rate limiting is enabled.
+ FEATURE_RATE_LIMITS = False
- # Maximum number of action logs pages that can be returned via the API.
- ACTION_LOG_MAX_PAGE = None
+ # Feature Flag: Whether to support log exporting.
+ FEATURE_LOG_EXPORT = True
- # Log model
- LOGS_MODEL = 'database'
- LOGS_MODEL_CONFIG = {}
+ # Maximum number of action logs pages that can be returned via the API.
+ ACTION_LOG_MAX_PAGE = None
- # Namespace in which all audit logging is disabled.
- DISABLED_FOR_AUDIT_LOGS = []
+ # Log model
+ LOGS_MODEL = "database"
+ LOGS_MODEL_CONFIG = {}
- # Namespace in which pull audit logging is disabled.
- DISABLED_FOR_PULL_LOGS = []
+ # Namespace in which all audit logging is disabled.
+ DISABLED_FOR_AUDIT_LOGS = []
- # Feature Flag: Whether pull logs are disabled for free namespace.
- FEATURE_DISABLE_PULL_LOGS_FOR_FREE_NAMESPACES = False
+ # Namespace in which pull audit logging is disabled.
+ DISABLED_FOR_PULL_LOGS = []
- # Feature Flag: If set to true, no account using blacklisted email addresses will be allowed
- # to be created.
- FEATURE_BLACKLISTED_EMAILS = False
+ # Feature Flag: Whether pull logs are disabled for free namespace.
+ FEATURE_DISABLE_PULL_LOGS_FOR_FREE_NAMESPACES = False
- # The list of domains, including subdomains, for which any *new* User with a matching
- # email address will be denied creation. This option is only used if
- # FEATURE_BLACKLISTED_EMAILS is enabled.
- BLACKLISTED_EMAIL_DOMAINS = []
+ # Feature Flag: If set to true, no account using blacklisted email addresses will be allowed
+ # to be created.
+ FEATURE_BLACKLISTED_EMAILS = False
- # Feature Flag: Whether garbage collection is enabled.
- FEATURE_GARBAGE_COLLECTION = True
+ # The list of domains, including subdomains, for which any *new* User with a matching
+ # email address will be denied creation. This option is only used if
+ # FEATURE_BLACKLISTED_EMAILS is enabled.
+ BLACKLISTED_EMAIL_DOMAINS = []
+
+ # Feature Flag: Whether garbage collection is enabled.
+ FEATURE_GARBAGE_COLLECTION = True
diff --git a/config_app/_init_config.py b/config_app/_init_config.py
index 8b0533570..bd2eb5826 100644
--- a/config_app/_init_config.py
+++ b/config_app/_init_config.py
@@ -7,31 +7,31 @@ import subprocess
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONF_DIR = os.getenv("QUAYCONF", os.path.join(ROOT_DIR, "conf/"))
-STATIC_DIR = os.path.join(ROOT_DIR, 'static/')
-STATIC_LDN_DIR = os.path.join(STATIC_DIR, 'ldn/')
-STATIC_FONTS_DIR = os.path.join(STATIC_DIR, 'fonts/')
-TEMPLATE_DIR = os.path.join(ROOT_DIR, 'templates/')
-IS_KUBERNETES = 'KUBERNETES_SERVICE_HOST' in os.environ
+STATIC_DIR = os.path.join(ROOT_DIR, "static/")
+STATIC_LDN_DIR = os.path.join(STATIC_DIR, "ldn/")
+STATIC_FONTS_DIR = os.path.join(STATIC_DIR, "fonts/")
+TEMPLATE_DIR = os.path.join(ROOT_DIR, "templates/")
+IS_KUBERNETES = "KUBERNETES_SERVICE_HOST" in os.environ
def _get_version_number_changelog():
- try:
- with open(os.path.join(ROOT_DIR, 'CHANGELOG.md')) as f:
- return re.search(r'(v[0-9]+\.[0-9]+\.[0-9]+)', f.readline()).group(0)
- except IOError:
- return ''
+ try:
+ with open(os.path.join(ROOT_DIR, "CHANGELOG.md")) as f:
+ return re.search(r"(v[0-9]+\.[0-9]+\.[0-9]+)", f.readline()).group(0)
+ except IOError:
+ return ""
def _get_git_sha():
- if os.path.exists("GIT_HEAD"):
- with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
- return f.read()
- else:
- try:
- return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
- except (OSError, subprocess.CalledProcessError):
- pass
- return "unknown"
+ if os.path.exists("GIT_HEAD"):
+ with open(os.path.join(ROOT_DIR, "GIT_HEAD")) as f:
+ return f.read()
+ else:
+ try:
+ return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()[0:8]
+ except (OSError, subprocess.CalledProcessError):
+ pass
+ return "unknown"
__version__ = _get_version_number_changelog()
diff --git a/config_app/c_app.py b/config_app/c_app.py
index 0df198dd1..38847d6c9 100644
--- a/config_app/c_app.py
+++ b/config_app/c_app.py
@@ -15,28 +15,29 @@ app = Flask(__name__)
logger = logging.getLogger(__name__)
-OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, 'config_app/conf/stack')
-INIT_SCRIPTS_LOCATION = '/conf/init/'
+OVERRIDE_CONFIG_DIRECTORY = os.path.join(ROOT_DIR, "config_app/conf/stack")
+INIT_SCRIPTS_LOCATION = "/conf/init/"
-is_testing = 'TEST' in os.environ
+is_testing = "TEST" in os.environ
is_kubernetes = IS_KUBERNETES
-logger.debug('Configuration is on a kubernetes deployment: %s' % IS_KUBERNETES)
+logger.debug("Configuration is on a kubernetes deployment: %s" % IS_KUBERNETES)
-config_provider = get_config_provider(OVERRIDE_CONFIG_DIRECTORY, 'config.yaml', 'config.py',
- testing=is_testing)
+config_provider = get_config_provider(
+ OVERRIDE_CONFIG_DIRECTORY, "config.yaml", "config.py", testing=is_testing
+)
if is_testing:
- from test.testconfig import TestConfig
+ from test.testconfig import TestConfig
- logger.debug('Loading test config.')
- app.config.from_object(TestConfig())
+ logger.debug("Loading test config.")
+ app.config.from_object(TestConfig())
else:
- from config import DefaultConfig
+ from config import DefaultConfig
- logger.debug('Loading default config.')
- app.config.from_object(DefaultConfig())
- app.teardown_request(database.close_db_filter)
+ logger.debug("Loading default config.")
+ app.config.from_object(DefaultConfig())
+ app.teardown_request(database.close_db_filter)
# Load the override config via the provider.
config_provider.update_app_config(app.config)
diff --git a/config_app/conf/gunicorn_local.py b/config_app/conf/gunicorn_local.py
index d0ea0a758..add20b457 100644
--- a/config_app/conf/gunicorn_local.py
+++ b/config_app/conf/gunicorn_local.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -9,18 +10,24 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True)
-bind = '0.0.0.0:5000'
+bind = "0.0.0.0:5000"
workers = 1
-worker_class = 'gevent'
+worker_class = "gevent"
daemon = False
-pythonpath = '.'
+pythonpath = "."
preload_app = True
+
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting local gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/config_app/conf/gunicorn_web.py b/config_app/conf/gunicorn_web.py
index 14225fe72..107d8c395 100644
--- a/config_app/conf/gunicorn_web.py
+++ b/config_app/conf/gunicorn_web.py
@@ -1,5 +1,6 @@
import sys
import os
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
import logging
@@ -10,17 +11,23 @@ from config_app.config_util.log import logfile_path
logconfig = logfile_path(debug=True)
-bind = 'unix:/tmp/gunicorn_web.sock'
+bind = "unix:/tmp/gunicorn_web.sock"
workers = 1
-worker_class = 'gevent'
-pythonpath = '.'
+worker_class = "gevent"
+pythonpath = "."
preload_app = True
+
def post_fork(server, worker):
- # Reset the Random library to ensure it won't raise the "PID check failed." error after
- # gunicorn forks.
- Random.atfork()
+ # Reset the Random library to ensure it won't raise the "PID check failed." error after
+ # gunicorn forks.
+ Random.atfork()
+
def when_ready(server):
- logger = logging.getLogger(__name__)
- logger.debug('Starting local gunicorn with %s workers and %s worker class', workers, worker_class)
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ "Starting local gunicorn with %s workers and %s worker class",
+ workers,
+ worker_class,
+ )
diff --git a/config_app/config_application.py b/config_app/config_application.py
index 43676e354..a6e5d9fa3 100644
--- a/config_app/config_application.py
+++ b/config_app/config_application.py
@@ -3,6 +3,6 @@ from config_app.c_app import app as application
# Bind all of the blueprints
import config_web
-if __name__ == '__main__':
- logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
- application.run(port=5000, debug=True, threaded=True, host='0.0.0.0')
+if __name__ == "__main__":
+ logging.config.fileConfig(logfile_path(debug=True), disable_existing_loggers=False)
+ application.run(port=5000, debug=True, threaded=True, host="0.0.0.0")
diff --git a/config_app/config_endpoints/api/__init__.py b/config_app/config_endpoints/api/__init__.py
index c80fc1c9c..0620fed63 100644
--- a/config_app/config_endpoints/api/__init__.py
+++ b/config_app/config_endpoints/api/__init__.py
@@ -13,141 +13,153 @@ from config_app.c_app import app, IS_KUBERNETES
from config_app.config_endpoints.exception import InvalidResponse, InvalidRequest
logger = logging.getLogger(__name__)
-api_bp = Blueprint('api', __name__)
+api_bp = Blueprint("api", __name__)
-CROSS_DOMAIN_HEADERS = ['Authorization', 'Content-Type', 'X-Requested-With']
+CROSS_DOMAIN_HEADERS = ["Authorization", "Content-Type", "X-Requested-With"]
class ApiExceptionHandlingApi(Api):
- pass
+ pass
- @crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS)
- def handle_error(self, error):
- return super(ApiExceptionHandlingApi, self).handle_error(error)
+ @crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS)
+ def handle_error(self, error):
+ return super(ApiExceptionHandlingApi, self).handle_error(error)
api = ApiExceptionHandlingApi()
api.init_app(api_bp)
+
def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
- if not metadata:
- metadata = {}
+ if not metadata:
+ metadata = {}
- if repo:
- repo_name = repo.name
+ if repo:
+ repo_name = repo.name
+
+ model.log.log_action(
+ kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata
+ )
- model.log.log_action(kind, user_or_orgname, repo_name, user_or_orgname, request.remote_addr, metadata)
def format_date(date):
- """ Output an RFC822 date format. """
- if date is None:
- return None
- return formatdate(timegm(date.utctimetuple()))
-
+ """ Output an RFC822 date format. """
+ if date is None:
+ return None
+ return formatdate(timegm(date.utctimetuple()))
def resource(*urls, **kwargs):
- def wrapper(api_resource):
- if not api_resource:
- return None
+ def wrapper(api_resource):
+ if not api_resource:
+ return None
- api_resource.registered = True
- api.add_resource(api_resource, *urls, **kwargs)
- return api_resource
+ api_resource.registered = True
+ api.add_resource(api_resource, *urls, **kwargs)
+ return api_resource
- return wrapper
+ return wrapper
class ApiResource(Resource):
- registered = False
- method_decorators = []
+ registered = False
+ method_decorators = []
- def options(self):
- return None, 200
+ def options(self):
+ return None, 200
def add_method_metadata(name, value):
- def modifier(func):
- if func is None:
- return None
+ def modifier(func):
+ if func is None:
+ return None
- if '__api_metadata' not in dir(func):
- func.__api_metadata = {}
- func.__api_metadata[name] = value
- return func
+ if "__api_metadata" not in dir(func):
+ func.__api_metadata = {}
+ func.__api_metadata[name] = value
+ return func
- return modifier
+ return modifier
def method_metadata(func, name):
- if func is None:
- return None
+ if func is None:
+ return None
- if '__api_metadata' in dir(func):
- return func.__api_metadata.get(name, None)
- return None
+ if "__api_metadata" in dir(func):
+ return func.__api_metadata.get(name, None)
+ return None
def no_cache(f):
- @wraps(f)
- def add_no_cache(*args, **kwargs):
- response = f(*args, **kwargs)
- if response is not None:
- response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
- return response
- return add_no_cache
+ @wraps(f)
+ def add_no_cache(*args, **kwargs):
+ response = f(*args, **kwargs)
+ if response is not None:
+ response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
+ return response
+
+ return add_no_cache
def define_json_response(schema_name):
- def wrapper(func):
- @add_method_metadata('response_schema', schema_name)
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- schema = self.schemas[schema_name]
- resp = func(self, *args, **kwargs)
+ def wrapper(func):
+ @add_method_metadata("response_schema", schema_name)
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ schema = self.schemas[schema_name]
+ resp = func(self, *args, **kwargs)
- if app.config['TESTING']:
- try:
- validate(resp, schema)
- except ValidationError as ex:
- raise InvalidResponse(ex.message)
+ if app.config["TESTING"]:
+ try:
+ validate(resp, schema)
+ except ValidationError as ex:
+ raise InvalidResponse(ex.message)
- return resp
- return wrapped
- return wrapper
+ return resp
+
+ return wrapped
+
+ return wrapper
def validate_json_request(schema_name, optional=False):
- def wrapper(func):
- @add_method_metadata('request_schema', schema_name)
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- schema = self.schemas[schema_name]
- try:
- json_data = request.get_json()
- if json_data is None:
- if not optional:
- raise InvalidRequest('Missing JSON body')
- else:
- validate(json_data, schema)
- return func(self, *args, **kwargs)
- except ValidationError as ex:
- raise InvalidRequest(ex.message)
- return wrapped
- return wrapper
+ def wrapper(func):
+ @add_method_metadata("request_schema", schema_name)
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ schema = self.schemas[schema_name]
+ try:
+ json_data = request.get_json()
+ if json_data is None:
+ if not optional:
+ raise InvalidRequest("Missing JSON body")
+ else:
+ validate(json_data, schema)
+ return func(self, *args, **kwargs)
+ except ValidationError as ex:
+ raise InvalidRequest(ex.message)
+
+ return wrapped
+
+ return wrapper
+
def kubernetes_only(f):
- """ Aborts the request with a 400 if the app is not running on kubernetes """
- @wraps(f)
- def abort_if_not_kube(*args, **kwargs):
- if not IS_KUBERNETES:
- abort(400)
+ """ Aborts the request with a 400 if the app is not running on kubernetes """
- return f(*args, **kwargs)
- return abort_if_not_kube
+ @wraps(f)
+ def abort_if_not_kube(*args, **kwargs):
+ if not IS_KUBERNETES:
+ abort(400)
-nickname = partial(add_method_metadata, 'nickname')
+ return f(*args, **kwargs)
+
+ return abort_if_not_kube
+
+
+nickname = partial(add_method_metadata, "nickname")
import config_app.config_endpoints.api.discovery
diff --git a/config_app/config_endpoints/api/discovery.py b/config_app/config_endpoints/api/discovery.py
index 183963ea3..eef5536e5 100644
--- a/config_app/config_endpoints/api/discovery.py
+++ b/config_app/config_endpoints/api/discovery.py
@@ -5,250 +5,253 @@ from collections import OrderedDict
from config_app.c_app import app
from config_app.config_endpoints.api import method_metadata
-from config_app.config_endpoints.common import fully_qualified_name, PARAM_REGEX, TYPE_CONVERTER
+from config_app.config_endpoints.common import (
+ fully_qualified_name,
+ PARAM_REGEX,
+ TYPE_CONVERTER,
+)
logger = logging.getLogger(__name__)
def generate_route_data():
- include_internal = True
- compact = True
+ include_internal = True
+ compact = True
- def swagger_parameter(name, description, kind='path', param_type='string', required=True,
- enum=None, schema=None):
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
- parameter_info = {
- 'name': name,
- 'in': kind,
- 'required': required
- }
+ def swagger_parameter(
+ name,
+ description,
+ kind="path",
+ param_type="string",
+ required=True,
+ enum=None,
+ schema=None,
+ ):
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
+ parameter_info = {"name": name, "in": kind, "required": required}
- if schema:
- parameter_info['schema'] = {
- '$ref': '#/definitions/%s' % schema
- }
- else:
- parameter_info['type'] = param_type
-
- if enum is not None and len(list(enum)) > 0:
- parameter_info['enum'] = list(enum)
-
- return parameter_info
-
- paths = {}
- models = {}
- tags = []
- tags_added = set()
- operation_ids = set()
-
- for rule in app.url_map.iter_rules():
- endpoint_method = app.view_functions[rule.endpoint]
-
- # Verify that we have a view class for this API method.
- if not 'view_class' in dir(endpoint_method):
- continue
-
- view_class = endpoint_method.view_class
-
- # Hide the class if it is internal.
- internal = method_metadata(view_class, 'internal')
- if not include_internal and internal:
- continue
-
- # Build the tag.
- parts = fully_qualified_name(view_class).split('.')
- tag_name = parts[-2]
- if not tag_name in tags_added:
- tags_added.add(tag_name)
- tags.append({
- 'name': tag_name,
- 'description': (sys.modules[view_class.__module__].__doc__ or '').strip()
- })
-
- # Build the Swagger data for the path.
- swagger_path = PARAM_REGEX.sub(r'{\2}', rule.rule)
- full_name = fully_qualified_name(view_class)
- path_swagger = {
- 'x-name': full_name,
- 'x-path': swagger_path,
- 'x-tag': tag_name
- }
-
- related_user_res = method_metadata(view_class, 'related_user_resource')
- if related_user_res is not None:
- path_swagger['x-user-related'] = fully_qualified_name(related_user_res)
-
- paths[swagger_path] = path_swagger
-
- # Add any global path parameters.
- param_data_map = view_class.__api_path_params if '__api_path_params' in dir(
- view_class) else {}
- if param_data_map:
- path_parameters_swagger = []
- for path_parameter in param_data_map:
- description = param_data_map[path_parameter].get('description')
- path_parameters_swagger.append(swagger_parameter(path_parameter, description))
-
- path_swagger['parameters'] = path_parameters_swagger
-
- # Add the individual HTTP operations.
- method_names = list(rule.methods.difference(['HEAD', 'OPTIONS']))
- for method_name in method_names:
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
- method = getattr(view_class, method_name.lower(), None)
- if method is None:
- logger.debug('Unable to find method for %s in class %s', method_name, view_class)
- continue
-
- operationId = method_metadata(method, 'nickname')
- operation_swagger = {
- 'operationId': operationId,
- 'parameters': [],
- }
-
- if operationId is None:
- continue
-
- if operationId in operation_ids:
- raise Exception('Duplicate operation Id: %s' % operationId)
-
- operation_ids.add(operationId)
-
- # Mark the method as internal.
- internal = method_metadata(method, 'internal')
- if internal is not None:
- operation_swagger['x-internal'] = True
-
- if include_internal:
- requires_fresh_login = method_metadata(method, 'requires_fresh_login')
- if requires_fresh_login is not None:
- operation_swagger['x-requires-fresh-login'] = True
-
- # Add the path parameters.
- if rule.arguments:
- for path_parameter in rule.arguments:
- description = param_data_map.get(path_parameter, {}).get('description')
- operation_swagger['parameters'].append(
- swagger_parameter(path_parameter, description))
-
- # Add the query parameters.
- if '__api_query_params' in dir(method):
- for query_parameter_info in method.__api_query_params:
- name = query_parameter_info['name']
- description = query_parameter_info['help']
- param_type = TYPE_CONVERTER[query_parameter_info['type']]
- required = query_parameter_info['required']
-
- operation_swagger['parameters'].append(
- swagger_parameter(name, description, kind='query',
- param_type=param_type,
- required=required,
- enum=query_parameter_info['choices']))
-
- # Add the OAuth security block.
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
- scope = method_metadata(method, 'oauth2_scope')
- if scope and not compact:
- operation_swagger['security'] = [{'oauth2_implicit': [scope.scope]}]
-
- # Add the responses block.
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
- response_schema_name = method_metadata(method, 'response_schema')
- if not compact:
- if response_schema_name:
- models[response_schema_name] = view_class.schemas[response_schema_name]
-
- models['ApiError'] = {
- 'type': 'object',
- 'properties': {
- 'status': {
- 'type': 'integer',
- 'description': 'Status code of the response.'
- },
- 'type': {
- 'type': 'string',
- 'description': 'Reference to the type of the error.'
- },
- 'detail': {
- 'type': 'string',
- 'description': 'Details about the specific instance of the error.'
- },
- 'title': {
- 'type': 'string',
- 'description': 'Unique error code to identify the type of error.'
- },
- 'error_message': {
- 'type': 'string',
- 'description': 'Deprecated; alias for detail'
- },
- 'error_type': {
- 'type': 'string',
- 'description': 'Deprecated; alias for detail'
- }
- },
- 'required': [
- 'status',
- 'type',
- 'title',
- ]
- }
-
- responses = {
- '400': {
- 'description': 'Bad Request',
- },
-
- '401': {
- 'description': 'Session required',
- },
-
- '403': {
- 'description': 'Unauthorized access',
- },
-
- '404': {
- 'description': 'Not found',
- },
- }
-
- for _, body in responses.items():
- body['schema'] = {'$ref': '#/definitions/ApiError'}
-
- if method_name == 'DELETE':
- responses['204'] = {
- 'description': 'Deleted'
- }
- elif method_name == 'POST':
- responses['201'] = {
- 'description': 'Successful creation'
- }
+ if schema:
+ parameter_info["schema"] = {"$ref": "#/definitions/%s" % schema}
else:
- responses['200'] = {
- 'description': 'Successful invocation'
- }
+ parameter_info["type"] = param_type
- if response_schema_name:
- responses['200']['schema'] = {
- '$ref': '#/definitions/%s' % response_schema_name
- }
+ if enum is not None and len(list(enum)) > 0:
+ parameter_info["enum"] = list(enum)
- operation_swagger['responses'] = responses
+ return parameter_info
- # Add the request block.
- request_schema_name = method_metadata(method, 'request_schema')
- if request_schema_name and not compact:
- models[request_schema_name] = view_class.schemas[request_schema_name]
+ paths = {}
+ models = {}
+ tags = []
+ tags_added = set()
+ operation_ids = set()
- operation_swagger['parameters'].append(
- swagger_parameter('body', 'Request body contents.', kind='body',
- schema=request_schema_name))
+ for rule in app.url_map.iter_rules():
+ endpoint_method = app.view_functions[rule.endpoint]
- # Add the operation to the parent path.
- if not internal or (internal and include_internal):
- path_swagger[method_name.lower()] = operation_swagger
+ # Verify that we have a view class for this API method.
+ if not "view_class" in dir(endpoint_method):
+ continue
- tags.sort(key=lambda t: t['name'])
- paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]['x-tag']))
+ view_class = endpoint_method.view_class
- if compact:
- return {'paths': paths}
+ # Hide the class if it is internal.
+ internal = method_metadata(view_class, "internal")
+ if not include_internal and internal:
+ continue
+
+ # Build the tag.
+ parts = fully_qualified_name(view_class).split(".")
+ tag_name = parts[-2]
+ if not tag_name in tags_added:
+ tags_added.add(tag_name)
+ tags.append(
+ {
+ "name": tag_name,
+ "description": (
+ sys.modules[view_class.__module__].__doc__ or ""
+ ).strip(),
+ }
+ )
+
+ # Build the Swagger data for the path.
+ swagger_path = PARAM_REGEX.sub(r"{\2}", rule.rule)
+ full_name = fully_qualified_name(view_class)
+ path_swagger = {"x-name": full_name, "x-path": swagger_path, "x-tag": tag_name}
+
+ related_user_res = method_metadata(view_class, "related_user_resource")
+ if related_user_res is not None:
+ path_swagger["x-user-related"] = fully_qualified_name(related_user_res)
+
+ paths[swagger_path] = path_swagger
+
+ # Add any global path parameters.
+ param_data_map = (
+ view_class.__api_path_params
+ if "__api_path_params" in dir(view_class)
+ else {}
+ )
+ if param_data_map:
+ path_parameters_swagger = []
+ for path_parameter in param_data_map:
+ description = param_data_map[path_parameter].get("description")
+ path_parameters_swagger.append(
+ swagger_parameter(path_parameter, description)
+ )
+
+ path_swagger["parameters"] = path_parameters_swagger
+
+ # Add the individual HTTP operations.
+ method_names = list(rule.methods.difference(["HEAD", "OPTIONS"]))
+ for method_name in method_names:
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
+ method = getattr(view_class, method_name.lower(), None)
+ if method is None:
+ logger.debug(
+ "Unable to find method for %s in class %s", method_name, view_class
+ )
+ continue
+
+ operationId = method_metadata(method, "nickname")
+ operation_swagger = {"operationId": operationId, "parameters": []}
+
+ if operationId is None:
+ continue
+
+ if operationId in operation_ids:
+ raise Exception("Duplicate operation Id: %s" % operationId)
+
+ operation_ids.add(operationId)
+
+ # Mark the method as internal.
+ internal = method_metadata(method, "internal")
+ if internal is not None:
+ operation_swagger["x-internal"] = True
+
+ if include_internal:
+ requires_fresh_login = method_metadata(method, "requires_fresh_login")
+ if requires_fresh_login is not None:
+ operation_swagger["x-requires-fresh-login"] = True
+
+ # Add the path parameters.
+ if rule.arguments:
+ for path_parameter in rule.arguments:
+ description = param_data_map.get(path_parameter, {}).get(
+ "description"
+ )
+ operation_swagger["parameters"].append(
+ swagger_parameter(path_parameter, description)
+ )
+
+ # Add the query parameters.
+ if "__api_query_params" in dir(method):
+ for query_parameter_info in method.__api_query_params:
+ name = query_parameter_info["name"]
+ description = query_parameter_info["help"]
+ param_type = TYPE_CONVERTER[query_parameter_info["type"]]
+ required = query_parameter_info["required"]
+
+ operation_swagger["parameters"].append(
+ swagger_parameter(
+ name,
+ description,
+ kind="query",
+ param_type=param_type,
+ required=required,
+ enum=query_parameter_info["choices"],
+ )
+ )
+
+ # Add the OAuth security block.
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
+ scope = method_metadata(method, "oauth2_scope")
+ if scope and not compact:
+ operation_swagger["security"] = [{"oauth2_implicit": [scope.scope]}]
+
+ # Add the responses block.
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
+ response_schema_name = method_metadata(method, "response_schema")
+ if not compact:
+ if response_schema_name:
+ models[response_schema_name] = view_class.schemas[
+ response_schema_name
+ ]
+
+ models["ApiError"] = {
+ "type": "object",
+ "properties": {
+ "status": {
+ "type": "integer",
+ "description": "Status code of the response.",
+ },
+ "type": {
+ "type": "string",
+ "description": "Reference to the type of the error.",
+ },
+ "detail": {
+ "type": "string",
+ "description": "Details about the specific instance of the error.",
+ },
+ "title": {
+ "type": "string",
+ "description": "Unique error code to identify the type of error.",
+ },
+ "error_message": {
+ "type": "string",
+ "description": "Deprecated; alias for detail",
+ },
+ "error_type": {
+ "type": "string",
+ "description": "Deprecated; alias for detail",
+ },
+ },
+ "required": ["status", "type", "title"],
+ }
+
+ responses = {
+ "400": {"description": "Bad Request"},
+ "401": {"description": "Session required"},
+ "403": {"description": "Unauthorized access"},
+ "404": {"description": "Not found"},
+ }
+
+ for _, body in responses.items():
+ body["schema"] = {"$ref": "#/definitions/ApiError"}
+
+ if method_name == "DELETE":
+ responses["204"] = {"description": "Deleted"}
+ elif method_name == "POST":
+ responses["201"] = {"description": "Successful creation"}
+ else:
+ responses["200"] = {"description": "Successful invocation"}
+
+ if response_schema_name:
+ responses["200"]["schema"] = {
+ "$ref": "#/definitions/%s" % response_schema_name
+ }
+
+ operation_swagger["responses"] = responses
+
+ # Add the request block.
+ request_schema_name = method_metadata(method, "request_schema")
+ if request_schema_name and not compact:
+ models[request_schema_name] = view_class.schemas[request_schema_name]
+
+ operation_swagger["parameters"].append(
+ swagger_parameter(
+ "body",
+ "Request body contents.",
+ kind="body",
+ schema=request_schema_name,
+ )
+ )
+
+ # Add the operation to the parent path.
+ if not internal or (internal and include_internal):
+ path_swagger[method_name.lower()] = operation_swagger
+
+ tags.sort(key=lambda t: t["name"])
+ paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]["x-tag"]))
+
+ if compact:
+ return {"paths": paths}
diff --git a/config_app/config_endpoints/api/kube_endpoints.py b/config_app/config_endpoints/api/kube_endpoints.py
index a7143412d..b71d2fe61 100644
--- a/config_app/config_endpoints/api/kube_endpoints.py
+++ b/config_app/config_endpoints/api/kube_endpoints.py
@@ -6,138 +6,152 @@ from config_app.config_util.config import get_config_as_kube_secret
from data.database import configure
from config_app.c_app import app, config_provider
-from config_app.config_endpoints.api import resource, ApiResource, nickname, kubernetes_only, validate_json_request
-from config_app.config_util.k8saccessor import KubernetesAccessorSingleton, K8sApiException
+from config_app.config_endpoints.api import (
+ resource,
+ ApiResource,
+ nickname,
+ kubernetes_only,
+ validate_json_request,
+)
+from config_app.config_util.k8saccessor import (
+ KubernetesAccessorSingleton,
+ K8sApiException,
+)
logger = logging.getLogger(__name__)
-@resource('/v1/kubernetes/deployments/')
+
+@resource("/v1/kubernetes/deployments/")
class SuperUserKubernetesDeployment(ApiResource):
- """ Resource for the getting the status of Red Hat Quay deployments and cycling them """
- schemas = {
- 'ValidateDeploymentNames': {
- 'type': 'object',
- 'description': 'Validates deployment names for cycling',
- 'required': [
- 'deploymentNames'
- ],
- 'properties': {
- 'deploymentNames': {
- 'type': 'array',
- 'description': 'The names of the deployments to cycle'
- },
- },
+ """ Resource for the getting the status of Red Hat Quay deployments and cycling them """
+
+ schemas = {
+ "ValidateDeploymentNames": {
+ "type": "object",
+ "description": "Validates deployment names for cycling",
+ "required": ["deploymentNames"],
+ "properties": {
+ "deploymentNames": {
+ "type": "array",
+ "description": "The names of the deployments to cycle",
+ }
+ },
+ }
}
- }
- @kubernetes_only
- @nickname('scGetNumDeployments')
- def get(self):
- return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
+ @kubernetes_only
+ @nickname("scGetNumDeployments")
+ def get(self):
+ return KubernetesAccessorSingleton.get_instance().get_qe_deployments()
- @kubernetes_only
- @validate_json_request('ValidateDeploymentNames')
- @nickname('scCycleQEDeployments')
- def put(self):
- deployment_names = request.get_json()['deploymentNames']
- return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(deployment_names)
+ @kubernetes_only
+ @validate_json_request("ValidateDeploymentNames")
+ @nickname("scCycleQEDeployments")
+ def put(self):
+ deployment_names = request.get_json()["deploymentNames"]
+ return KubernetesAccessorSingleton.get_instance().cycle_qe_deployments(
+ deployment_names
+ )
-@resource('/v1/kubernetes/deployment//status')
+@resource("/v1/kubernetes/deployment//status")
class QEDeploymentRolloutStatus(ApiResource):
- @kubernetes_only
- @nickname('scGetDeploymentRolloutStatus')
- def get(self, deployment):
- deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(deployment)
- return {
- 'status': deployment_rollout_status.status,
- 'message': deployment_rollout_status.message,
- }
+ @kubernetes_only
+ @nickname("scGetDeploymentRolloutStatus")
+ def get(self, deployment):
+ deployment_rollout_status = KubernetesAccessorSingleton.get_instance().get_deployment_rollout_status(
+ deployment
+ )
+ return {
+ "status": deployment_rollout_status.status,
+ "message": deployment_rollout_status.message,
+ }
-@resource('/v1/kubernetes/deployments/rollback')
+@resource("/v1/kubernetes/deployments/rollback")
class QEDeploymentRollback(ApiResource):
- """ Resource for rolling back deployments """
- schemas = {
- 'ValidateDeploymentNames': {
- 'type': 'object',
- 'description': 'Validates deployment names for rolling back',
- 'required': [
- 'deploymentNames'
- ],
- 'properties': {
- 'deploymentNames': {
- 'type': 'array',
- 'description': 'The names of the deployments to rollback'
- },
- },
- }
- }
+ """ Resource for rolling back deployments """
- @kubernetes_only
- @nickname('scRollbackDeployments')
- @validate_json_request('ValidateDeploymentNames')
- def post(self):
- """
+ schemas = {
+ "ValidateDeploymentNames": {
+ "type": "object",
+ "description": "Validates deployment names for rolling back",
+ "required": ["deploymentNames"],
+ "properties": {
+ "deploymentNames": {
+ "type": "array",
+ "description": "The names of the deployments to rollback",
+ }
+ },
+ }
+ }
+
+ @kubernetes_only
+ @nickname("scRollbackDeployments")
+ @validate_json_request("ValidateDeploymentNames")
+ def post(self):
+ """
Returns the config to its original state and rolls back deployments
:return:
"""
- deployment_names = request.get_json()['deploymentNames']
+ deployment_names = request.get_json()["deploymentNames"]
- # To roll back a deployment, we must do 2 things:
- # 1. Roll back the config secret to its old value (discarding changes we made in this session)
- # 2. Trigger a rollback to the previous revision, so that the pods will be restarted with
- # the old config
- old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir())
- kube_accessor = KubernetesAccessorSingleton.get_instance()
- kube_accessor.replace_qe_secret(old_secret)
+ # To roll back a deployment, we must do 2 things:
+ # 1. Roll back the config secret to its old value (discarding changes we made in this session)
+ # 2. Trigger a rollback to the previous revision, so that the pods will be restarted with
+ # the old config
+ old_secret = get_config_as_kube_secret(config_provider.get_old_config_dir())
+ kube_accessor = KubernetesAccessorSingleton.get_instance()
+ kube_accessor.replace_qe_secret(old_secret)
- try:
- for name in deployment_names:
- kube_accessor.rollback_deployment(name)
- except K8sApiException as e:
- logger.exception('Failed to rollback deployment.')
- return make_response(e.message, 503)
+ try:
+ for name in deployment_names:
+ kube_accessor.rollback_deployment(name)
+ except K8sApiException as e:
+ logger.exception("Failed to rollback deployment.")
+ return make_response(e.message, 503)
- return make_response('Ok', 204)
+ return make_response("Ok", 204)
-@resource('/v1/kubernetes/config')
+@resource("/v1/kubernetes/config")
class SuperUserKubernetesConfiguration(ApiResource):
- """ Resource for saving the config files to kubernetes secrets. """
+ """ Resource for saving the config files to kubernetes secrets. """
- @kubernetes_only
- @nickname('scDeployConfiguration')
- def post(self):
- try:
- new_secret = get_config_as_kube_secret(config_provider.get_config_dir_path())
- KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
- except K8sApiException as e:
- logger.exception('Failed to deploy qe config secret to kubernetes.')
- return make_response(e.message, 503)
+ @kubernetes_only
+ @nickname("scDeployConfiguration")
+ def post(self):
+ try:
+ new_secret = get_config_as_kube_secret(
+ config_provider.get_config_dir_path()
+ )
+ KubernetesAccessorSingleton.get_instance().replace_qe_secret(new_secret)
+ except K8sApiException as e:
+ logger.exception("Failed to deploy qe config secret to kubernetes.")
+ return make_response(e.message, 503)
- return make_response('Ok', 201)
+ return make_response("Ok", 201)
-@resource('/v1/kubernetes/config/populate')
+@resource("/v1/kubernetes/config/populate")
class KubernetesConfigurationPopulator(ApiResource):
- """ Resource for populating the local configuration from the cluster's kubernetes secrets. """
+ """ Resource for populating the local configuration from the cluster's kubernetes secrets. """
- @kubernetes_only
- @nickname('scKubePopulateConfig')
- def post(self):
- # Get a clean transient directory to write the config into
- config_provider.new_config_dir()
+ @kubernetes_only
+ @nickname("scKubePopulateConfig")
+ def post(self):
+ # Get a clean transient directory to write the config into
+ config_provider.new_config_dir()
- kube_accessor = KubernetesAccessorSingleton.get_instance()
- kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path())
- config_provider.create_copy_of_config_dir()
+ kube_accessor = KubernetesAccessorSingleton.get_instance()
+ kube_accessor.save_secret_to_directory(config_provider.get_config_dir_path())
+ config_provider.create_copy_of_config_dir()
- # We update the db configuration to connect to their specified one
- # (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error,
- # and all of the options create a new clean dir, so we'll never pollute configs)
- combined = dict(**app.config)
- combined.update(config_provider.get_config())
- configure(combined)
+ # We update the db configuration to connect to their specified one
+ # (Note, even if this DB isn't valid, it won't affect much in the config app, since we'll report an error,
+ # and all of the options create a new clean dir, so we'll never pollute configs)
+ combined = dict(**app.config)
+ combined.update(config_provider.get_config())
+ configure(combined)
- return 200
+ return 200
diff --git a/config_app/config_endpoints/api/suconfig.py b/config_app/config_endpoints/api/suconfig.py
index 810d4a229..29c3545b7 100644
--- a/config_app/config_endpoints/api/suconfig.py
+++ b/config_app/config_endpoints/api/suconfig.py
@@ -2,301 +2,283 @@ import logging
from flask import abort, request
-from config_app.config_endpoints.api.suconfig_models_pre_oci import pre_oci_model as model
-from config_app.config_endpoints.api import resource, ApiResource, nickname, validate_json_request
-from config_app.c_app import (app, config_provider, superusers, ip_resolver,
- instance_keys, INIT_SCRIPTS_LOCATION)
+from config_app.config_endpoints.api.suconfig_models_pre_oci import (
+ pre_oci_model as model,
+)
+from config_app.config_endpoints.api import (
+ resource,
+ ApiResource,
+ nickname,
+ validate_json_request,
+)
+from config_app.c_app import (
+ app,
+ config_provider,
+ superusers,
+ ip_resolver,
+ instance_keys,
+ INIT_SCRIPTS_LOCATION,
+)
from data.database import configure
from data.runmigration import run_alembic_migration
from util.config.configutil import add_enterprise_config_defaults
-from util.config.validator import validate_service_for_config, ValidatorContext, \
- is_valid_config_upload_filename
+from util.config.validator import (
+ validate_service_for_config,
+ ValidatorContext,
+ is_valid_config_upload_filename,
+)
logger = logging.getLogger(__name__)
def database_is_valid():
- """ Returns whether the database, as configured, is valid. """
- return model.is_valid()
+ """ Returns whether the database, as configured, is valid. """
+ return model.is_valid()
def database_has_users():
- """ Returns whether the database has any users defined. """
- return model.has_users()
+ """ Returns whether the database has any users defined. """
+ return model.has_users()
-@resource('/v1/superuser/config')
+@resource("/v1/superuser/config")
class SuperUserConfig(ApiResource):
- """ Resource for fetching and updating the current configuration, if any. """
- schemas = {
- 'UpdateConfig': {
- 'type': 'object',
- 'description': 'Updates the YAML config file',
- 'required': [
- 'config',
- ],
- 'properties': {
- 'config': {
- 'type': 'object'
- },
- 'password': {
- 'type': 'string'
- },
- },
- },
- }
+ """ Resource for fetching and updating the current configuration, if any. """
- @nickname('scGetConfig')
- def get(self):
- """ Returns the currently defined configuration, if any. """
- config_object = config_provider.get_config()
- return {
- 'config': config_object
+ schemas = {
+ "UpdateConfig": {
+ "type": "object",
+ "description": "Updates the YAML config file",
+ "required": ["config"],
+ "properties": {
+ "config": {"type": "object"},
+ "password": {"type": "string"},
+ },
+ }
}
- @nickname('scUpdateConfig')
- @validate_json_request('UpdateConfig')
- def put(self):
- """ Updates the config override file. """
- # Note: This method is called to set the database configuration before super users exists,
- # so we also allow it to be called if there is no valid registry configuration setup.
- config_object = request.get_json()['config']
+ @nickname("scGetConfig")
+ def get(self):
+ """ Returns the currently defined configuration, if any. """
+ config_object = config_provider.get_config()
+ return {"config": config_object}
- # Add any enterprise defaults missing from the config.
- add_enterprise_config_defaults(config_object, app.config['SECRET_KEY'])
+ @nickname("scUpdateConfig")
+ @validate_json_request("UpdateConfig")
+ def put(self):
+ """ Updates the config override file. """
+ # Note: This method is called to set the database configuration before super users exists,
+ # so we also allow it to be called if there is no valid registry configuration setup.
+ config_object = request.get_json()["config"]
- # Write the configuration changes to the config override file.
- config_provider.save_config(config_object)
+ # Add any enterprise defaults missing from the config.
+ add_enterprise_config_defaults(config_object, app.config["SECRET_KEY"])
- # now try to connect to the db provided in their config to validate it works
- combined = dict(**app.config)
- combined.update(config_provider.get_config())
- configure(combined, testing=app.config['TESTING'])
+ # Write the configuration changes to the config override file.
+ config_provider.save_config(config_object)
- return {
- 'exists': True,
- 'config': config_object
- }
+ # now try to connect to the db provided in their config to validate it works
+ combined = dict(**app.config)
+ combined.update(config_provider.get_config())
+ configure(combined, testing=app.config["TESTING"])
+
+ return {"exists": True, "config": config_object}
-@resource('/v1/superuser/registrystatus')
+@resource("/v1/superuser/registrystatus")
class SuperUserRegistryStatus(ApiResource):
- """ Resource for determining the status of the registry, such as if config exists,
+ """ Resource for determining the status of the registry, such as if config exists,
if a database is configured, and if it has any defined users.
"""
- @nickname('scRegistryStatus')
- def get(self):
- """ Returns the status of the registry. """
- # If there is no config file, we need to setup the database.
- if not config_provider.config_exists():
- return {
- 'status': 'config-db'
- }
+ @nickname("scRegistryStatus")
+ def get(self):
+ """ Returns the status of the registry. """
+ # If there is no config file, we need to setup the database.
+ if not config_provider.config_exists():
+ return {"status": "config-db"}
- # If the database isn't yet valid, then we need to set it up.
- if not database_is_valid():
- return {
- 'status': 'setup-db'
- }
+ # If the database isn't yet valid, then we need to set it up.
+ if not database_is_valid():
+ return {"status": "setup-db"}
- config = config_provider.get_config()
- if config and config.get('SETUP_COMPLETE'):
- return {
- 'status': 'config'
- }
+ config = config_provider.get_config()
+ if config and config.get("SETUP_COMPLETE"):
+ return {"status": "config"}
- return {
- 'status': 'create-superuser' if not database_has_users() else 'config'
- }
+ return {"status": "create-superuser" if not database_has_users() else "config"}
class _AlembicLogHandler(logging.Handler):
- def __init__(self):
- super(_AlembicLogHandler, self).__init__()
- self.records = []
+ def __init__(self):
+ super(_AlembicLogHandler, self).__init__()
+ self.records = []
- def emit(self, record):
- self.records.append({
- 'level': record.levelname,
- 'message': record.getMessage()
- })
+ def emit(self, record):
+ self.records.append({"level": record.levelname, "message": record.getMessage()})
def _reload_config():
- combined = dict(**app.config)
- combined.update(config_provider.get_config())
- configure(combined)
- return combined
+ combined = dict(**app.config)
+ combined.update(config_provider.get_config())
+ configure(combined)
+ return combined
-@resource('/v1/superuser/setupdb')
+@resource("/v1/superuser/setupdb")
class SuperUserSetupDatabase(ApiResource):
- """ Resource for invoking alembic to setup the database. """
+ """ Resource for invoking alembic to setup the database. """
- @nickname('scSetupDatabase')
- def get(self):
- """ Invokes the alembic upgrade process. """
- # Note: This method is called after the database configured is saved, but before the
- # database has any tables. Therefore, we only allow it to be run in that unique case.
- if config_provider.config_exists() and not database_is_valid():
- combined = _reload_config()
+ @nickname("scSetupDatabase")
+ def get(self):
+ """ Invokes the alembic upgrade process. """
+ # Note: This method is called after the database configured is saved, but before the
+ # database has any tables. Therefore, we only allow it to be run in that unique case.
+ if config_provider.config_exists() and not database_is_valid():
+ combined = _reload_config()
- app.config['DB_URI'] = combined['DB_URI']
- db_uri = app.config['DB_URI']
- escaped_db_uri = db_uri.replace('%', '%%')
+ app.config["DB_URI"] = combined["DB_URI"]
+ db_uri = app.config["DB_URI"]
+ escaped_db_uri = db_uri.replace("%", "%%")
- log_handler = _AlembicLogHandler()
+ log_handler = _AlembicLogHandler()
- try:
- run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
- except Exception as ex:
- return {
- 'error': str(ex)
- }
+ try:
+ run_alembic_migration(escaped_db_uri, log_handler, setup_app=False)
+ except Exception as ex:
+ return {"error": str(ex)}
- return {
- 'logs': log_handler.records
- }
+ return {"logs": log_handler.records}
- abort(403)
+ abort(403)
-@resource('/v1/superuser/config/createsuperuser')
+@resource("/v1/superuser/config/createsuperuser")
class SuperUserCreateInitialSuperUser(ApiResource):
- """ Resource for creating the initial super user. """
- schemas = {
- 'CreateSuperUser': {
- 'type': 'object',
- 'description': 'Information for creating the initial super user',
- 'required': [
- 'username',
- 'password',
- 'email'
- ],
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'The username for the superuser'
- },
- 'password': {
- 'type': 'string',
- 'description': 'The password for the superuser'
- },
- 'email': {
- 'type': 'string',
- 'description': 'The e-mail address for the superuser'
- },
- },
- },
- }
+ """ Resource for creating the initial super user. """
- @nickname('scCreateInitialSuperuser')
- @validate_json_request('CreateSuperUser')
- def post(self):
- """ Creates the initial super user, updates the underlying configuration and
+ schemas = {
+ "CreateSuperUser": {
+ "type": "object",
+ "description": "Information for creating the initial super user",
+ "required": ["username", "password", "email"],
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "The username for the superuser",
+ },
+ "password": {
+ "type": "string",
+ "description": "The password for the superuser",
+ },
+ "email": {
+ "type": "string",
+ "description": "The e-mail address for the superuser",
+ },
+ },
+ }
+ }
+
+ @nickname("scCreateInitialSuperuser")
+ @validate_json_request("CreateSuperUser")
+ def post(self):
+ """ Creates the initial super user, updates the underlying configuration and
sets the current session to have that super user. """
- _reload_config()
+ _reload_config()
- # Special security check: This method is only accessible when:
- # - There is a valid config YAML file.
- # - There are currently no users in the database (clean install)
- #
- # We do this special security check because at the point this method is called, the database
- # is clean but does not (yet) have any super users for our permissions code to check against.
- if config_provider.config_exists() and not database_has_users():
- data = request.get_json()
- username = data['username']
- password = data['password']
- email = data['email']
+ # Special security check: This method is only accessible when:
+ # - There is a valid config YAML file.
+ # - There are currently no users in the database (clean install)
+ #
+ # We do this special security check because at the point this method is called, the database
+ # is clean but does not (yet) have any super users for our permissions code to check against.
+ if config_provider.config_exists() and not database_has_users():
+ data = request.get_json()
+ username = data["username"]
+ password = data["password"]
+ email = data["email"]
- # Create the user in the database.
- superuser_uuid = model.create_superuser(username, password, email)
+ # Create the user in the database.
+ superuser_uuid = model.create_superuser(username, password, email)
- # Add the user to the config.
- config_object = config_provider.get_config()
- config_object['SUPER_USERS'] = [username]
- config_provider.save_config(config_object)
+ # Add the user to the config.
+ config_object = config_provider.get_config()
+ config_object["SUPER_USERS"] = [username]
+ config_provider.save_config(config_object)
- # Update the in-memory config for the new superuser.
- superusers.register_superuser(username)
+ # Update the in-memory config for the new superuser.
+ superusers.register_superuser(username)
- return {
- 'status': True
- }
+ return {"status": True}
- abort(403)
+ abort(403)
-@resource('/v1/superuser/config/validate/')
+@resource("/v1/superuser/config/validate/")
class SuperUserConfigValidate(ApiResource):
- """ Resource for validating a block of configuration against an external service. """
- schemas = {
- 'ValidateConfig': {
- 'type': 'object',
- 'description': 'Validates configuration',
- 'required': [
- 'config'
- ],
- 'properties': {
- 'config': {
- 'type': 'object'
- },
- 'password': {
- 'type': 'string',
- 'description': 'The users password, used for auth validation'
+ """ Resource for validating a block of configuration against an external service. """
+
+ schemas = {
+ "ValidateConfig": {
+ "type": "object",
+ "description": "Validates configuration",
+ "required": ["config"],
+ "properties": {
+ "config": {"type": "object"},
+ "password": {
+ "type": "string",
+ "description": "The users password, used for auth validation",
+ },
+ },
}
- },
- },
- }
+ }
- @nickname('scValidateConfig')
- @validate_json_request('ValidateConfig')
- def post(self, service):
- """ Validates the given config for the given service. """
- # Note: This method is called to validate the database configuration before super users exists,
- # so we also allow it to be called if there is no valid registry configuration setup. Note that
- # this is also safe since this method does not access any information not given in the request.
- config = request.get_json()['config']
- validator_context = ValidatorContext.from_app(app, config,
- request.get_json().get('password', ''),
- instance_keys=instance_keys,
- ip_resolver=ip_resolver,
- config_provider=config_provider,
- init_scripts_location=INIT_SCRIPTS_LOCATION)
+ @nickname("scValidateConfig")
+ @validate_json_request("ValidateConfig")
+ def post(self, service):
+ """ Validates the given config for the given service. """
+ # Note: This method is called to validate the database configuration before super users exists,
+ # so we also allow it to be called if there is no valid registry configuration setup. Note that
+ # this is also safe since this method does not access any information not given in the request.
+ config = request.get_json()["config"]
+ validator_context = ValidatorContext.from_app(
+ app,
+ config,
+ request.get_json().get("password", ""),
+ instance_keys=instance_keys,
+ ip_resolver=ip_resolver,
+ config_provider=config_provider,
+ init_scripts_location=INIT_SCRIPTS_LOCATION,
+ )
- return validate_service_for_config(service, validator_context)
+ return validate_service_for_config(service, validator_context)
-@resource('/v1/superuser/config/file/')
+@resource("/v1/superuser/config/file/")
class SuperUserConfigFile(ApiResource):
- """ Resource for fetching the status of config files and overriding them. """
+ """ Resource for fetching the status of config files and overriding them. """
- @nickname('scConfigFileExists')
- def get(self, filename):
- """ Returns whether the configuration file with the given name exists. """
- if not is_valid_config_upload_filename(filename):
- abort(404)
+ @nickname("scConfigFileExists")
+ def get(self, filename):
+ """ Returns whether the configuration file with the given name exists. """
+ if not is_valid_config_upload_filename(filename):
+ abort(404)
- return {
- 'exists': config_provider.volume_file_exists(filename)
- }
+ return {"exists": config_provider.volume_file_exists(filename)}
- @nickname('scUpdateConfigFile')
- def post(self, filename):
- """ Updates the configuration file with the given name. """
- if not is_valid_config_upload_filename(filename):
- abort(404)
+ @nickname("scUpdateConfigFile")
+ def post(self, filename):
+ """ Updates the configuration file with the given name. """
+ if not is_valid_config_upload_filename(filename):
+ abort(404)
- # Note: This method can be called before the configuration exists
- # to upload the database SSL cert.
- uploaded_file = request.files['file']
- if not uploaded_file:
- abort(400)
+ # Note: This method can be called before the configuration exists
+ # to upload the database SSL cert.
+ uploaded_file = request.files["file"]
+ if not uploaded_file:
+ abort(400)
- config_provider.save_volume_file(filename, uploaded_file)
- return {
- 'status': True
- }
+ config_provider.save_volume_file(filename, uploaded_file)
+ return {"status": True}
diff --git a/config_app/config_endpoints/api/suconfig_models_interface.py b/config_app/config_endpoints/api/suconfig_models_interface.py
index 9f8cbd0cb..d41a97d11 100644
--- a/config_app/config_endpoints/api/suconfig_models_interface.py
+++ b/config_app/config_endpoints/api/suconfig_models_interface.py
@@ -4,36 +4,36 @@ from six import add_metaclass
@add_metaclass(ABCMeta)
class SuperuserConfigDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by the superuser config API.
"""
- @abstractmethod
- def is_valid(self):
- """
+ @abstractmethod
+ def is_valid(self):
+ """
Returns true if the configured database is valid.
"""
- @abstractmethod
- def has_users(self):
- """
+ @abstractmethod
+ def has_users(self):
+ """
Returns true if there are any users defined.
"""
- @abstractmethod
- def create_superuser(self, username, password, email):
- """
+ @abstractmethod
+ def create_superuser(self, username, password, email):
+ """
Creates a new superuser with the given username, password and email. Returns the user's UUID.
"""
- @abstractmethod
- def has_federated_login(self, username, service_name):
- """
+ @abstractmethod
+ def has_federated_login(self, username, service_name):
+ """
Returns true if the matching user has a federated login under the matching service.
"""
- @abstractmethod
- def attach_federated_login(self, username, service_name, federated_username):
- """
+ @abstractmethod
+ def attach_federated_login(self, username, service_name, federated_username):
+ """
Attaches a federatated login to the matching user, under the given service.
"""
diff --git a/config_app/config_endpoints/api/suconfig_models_pre_oci.py b/config_app/config_endpoints/api/suconfig_models_pre_oci.py
index fbc238078..9e512e88a 100644
--- a/config_app/config_endpoints/api/suconfig_models_pre_oci.py
+++ b/config_app/config_endpoints/api/suconfig_models_pre_oci.py
@@ -1,37 +1,39 @@
from data import model
from data.database import User
-from config_app.config_endpoints.api.suconfig_models_interface import SuperuserConfigDataInterface
+from config_app.config_endpoints.api.suconfig_models_interface import (
+ SuperuserConfigDataInterface,
+)
class PreOCIModel(SuperuserConfigDataInterface):
- # Note: this method is different than has_users: the user select will throw if the user
- # table does not exist, whereas has_users assumes the table is valid
- def is_valid(self):
- try:
- list(User.select().limit(1))
- return True
- except:
- return False
+ # Note: this method is different than has_users: the user select will throw if the user
+ # table does not exist, whereas has_users assumes the table is valid
+ def is_valid(self):
+ try:
+ list(User.select().limit(1))
+ return True
+ except:
+ return False
- def has_users(self):
- return bool(list(User.select().limit(1)))
+ def has_users(self):
+ return bool(list(User.select().limit(1)))
- def create_superuser(self, username, password, email):
- return model.user.create_user(username, password, email, auto_verify=True).uuid
+ def create_superuser(self, username, password, email):
+ return model.user.create_user(username, password, email, auto_verify=True).uuid
- def has_federated_login(self, username, service_name):
- user = model.user.get_user(username)
- if user is None:
- return False
+ def has_federated_login(self, username, service_name):
+ user = model.user.get_user(username)
+ if user is None:
+ return False
- return bool(model.user.lookup_federated_login(user, service_name))
+ return bool(model.user.lookup_federated_login(user, service_name))
- def attach_federated_login(self, username, service_name, federated_username):
- user = model.user.get_user(username)
- if user is None:
- return False
+ def attach_federated_login(self, username, service_name, federated_username):
+ user = model.user.get_user(username)
+ if user is None:
+ return False
- model.user.attach_federated_login(user, service_name, federated_username)
+ model.user.attach_federated_login(user, service_name, federated_username)
pre_oci_model = PreOCIModel()
diff --git a/config_app/config_endpoints/api/superuser.py b/config_app/config_endpoints/api/superuser.py
index 7e5adccb5..db5d1d81c 100644
--- a/config_app/config_endpoints/api/superuser.py
+++ b/config_app/config_endpoints/api/superuser.py
@@ -12,7 +12,13 @@ from data.model import ServiceKeyDoesNotExist
from util.config.validator import EXTRA_CA_DIRECTORY
from config_app.config_endpoints.exception import InvalidRequest
-from config_app.config_endpoints.api import resource, ApiResource, nickname, log_action, validate_json_request
+from config_app.config_endpoints.api import (
+ resource,
+ ApiResource,
+ nickname,
+ log_action,
+ validate_json_request,
+)
from config_app.config_endpoints.api.superuser_models_pre_oci import pre_oci_model
from config_app.config_util.ssl import load_certificate, CertInvalidException
from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
@@ -21,228 +27,233 @@ from config_app.c_app import app, config_provider, INIT_SCRIPTS_LOCATION
logger = logging.getLogger(__name__)
-@resource('/v1/superuser/customcerts/')
+@resource("/v1/superuser/customcerts/")
class SuperUserCustomCertificate(ApiResource):
- """ Resource for managing a custom certificate. """
+ """ Resource for managing a custom certificate. """
- @nickname('uploadCustomCertificate')
- def post(self, certpath):
- uploaded_file = request.files['file']
- if not uploaded_file:
- raise InvalidRequest('Missing certificate file')
+ @nickname("uploadCustomCertificate")
+ def post(self, certpath):
+ uploaded_file = request.files["file"]
+ if not uploaded_file:
+ raise InvalidRequest("Missing certificate file")
- # Save the certificate.
- certpath = pathvalidate.sanitize_filename(certpath)
- if not certpath.endswith('.crt'):
- raise InvalidRequest('Invalid certificate file: must have suffix `.crt`')
+ # Save the certificate.
+ certpath = pathvalidate.sanitize_filename(certpath)
+ if not certpath.endswith(".crt"):
+ raise InvalidRequest("Invalid certificate file: must have suffix `.crt`")
- logger.debug('Saving custom certificate %s', certpath)
- cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
- config_provider.save_volume_file(cert_full_path, uploaded_file)
- logger.debug('Saved custom certificate %s', certpath)
+ logger.debug("Saving custom certificate %s", certpath)
+ cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
+ config_provider.save_volume_file(cert_full_path, uploaded_file)
+ logger.debug("Saved custom certificate %s", certpath)
- # Validate the certificate.
- try:
- logger.debug('Loading custom certificate %s', certpath)
- with config_provider.get_volume_file(cert_full_path) as f:
- load_certificate(f.read())
- except CertInvalidException:
- logger.exception('Got certificate invalid error for cert %s', certpath)
- return '', 204
- except IOError:
- logger.exception('Got IO error for cert %s', certpath)
- return '', 204
+ # Validate the certificate.
+ try:
+ logger.debug("Loading custom certificate %s", certpath)
+ with config_provider.get_volume_file(cert_full_path) as f:
+ load_certificate(f.read())
+ except CertInvalidException:
+ logger.exception("Got certificate invalid error for cert %s", certpath)
+ return "", 204
+ except IOError:
+ logger.exception("Got IO error for cert %s", certpath)
+ return "", 204
- # Call the update script with config dir location to install the certificate immediately.
- if not app.config['TESTING']:
- cert_dir = os.path.join(config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY)
- if subprocess.call([os.path.join(INIT_SCRIPTS_LOCATION, 'certs_install.sh')], env={ 'CERTDIR': cert_dir }) != 0:
- raise Exception('Could not install certificates')
+ # Call the update script with config dir location to install the certificate immediately.
+ if not app.config["TESTING"]:
+ cert_dir = os.path.join(
+ config_provider.get_config_dir_path(), EXTRA_CA_DIRECTORY
+ )
+ if (
+ subprocess.call(
+ [os.path.join(INIT_SCRIPTS_LOCATION, "certs_install.sh")],
+ env={"CERTDIR": cert_dir},
+ )
+ != 0
+ ):
+ raise Exception("Could not install certificates")
- return '', 204
+ return "", 204
- @nickname('deleteCustomCertificate')
- def delete(self, certpath):
- cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
- config_provider.remove_volume_file(cert_full_path)
- return '', 204
+ @nickname("deleteCustomCertificate")
+ def delete(self, certpath):
+ cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, certpath)
+ config_provider.remove_volume_file(cert_full_path)
+ return "", 204
-@resource('/v1/superuser/customcerts')
+@resource("/v1/superuser/customcerts")
class SuperUserCustomCertificates(ApiResource):
- """ Resource for managing custom certificates. """
+ """ Resource for managing custom certificates. """
- @nickname('getCustomCertificates')
- def get(self):
- has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
- extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
- if extra_certs_found is None:
- return {
- 'status': 'file' if has_extra_certs_path else 'none',
- }
+ @nickname("getCustomCertificates")
+ def get(self):
+ has_extra_certs_path = config_provider.volume_file_exists(EXTRA_CA_DIRECTORY)
+ extra_certs_found = config_provider.list_volume_directory(EXTRA_CA_DIRECTORY)
+ if extra_certs_found is None:
+ return {"status": "file" if has_extra_certs_path else "none"}
- cert_views = []
- for extra_cert_path in extra_certs_found:
- try:
- cert_full_path = config_provider.get_volume_path(EXTRA_CA_DIRECTORY, extra_cert_path)
- with config_provider.get_volume_file(cert_full_path) as f:
- certificate = load_certificate(f.read())
- cert_views.append({
- 'path': extra_cert_path,
- 'names': list(certificate.names),
- 'expired': certificate.expired,
- })
- except CertInvalidException as cie:
- cert_views.append({
- 'path': extra_cert_path,
- 'error': cie.message,
- })
- except IOError as ioe:
- cert_views.append({
- 'path': extra_cert_path,
- 'error': ioe.message,
- })
+ cert_views = []
+ for extra_cert_path in extra_certs_found:
+ try:
+ cert_full_path = config_provider.get_volume_path(
+ EXTRA_CA_DIRECTORY, extra_cert_path
+ )
+ with config_provider.get_volume_file(cert_full_path) as f:
+ certificate = load_certificate(f.read())
+ cert_views.append(
+ {
+ "path": extra_cert_path,
+ "names": list(certificate.names),
+ "expired": certificate.expired,
+ }
+ )
+ except CertInvalidException as cie:
+ cert_views.append({"path": extra_cert_path, "error": cie.message})
+ except IOError as ioe:
+ cert_views.append({"path": extra_cert_path, "error": ioe.message})
- return {
- 'status': 'directory',
- 'certs': cert_views,
- }
+ return {"status": "directory", "certs": cert_views}
-@resource('/v1/superuser/keys')
+@resource("/v1/superuser/keys")
class SuperUserServiceKeyManagement(ApiResource):
- """ Resource for managing service keys."""
- schemas = {
- 'CreateServiceKey': {
- 'id': 'CreateServiceKey',
- 'type': 'object',
- 'description': 'Description of creation of a service key',
- 'required': ['service', 'expiration'],
- 'properties': {
- 'service': {
- 'type': 'string',
- 'description': 'The service authenticating with this key',
- },
- 'name': {
- 'type': 'string',
- 'description': 'The friendly name of a service key',
- },
- 'metadata': {
- 'type': 'object',
- 'description': 'The key/value pairs of this key\'s metadata',
- },
- 'notes': {
- 'type': 'string',
- 'description': 'If specified, the extra notes for the key',
- },
- 'expiration': {
- 'description': 'The expiration date as a unix timestamp',
- 'anyOf': [{'type': 'number'}, {'type': 'null'}],
- },
- },
- },
- }
+ """ Resource for managing service keys."""
- @nickname('listServiceKeys')
- def get(self):
- keys = pre_oci_model.list_all_service_keys()
-
- return jsonify({
- 'keys': [key.to_dict() for key in keys],
- })
-
- @nickname('createServiceKey')
- @validate_json_request('CreateServiceKey')
- def post(self):
- body = request.get_json()
-
- # Ensure we have a valid expiration date if specified.
- expiration_date = body.get('expiration', None)
- if expiration_date is not None:
- try:
- expiration_date = datetime.utcfromtimestamp(float(expiration_date))
- except ValueError as ve:
- raise InvalidRequest('Invalid expiration date: %s' % ve)
-
- if expiration_date <= datetime.now():
- raise InvalidRequest('Expiration date cannot be in the past')
-
- # Create the metadata for the key.
- metadata = body.get('metadata', {})
- metadata.update({
- 'created_by': 'Quay Superuser Panel',
- 'ip': request.remote_addr,
- })
-
- # Generate a key with a private key that we *never save*.
- (private_key, key_id) = pre_oci_model.generate_service_key(body['service'], expiration_date,
- metadata=metadata,
- name=body.get('name', ''))
- # Auto-approve the service key.
- pre_oci_model.approve_service_key(key_id, ServiceKeyApprovalType.SUPERUSER,
- notes=body.get('notes', ''))
-
- # Log the creation and auto-approval of the service key.
- key_log_metadata = {
- 'kid': key_id,
- 'preshared': True,
- 'service': body['service'],
- 'name': body.get('name', ''),
- 'expiration_date': expiration_date,
- 'auto_approved': True,
+ schemas = {
+ "CreateServiceKey": {
+ "id": "CreateServiceKey",
+ "type": "object",
+ "description": "Description of creation of a service key",
+ "required": ["service", "expiration"],
+ "properties": {
+ "service": {
+ "type": "string",
+ "description": "The service authenticating with this key",
+ },
+ "name": {
+ "type": "string",
+ "description": "The friendly name of a service key",
+ },
+ "metadata": {
+ "type": "object",
+ "description": "The key/value pairs of this key's metadata",
+ },
+ "notes": {
+ "type": "string",
+ "description": "If specified, the extra notes for the key",
+ },
+ "expiration": {
+ "description": "The expiration date as a unix timestamp",
+ "anyOf": [{"type": "number"}, {"type": "null"}],
+ },
+ },
+ }
}
- log_action('service_key_create', None, key_log_metadata)
- log_action('service_key_approve', None, key_log_metadata)
+ @nickname("listServiceKeys")
+ def get(self):
+ keys = pre_oci_model.list_all_service_keys()
- return jsonify({
- 'kid': key_id,
- 'name': body.get('name', ''),
- 'service': body['service'],
- 'public_key': private_key.publickey().exportKey('PEM'),
- 'private_key': private_key.exportKey('PEM'),
- })
+ return jsonify({"keys": [key.to_dict() for key in keys]})
-@resource('/v1/superuser/approvedkeys/')
+ @nickname("createServiceKey")
+ @validate_json_request("CreateServiceKey")
+ def post(self):
+ body = request.get_json()
+
+ # Ensure we have a valid expiration date if specified.
+ expiration_date = body.get("expiration", None)
+ if expiration_date is not None:
+ try:
+ expiration_date = datetime.utcfromtimestamp(float(expiration_date))
+ except ValueError as ve:
+ raise InvalidRequest("Invalid expiration date: %s" % ve)
+
+ if expiration_date <= datetime.now():
+ raise InvalidRequest("Expiration date cannot be in the past")
+
+ # Create the metadata for the key.
+ metadata = body.get("metadata", {})
+ metadata.update(
+ {"created_by": "Quay Superuser Panel", "ip": request.remote_addr}
+ )
+
+ # Generate a key with a private key that we *never save*.
+ (private_key, key_id) = pre_oci_model.generate_service_key(
+ body["service"],
+ expiration_date,
+ metadata=metadata,
+ name=body.get("name", ""),
+ )
+ # Auto-approve the service key.
+ pre_oci_model.approve_service_key(
+ key_id, ServiceKeyApprovalType.SUPERUSER, notes=body.get("notes", "")
+ )
+
+ # Log the creation and auto-approval of the service key.
+ key_log_metadata = {
+ "kid": key_id,
+ "preshared": True,
+ "service": body["service"],
+ "name": body.get("name", ""),
+ "expiration_date": expiration_date,
+ "auto_approved": True,
+ }
+
+ log_action("service_key_create", None, key_log_metadata)
+ log_action("service_key_approve", None, key_log_metadata)
+
+ return jsonify(
+ {
+ "kid": key_id,
+ "name": body.get("name", ""),
+ "service": body["service"],
+ "public_key": private_key.publickey().exportKey("PEM"),
+ "private_key": private_key.exportKey("PEM"),
+ }
+ )
+
+
+@resource("/v1/superuser/approvedkeys/")
class SuperUserServiceKeyApproval(ApiResource):
- """ Resource for approving service keys. """
+ """ Resource for approving service keys. """
- schemas = {
- 'ApproveServiceKey': {
- 'id': 'ApproveServiceKey',
- 'type': 'object',
- 'description': 'Information for approving service keys',
- 'properties': {
- 'notes': {
- 'type': 'string',
- 'description': 'Optional approval notes',
- },
- },
- },
- }
+ schemas = {
+ "ApproveServiceKey": {
+ "id": "ApproveServiceKey",
+ "type": "object",
+ "description": "Information for approving service keys",
+ "properties": {
+ "notes": {"type": "string", "description": "Optional approval notes"}
+ },
+ }
+ }
- @nickname('approveServiceKey')
- @validate_json_request('ApproveServiceKey')
- def post(self, kid):
- notes = request.get_json().get('notes', '')
- try:
- key = pre_oci_model.approve_service_key(kid, ServiceKeyApprovalType.SUPERUSER, notes=notes)
+ @nickname("approveServiceKey")
+ @validate_json_request("ApproveServiceKey")
+ def post(self, kid):
+ notes = request.get_json().get("notes", "")
+ try:
+ key = pre_oci_model.approve_service_key(
+ kid, ServiceKeyApprovalType.SUPERUSER, notes=notes
+ )
- # Log the approval of the service key.
- key_log_metadata = {
- 'kid': kid,
- 'service': key.service,
- 'name': key.name,
- 'expiration_date': key.expiration_date,
- }
+ # Log the approval of the service key.
+ key_log_metadata = {
+ "kid": kid,
+ "service": key.service,
+ "name": key.name,
+ "expiration_date": key.expiration_date,
+ }
- # Note: this may not actually be the current person modifying the config, but if they're in the config tool,
- # they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
- super_user = app.config.get('SUPER_USERS', [None])[0]
- log_action('service_key_approve', super_user, key_log_metadata)
- except ServiceKeyDoesNotExist:
- raise NotFound()
- except ServiceKeyAlreadyApproved:
- pass
+ # Note: this may not actually be the current person modifying the config, but if they're in the config tool,
+ # they have full access to the DB and could pretend to be any user, so pulling any superuser is likely fine
+ super_user = app.config.get("SUPER_USERS", [None])[0]
+ log_action("service_key_approve", super_user, key_log_metadata)
+ except ServiceKeyDoesNotExist:
+ raise NotFound()
+ except ServiceKeyAlreadyApproved:
+ pass
- return make_response('', 201)
+ return make_response("", 201)
diff --git a/config_app/config_endpoints/api/superuser_models_interface.py b/config_app/config_endpoints/api/superuser_models_interface.py
index 53efc9aec..efd8a0f04 100644
--- a/config_app/config_endpoints/api/superuser_models_interface.py
+++ b/config_app/config_endpoints/api/superuser_models_interface.py
@@ -6,21 +6,33 @@ from config_app.config_endpoints.api import format_date
def user_view(user):
- return {
- 'name': user.username,
- 'kind': 'user',
- 'is_robot': user.robot,
- }
+ return {"name": user.username, "kind": "user", "is_robot": user.robot}
-class RepositoryBuild(namedtuple('RepositoryBuild',
- ['uuid', 'logs_archived', 'repository_namespace_user_username',
- 'repository_name',
- 'can_write', 'can_read', 'pull_robot', 'resource_key', 'trigger',
- 'display_name',
- 'started', 'job_config', 'phase', 'status', 'error',
- 'archive_url'])):
- """
+class RepositoryBuild(
+ namedtuple(
+ "RepositoryBuild",
+ [
+ "uuid",
+ "logs_archived",
+ "repository_namespace_user_username",
+ "repository_name",
+ "can_write",
+ "can_read",
+ "pull_robot",
+ "resource_key",
+ "trigger",
+ "display_name",
+ "started",
+ "job_config",
+ "phase",
+ "status",
+ "error",
+ "archive_url",
+ ],
+ )
+):
+ """
RepositoryBuild represents a build associated with a repostiory
:type uuid: string
:type logs_archived: boolean
@@ -40,42 +52,46 @@ class RepositoryBuild(namedtuple('RepositoryBuild',
:type archive_url: string
"""
- def to_dict(self):
+ def to_dict(self):
- resp = {
- 'id': self.uuid,
- 'phase': self.phase,
- 'started': format_date(self.started),
- 'display_name': self.display_name,
- 'status': self.status or {},
- 'subdirectory': self.job_config.get('build_subdir', ''),
- 'dockerfile_path': self.job_config.get('build_subdir', ''),
- 'context': self.job_config.get('context', ''),
- 'tags': self.job_config.get('docker_tags', []),
- 'manual_user': self.job_config.get('manual_user', None),
- 'is_writer': self.can_write,
- 'trigger': self.trigger.to_dict(),
- 'trigger_metadata': self.job_config.get('trigger_metadata', None) if self.can_read else None,
- 'resource_key': self.resource_key,
- 'pull_robot': user_view(self.pull_robot) if self.pull_robot else None,
- 'repository': {
- 'namespace': self.repository_namespace_user_username,
- 'name': self.repository_name
- },
- 'error': self.error,
- }
+ resp = {
+ "id": self.uuid,
+ "phase": self.phase,
+ "started": format_date(self.started),
+ "display_name": self.display_name,
+ "status": self.status or {},
+ "subdirectory": self.job_config.get("build_subdir", ""),
+ "dockerfile_path": self.job_config.get("build_subdir", ""),
+ "context": self.job_config.get("context", ""),
+ "tags": self.job_config.get("docker_tags", []),
+ "manual_user": self.job_config.get("manual_user", None),
+ "is_writer": self.can_write,
+ "trigger": self.trigger.to_dict(),
+ "trigger_metadata": self.job_config.get("trigger_metadata", None)
+ if self.can_read
+ else None,
+ "resource_key": self.resource_key,
+ "pull_robot": user_view(self.pull_robot) if self.pull_robot else None,
+ "repository": {
+ "namespace": self.repository_namespace_user_username,
+ "name": self.repository_name,
+ },
+ "error": self.error,
+ }
- if self.can_write:
- if self.resource_key is not None:
- resp['archive_url'] = self.archive_url
- elif self.job_config.get('archive_url', None):
- resp['archive_url'] = self.job_config['archive_url']
+ if self.can_write:
+ if self.resource_key is not None:
+ resp["archive_url"] = self.archive_url
+ elif self.job_config.get("archive_url", None):
+ resp["archive_url"] = self.job_config["archive_url"]
- return resp
+ return resp
-class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_date', 'notes'])):
- """
+class Approval(
+ namedtuple("Approval", ["approver", "approval_type", "approved_date", "notes"])
+):
+ """
Approval represents whether a key has been approved or not
:type approver: User
:type approval_type: string
@@ -83,19 +99,32 @@ class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_da
:type notes: string
"""
- def to_dict(self):
- return {
- 'approver': self.approver.to_dict() if self.approver else None,
- 'approval_type': self.approval_type,
- 'approved_date': self.approved_date,
- 'notes': self.notes,
- }
+ def to_dict(self):
+ return {
+ "approver": self.approver.to_dict() if self.approver else None,
+ "approval_type": self.approval_type,
+ "approved_date": self.approved_date,
+ "notes": self.notes,
+ }
class ServiceKey(
- namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'metadata', 'created_date',
- 'expiration_date', 'rotation_duration', 'approval'])):
- """
+ namedtuple(
+ "ServiceKey",
+ [
+ "name",
+ "kid",
+ "service",
+ "jwk",
+ "metadata",
+ "created_date",
+ "expiration_date",
+ "rotation_duration",
+ "approval",
+ ],
+ )
+):
+ """
ServiceKey is an apostille signing key
:type name: string
:type kid: int
@@ -109,22 +138,22 @@ class ServiceKey(
"""
- def to_dict(self):
- return {
- 'name': self.name,
- 'kid': self.kid,
- 'service': self.service,
- 'jwk': self.jwk,
- 'metadata': self.metadata,
- 'created_date': self.created_date,
- 'expiration_date': self.expiration_date,
- 'rotation_duration': self.rotation_duration,
- 'approval': self.approval.to_dict() if self.approval is not None else None,
- }
+ def to_dict(self):
+ return {
+ "name": self.name,
+ "kid": self.kid,
+ "service": self.service,
+ "jwk": self.jwk,
+ "metadata": self.metadata,
+ "created_date": self.created_date,
+ "expiration_date": self.expiration_date,
+ "rotation_duration": self.rotation_duration,
+ "approval": self.approval.to_dict() if self.approval is not None else None,
+ }
-class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robot'])):
- """
+class User(namedtuple("User", ["username", "email", "verified", "enabled", "robot"])):
+ """
User represents a single user.
:type username: string
:type email: string
@@ -133,41 +162,38 @@ class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robo
:type robot: User
"""
- def to_dict(self):
- user_data = {
- 'kind': 'user',
- 'name': self.username,
- 'username': self.username,
- 'email': self.email,
- 'verified': self.verified,
- 'enabled': self.enabled,
- }
+ def to_dict(self):
+ user_data = {
+ "kind": "user",
+ "name": self.username,
+ "username": self.username,
+ "email": self.email,
+ "verified": self.verified,
+ "enabled": self.enabled,
+ }
- return user_data
+ return user_data
-class Organization(namedtuple('Organization', ['username', 'email'])):
- """
+class Organization(namedtuple("Organization", ["username", "email"])):
+ """
Organization represents a single org.
:type username: string
:type email: string
"""
- def to_dict(self):
- return {
- 'name': self.username,
- 'email': self.email,
- }
+ def to_dict(self):
+ return {"name": self.username, "email": self.email}
@add_metaclass(ABCMeta)
class SuperuserDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by a superuser api.
"""
- @abstractmethod
- def list_all_service_keys(self):
- """
+ @abstractmethod
+ def list_all_service_keys(self):
+ """
Returns a list of service keys
"""
diff --git a/config_app/config_endpoints/api/superuser_models_pre_oci.py b/config_app/config_endpoints/api/superuser_models_pre_oci.py
index c35b94243..37864ceee 100644
--- a/config_app/config_endpoints/api/superuser_models_pre_oci.py
+++ b/config_app/config_endpoints/api/superuser_models_pre_oci.py
@@ -1,60 +1,85 @@
from data import model
-from config_app.config_endpoints.api.superuser_models_interface import (SuperuserDataInterface, User, ServiceKey,
- Approval)
+from config_app.config_endpoints.api.superuser_models_interface import (
+ SuperuserDataInterface,
+ User,
+ ServiceKey,
+ Approval,
+)
def _create_user(user):
- if user is None:
- return None
- return User(user.username, user.email, user.verified, user.enabled, user.robot)
+ if user is None:
+ return None
+ return User(user.username, user.email, user.verified, user.enabled, user.robot)
def _create_key(key):
- approval = None
- if key.approval is not None:
- approval = Approval(_create_user(key.approval.approver), key.approval.approval_type,
- key.approval.approved_date,
- key.approval.notes)
+ approval = None
+ if key.approval is not None:
+ approval = Approval(
+ _create_user(key.approval.approver),
+ key.approval.approval_type,
+ key.approval.approved_date,
+ key.approval.notes,
+ )
- return ServiceKey(key.name, key.kid, key.service, key.jwk, key.metadata, key.created_date,
- key.expiration_date,
- key.rotation_duration, approval)
+ return ServiceKey(
+ key.name,
+ key.kid,
+ key.service,
+ key.jwk,
+ key.metadata,
+ key.created_date,
+ key.expiration_date,
+ key.rotation_duration,
+ approval,
+ )
class ServiceKeyDoesNotExist(Exception):
- pass
+ pass
class ServiceKeyAlreadyApproved(Exception):
- pass
+ pass
class PreOCIModel(SuperuserDataInterface):
- """
+ """
PreOCIModel implements the data model for the SuperUser using a database schema
before it was changed to support the OCI specification.
"""
- def list_all_service_keys(self):
- keys = model.service_keys.list_all_keys()
- return [_create_key(key) for key in keys]
+ def list_all_service_keys(self):
+ keys = model.service_keys.list_all_keys()
+ return [_create_key(key) for key in keys]
- def approve_service_key(self, kid, approval_type, notes=''):
- try:
- key = model.service_keys.approve_service_key(kid, approval_type, notes=notes)
- return _create_key(key)
- except model.ServiceKeyDoesNotExist:
- raise ServiceKeyDoesNotExist
- except model.ServiceKeyAlreadyApproved:
- raise ServiceKeyAlreadyApproved
+ def approve_service_key(self, kid, approval_type, notes=""):
+ try:
+ key = model.service_keys.approve_service_key(
+ kid, approval_type, notes=notes
+ )
+ return _create_key(key)
+ except model.ServiceKeyDoesNotExist:
+ raise ServiceKeyDoesNotExist
+ except model.ServiceKeyAlreadyApproved:
+ raise ServiceKeyAlreadyApproved
- def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None,
- rotation_duration=None):
- (private_key, key) = model.service_keys.generate_service_key(service, expiration_date,
- metadata=metadata, name=name)
+ def generate_service_key(
+ self,
+ service,
+ expiration_date,
+ kid=None,
+ name="",
+ metadata=None,
+ rotation_duration=None,
+ ):
+ (private_key, key) = model.service_keys.generate_service_key(
+ service, expiration_date, metadata=metadata, name=name
+ )
- return private_key, key.kid
+ return private_key, key.kid
pre_oci_model = PreOCIModel()
diff --git a/config_app/config_endpoints/api/tar_config_loader.py b/config_app/config_endpoints/api/tar_config_loader.py
index 8944d9092..06b57cb85 100644
--- a/config_app/config_endpoints/api/tar_config_loader.py
+++ b/config_app/config_endpoints/api/tar_config_loader.py
@@ -10,53 +10,59 @@ from data.database import configure
from config_app.c_app import app, config_provider
from config_app.config_endpoints.api import resource, ApiResource, nickname
-from config_app.config_util.tar import tarinfo_filter_partial, strip_absolute_path_and_add_trailing_dir
+from config_app.config_util.tar import (
+ tarinfo_filter_partial,
+ strip_absolute_path_and_add_trailing_dir,
+)
-@resource('/v1/configapp/initialization')
+@resource("/v1/configapp/initialization")
class ConfigInitialization(ApiResource):
- """
+ """
Resource for dealing with any initialization logic for the config app
"""
- @nickname('scStartNewConfig')
- def post(self):
- config_provider.new_config_dir()
- return make_response('OK')
+ @nickname("scStartNewConfig")
+ def post(self):
+ config_provider.new_config_dir()
+ return make_response("OK")
-@resource('/v1/configapp/tarconfig')
+@resource("/v1/configapp/tarconfig")
class TarConfigLoader(ApiResource):
- """
+ """
Resource for dealing with configuration as a tarball,
including loading and generating functions
"""
- @nickname('scGetConfigTarball')
- def get(self):
- config_path = config_provider.get_config_dir_path()
- tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
- temp = tempfile.NamedTemporaryFile()
+ @nickname("scGetConfigTarball")
+ def get(self):
+ config_path = config_provider.get_config_dir_path()
+ tar_dir_prefix = strip_absolute_path_and_add_trailing_dir(config_path)
+ temp = tempfile.NamedTemporaryFile()
- with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
- for name in os.listdir(config_path):
- tar.add(os.path.join(config_path, name), filter=tarinfo_filter_partial(tar_dir_prefix))
- return send_file(temp.name, mimetype='application/gzip')
+ with closing(tarfile.open(temp.name, mode="w|gz")) as tar:
+ for name in os.listdir(config_path):
+ tar.add(
+ os.path.join(config_path, name),
+ filter=tarinfo_filter_partial(tar_dir_prefix),
+ )
+ return send_file(temp.name, mimetype="application/gzip")
- @nickname('scUploadTarballConfig')
- def put(self):
- """ Loads tarball config into the config provider """
- # Generate a new empty dir to load the config into
- config_provider.new_config_dir()
- input_stream = request.stream
- with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream:
- tar_stream.extractall(config_provider.get_config_dir_path())
+ @nickname("scUploadTarballConfig")
+ def put(self):
+ """ Loads tarball config into the config provider """
+ # Generate a new empty dir to load the config into
+ config_provider.new_config_dir()
+ input_stream = request.stream
+ with tarfile.open(mode="r|gz", fileobj=input_stream) as tar_stream:
+ tar_stream.extractall(config_provider.get_config_dir_path())
- config_provider.create_copy_of_config_dir()
+ config_provider.create_copy_of_config_dir()
- # now try to connect to the db provided in their config to validate it works
- combined = dict(**app.config)
- combined.update(config_provider.get_config())
- configure(combined)
+ # now try to connect to the db provided in their config to validate it works
+ combined = dict(**app.config)
+ combined.update(config_provider.get_config())
+ configure(combined)
- return make_response('OK')
+ return make_response("OK")
diff --git a/config_app/config_endpoints/api/user.py b/config_app/config_endpoints/api/user.py
index 85008c87e..9ab787a47 100644
--- a/config_app/config_endpoints/api/user.py
+++ b/config_app/config_endpoints/api/user.py
@@ -3,12 +3,12 @@ from config_app.config_endpoints.api import resource, ApiResource, nickname
from config_app.config_endpoints.api.superuser_models_interface import user_view
-@resource('/v1/user/')
+@resource("/v1/user/")
class User(ApiResource):
- """ Operations related to users. """
+ """ Operations related to users. """
- @nickname('getLoggedInUser')
- def get(self):
- """ Get user information for the authenticated user. """
- user = get_authenticated_user()
- return user_view(user)
+ @nickname("getLoggedInUser")
+ def get(self):
+ """ Get user information for the authenticated user. """
+ user = get_authenticated_user()
+ return user_view(user)
diff --git a/config_app/config_endpoints/common.py b/config_app/config_endpoints/common.py
index c277f3b35..1cf874d5d 100644
--- a/config_app/config_endpoints/common.py
+++ b/config_app/config_endpoints/common.py
@@ -14,60 +14,72 @@ from config_app.config_util.k8sconfig import get_k8s_namespace
def truthy_bool(param):
- return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
+ return param not in {False, "false", "False", "0", "FALSE", "", "null"}
-DEFAULT_JS_BUNDLE_NAME = 'configapp'
-PARAM_REGEX = re.compile(r'<([^:>]+:)*([\w]+)>')
+DEFAULT_JS_BUNDLE_NAME = "configapp"
+PARAM_REGEX = re.compile(r"<([^:>]+:)*([\w]+)>")
logger = logging.getLogger(__name__)
TYPE_CONVERTER = {
- truthy_bool: 'boolean',
- str: 'string',
- basestring: 'string',
- reqparse.text_type: 'string',
- int: 'integer',
+ truthy_bool: "boolean",
+ str: "string",
+ basestring: "string",
+ reqparse.text_type: "string",
+ int: "integer",
}
def _list_files(path, extension, contains=""):
- """ Returns a list of all the files with the given extension found under the given path. """
+ """ Returns a list of all the files with the given extension found under the given path. """
- def matches(f):
- return os.path.splitext(f)[1] == '.' + extension and contains in os.path.splitext(f)[0]
+ def matches(f):
+ return (
+ os.path.splitext(f)[1] == "." + extension
+ and contains in os.path.splitext(f)[0]
+ )
- def join_path(dp, f):
- # Remove the static/ prefix. It is added in the template.
- return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len('config_app/static/'):]
+ def join_path(dp, f):
+ # Remove the static/ prefix. It is added in the template.
+ return os.path.join(dp, f)[len(ROOT_DIR) + 1 + len("config_app/static/") :]
- filepath = os.path.join(os.path.join(ROOT_DIR, 'config_app/static/'), path)
- return [join_path(dp, f) for dp, _, files in os.walk(filepath) for f in files if matches(f)]
+ filepath = os.path.join(os.path.join(ROOT_DIR, "config_app/static/"), path)
+ return [
+ join_path(dp, f)
+ for dp, _, files in os.walk(filepath)
+ for f in files
+ if matches(f)
+ ]
-FONT_AWESOME_4 = 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css'
+FONT_AWESOME_4 = "netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css"
-def render_page_template(name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs):
- """ Renders the page template with the given name as the response and returns its contents. """
- main_scripts = _list_files('build', 'js', js_bundle_name)
+def render_page_template(
+ name, route_data=None, js_bundle_name=DEFAULT_JS_BUNDLE_NAME, **kwargs
+):
+ """ Renders the page template with the given name as the response and returns its contents. """
+ main_scripts = _list_files("build", "js", js_bundle_name)
- use_cdn = os.getenv('TESTING') == 'true'
+ use_cdn = os.getenv("TESTING") == "true"
- external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
- external_scripts = get_external_javascript(local=not use_cdn)
+ external_styles = get_external_css(local=not use_cdn, exclude=FONT_AWESOME_4)
+ external_scripts = get_external_javascript(local=not use_cdn)
- contents = render_template(name,
- route_data=route_data,
- main_scripts=main_scripts,
- external_styles=external_styles,
- external_scripts=external_scripts,
- config_set=frontend_visible_config(app.config),
- kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
- **kwargs)
+ contents = render_template(
+ name,
+ route_data=route_data,
+ main_scripts=main_scripts,
+ external_styles=external_styles,
+ external_scripts=external_scripts,
+ config_set=frontend_visible_config(app.config),
+ kubernetes_namespace=IS_KUBERNETES and get_k8s_namespace(),
+ **kwargs
+ )
- resp = make_response(contents)
- resp.headers['X-FRAME-OPTIONS'] = 'DENY'
- return resp
+ resp = make_response(contents)
+ resp.headers["X-FRAME-OPTIONS"] = "DENY"
+ return resp
def fully_qualified_name(method_view_class):
- return '%s.%s' % (method_view_class.__module__, method_view_class.__name__)
+ return "%s.%s" % (method_view_class.__module__, method_view_class.__name__)
diff --git a/config_app/config_endpoints/exception.py b/config_app/config_endpoints/exception.py
index 7f7f75a41..03e29fba9 100644
--- a/config_app/config_endpoints/exception.py
+++ b/config_app/config_endpoints/exception.py
@@ -5,11 +5,11 @@ from werkzeug.exceptions import HTTPException
class ApiErrorType(Enum):
- invalid_request = 'invalid_request'
+ invalid_request = "invalid_request"
class ApiException(HTTPException):
- """
+ """
Represents an error in the application/problem+json format.
See: https://tools.ietf.org/html/rfc7807
@@ -31,36 +31,42 @@ class ApiException(HTTPException):
information if dereferenced.
"""
- def __init__(self, error_type, status_code, error_description, payload=None):
- Exception.__init__(self)
- self.error_description = error_description
- self.code = status_code
- self.payload = payload
- self.error_type = error_type
- self.data = self.to_dict()
+ def __init__(self, error_type, status_code, error_description, payload=None):
+ Exception.__init__(self)
+ self.error_description = error_description
+ self.code = status_code
+ self.payload = payload
+ self.error_type = error_type
+ self.data = self.to_dict()
- super(ApiException, self).__init__(error_description, None)
+ super(ApiException, self).__init__(error_description, None)
- def to_dict(self):
- rv = dict(self.payload or ())
+ def to_dict(self):
+ rv = dict(self.payload or ())
- if self.error_description is not None:
- rv['detail'] = self.error_description
- rv['error_message'] = self.error_description # TODO: deprecate
+ if self.error_description is not None:
+ rv["detail"] = self.error_description
+ rv["error_message"] = self.error_description # TODO: deprecate
- rv['error_type'] = self.error_type.value # TODO: deprecate
- rv['title'] = self.error_type.value
- rv['type'] = url_for('api.error', error_type=self.error_type.value, _external=True)
- rv['status'] = self.code
+ rv["error_type"] = self.error_type.value # TODO: deprecate
+ rv["title"] = self.error_type.value
+ rv["type"] = url_for(
+ "api.error", error_type=self.error_type.value, _external=True
+ )
+ rv["status"] = self.code
- return rv
+ return rv
class InvalidRequest(ApiException):
- def __init__(self, error_description, payload=None):
- ApiException.__init__(self, ApiErrorType.invalid_request, 400, error_description, payload)
+ def __init__(self, error_description, payload=None):
+ ApiException.__init__(
+ self, ApiErrorType.invalid_request, 400, error_description, payload
+ )
class InvalidResponse(ApiException):
- def __init__(self, error_description, payload=None):
- ApiException.__init__(self, ApiErrorType.invalid_response, 400, error_description, payload)
+ def __init__(self, error_description, payload=None):
+ ApiException.__init__(
+ self, ApiErrorType.invalid_response, 400, error_description, payload
+ )
diff --git a/config_app/config_endpoints/setup_web.py b/config_app/config_endpoints/setup_web.py
index 32dda15e2..8db100705 100644
--- a/config_app/config_endpoints/setup_web.py
+++ b/config_app/config_endpoints/setup_web.py
@@ -5,19 +5,21 @@ from config_app.config_endpoints.common import render_page_template
from config_app.config_endpoints.api.discovery import generate_route_data
from config_app.config_endpoints.api import no_cache
-setup_web = Blueprint('setup_web', __name__, template_folder='templates')
+setup_web = Blueprint("setup_web", __name__, template_folder="templates")
@lru_cache(maxsize=1)
def _get_route_data():
- return generate_route_data()
+ return generate_route_data()
def render_page_template_with_routedata(name, *args, **kwargs):
- return render_page_template(name, _get_route_data(), *args, **kwargs)
+ return render_page_template(name, _get_route_data(), *args, **kwargs)
@no_cache
-@setup_web.route('/', methods=['GET'], defaults={'path': ''})
+@setup_web.route("/", methods=["GET"], defaults={"path": ""})
def index(path, **kwargs):
- return render_page_template_with_routedata('index.html', js_bundle_name='configapp', **kwargs)
+ return render_page_template_with_routedata(
+ "index.html", js_bundle_name="configapp", **kwargs
+ )
diff --git a/config_app/config_test/test_api_usage.py b/config_app/config_test/test_api_usage.py
index aa34b3495..4816e0aa5 100644
--- a/config_app/config_test/test_api_usage.py
+++ b/config_app/config_test/test_api_usage.py
@@ -5,204 +5,242 @@ from data import database, model
from util.security.test.test_ssl_util import generate_test_cert
from config_app.c_app import app
-from config_app.config_test import ApiTestCase, all_queues, ADMIN_ACCESS_USER, ADMIN_ACCESS_EMAIL
+from config_app.config_test import (
+ ApiTestCase,
+ all_queues,
+ ADMIN_ACCESS_USER,
+ ADMIN_ACCESS_EMAIL,
+)
from config_app.config_endpoints.api import api_bp
-from config_app.config_endpoints.api.superuser import SuperUserCustomCertificate, SuperUserCustomCertificates
-from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserCreateInitialSuperUser, \
- SuperUserConfigFile, SuperUserRegistryStatus
+from config_app.config_endpoints.api.superuser import (
+ SuperUserCustomCertificate,
+ SuperUserCustomCertificates,
+)
+from config_app.config_endpoints.api.suconfig import (
+ SuperUserConfig,
+ SuperUserCreateInitialSuperUser,
+ SuperUserConfigFile,
+ SuperUserRegistryStatus,
+)
try:
- app.register_blueprint(api_bp, url_prefix='/api')
+ app.register_blueprint(api_bp, url_prefix="/api")
except ValueError:
- # This blueprint was already registered
- pass
+ # This blueprint was already registered
+ pass
class TestSuperUserCreateInitialSuperUser(ApiTestCase):
- def test_create_superuser(self):
- data = {
- 'username': 'newsuper',
- 'password': 'password',
- 'email': 'jschorr+fake@devtable.com',
- }
+ def test_create_superuser(self):
+ data = {
+ "username": "newsuper",
+ "password": "password",
+ "email": "jschorr+fake@devtable.com",
+ }
- # Add some fake config.
- fake_config = {
- 'AUTHENTICATION_TYPE': 'Database',
- 'SECRET_KEY': 'fakekey',
- }
+ # Add some fake config.
+ fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
- self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config, hostname='fakehost'))
+ self.putJsonResponse(
+ SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
+ )
- # Try to write with config. Should 403 since there are users in the DB.
- self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
+ # Try to write with config. Should 403 since there are users in the DB.
+ self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
- # Delete all users in the DB.
- for user in list(database.User.select()):
- model.user.delete_user(user, all_queues)
+ # Delete all users in the DB.
+ for user in list(database.User.select()):
+ model.user.delete_user(user, all_queues)
- # Create the superuser.
- self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
+ # Create the superuser.
+ self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
- # Ensure the user exists in the DB.
- self.assertIsNotNone(model.user.get_user('newsuper'))
+ # Ensure the user exists in the DB.
+ self.assertIsNotNone(model.user.get_user("newsuper"))
- # Ensure that the current user is a superuser in the config.
- json = self.getJsonResponse(SuperUserConfig)
- self.assertEquals(['newsuper'], json['config']['SUPER_USERS'])
+ # Ensure that the current user is a superuser in the config.
+ json = self.getJsonResponse(SuperUserConfig)
+ self.assertEquals(["newsuper"], json["config"]["SUPER_USERS"])
- # Ensure that the current user is a superuser in memory by trying to call an API
- # that will fail otherwise.
- self.getResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
+ # Ensure that the current user is a superuser in memory by trying to call an API
+ # that will fail otherwise.
+ self.getResponse(SuperUserConfigFile, params=dict(filename="ssl.cert"))
class TestSuperUserConfig(ApiTestCase):
- def test_get_status_update_config(self):
- # With no config the status should be 'config-db'.
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('config-db', json['status'])
+ def test_get_status_update_config(self):
+ # With no config the status should be 'config-db'.
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("config-db", json["status"])
- # Add some fake config.
- fake_config = {
- 'AUTHENTICATION_TYPE': 'Database',
- 'SECRET_KEY': 'fakekey',
- }
+ # Add some fake config.
+ fake_config = {"AUTHENTICATION_TYPE": "Database", "SECRET_KEY": "fakekey"}
- json = self.putJsonResponse(SuperUserConfig, data=dict(config=fake_config,
- hostname='fakehost'))
- self.assertEquals('fakekey', json['config']['SECRET_KEY'])
- self.assertEquals('fakehost', json['config']['SERVER_HOSTNAME'])
- self.assertEquals('Database', json['config']['AUTHENTICATION_TYPE'])
+ json = self.putJsonResponse(
+ SuperUserConfig, data=dict(config=fake_config, hostname="fakehost")
+ )
+ self.assertEquals("fakekey", json["config"]["SECRET_KEY"])
+ self.assertEquals("fakehost", json["config"]["SERVER_HOSTNAME"])
+ self.assertEquals("Database", json["config"]["AUTHENTICATION_TYPE"])
- # With config the status should be 'setup-db'.
- # TODO: fix this test
- # json = self.getJsonResponse(SuperUserRegistryStatus)
- # self.assertEquals('setup-db', json['status'])
+ # With config the status should be 'setup-db'.
+ # TODO: fix this test
+ # json = self.getJsonResponse(SuperUserRegistryStatus)
+ # self.assertEquals('setup-db', json['status'])
- def test_config_file(self):
- # Try for an invalid file. Should 404.
- self.getResponse(SuperUserConfigFile, params=dict(filename='foobar'), expected_code=404)
+ def test_config_file(self):
+ # Try for an invalid file. Should 404.
+ self.getResponse(
+ SuperUserConfigFile, params=dict(filename="foobar"), expected_code=404
+ )
- # Try for a valid filename. Should not exist.
- json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
- self.assertFalse(json['exists'])
+ # Try for a valid filename. Should not exist.
+ json = self.getJsonResponse(
+ SuperUserConfigFile, params=dict(filename="ssl.cert")
+ )
+ self.assertFalse(json["exists"])
- # Add the file.
- self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'),
- file=(StringIO('my file contents'), 'ssl.cert'))
+ # Add the file.
+ self.postResponse(
+ SuperUserConfigFile,
+ params=dict(filename="ssl.cert"),
+ file=(StringIO("my file contents"), "ssl.cert"),
+ )
- # Should now exist.
- json = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
- self.assertTrue(json['exists'])
+ # Should now exist.
+ json = self.getJsonResponse(
+ SuperUserConfigFile, params=dict(filename="ssl.cert")
+ )
+ self.assertTrue(json["exists"])
- def test_update_with_external_auth(self):
- # Run a mock LDAP.
- mockldap = MockLdap({
- 'dc=quay,dc=io': {
- 'dc': ['quay', 'io']
- },
- 'ou=employees,dc=quay,dc=io': {
- 'dc': ['quay', 'io'],
- 'ou': 'employees'
- },
- 'uid=' + ADMIN_ACCESS_USER + ',ou=employees,dc=quay,dc=io': {
- 'dc': ['quay', 'io'],
- 'ou': 'employees',
- 'uid': [ADMIN_ACCESS_USER],
- 'userPassword': ['password'],
- 'mail': [ADMIN_ACCESS_EMAIL],
- },
- })
+ def test_update_with_external_auth(self):
+ # Run a mock LDAP.
+ mockldap = MockLdap(
+ {
+ "dc=quay,dc=io": {"dc": ["quay", "io"]},
+ "ou=employees,dc=quay,dc=io": {"dc": ["quay", "io"], "ou": "employees"},
+ "uid="
+ + ADMIN_ACCESS_USER
+ + ",ou=employees,dc=quay,dc=io": {
+ "dc": ["quay", "io"],
+ "ou": "employees",
+ "uid": [ADMIN_ACCESS_USER],
+ "userPassword": ["password"],
+ "mail": [ADMIN_ACCESS_EMAIL],
+ },
+ }
+ )
- config = {
- 'AUTHENTICATION_TYPE': 'LDAP',
- 'LDAP_BASE_DN': ['dc=quay', 'dc=io'],
- 'LDAP_ADMIN_DN': 'uid=devtable,ou=employees,dc=quay,dc=io',
- 'LDAP_ADMIN_PASSWD': 'password',
- 'LDAP_USER_RDN': ['ou=employees'],
- 'LDAP_UID_ATTR': 'uid',
- 'LDAP_EMAIL_ATTR': 'mail',
- }
+ config = {
+ "AUTHENTICATION_TYPE": "LDAP",
+ "LDAP_BASE_DN": ["dc=quay", "dc=io"],
+ "LDAP_ADMIN_DN": "uid=devtable,ou=employees,dc=quay,dc=io",
+ "LDAP_ADMIN_PASSWD": "password",
+ "LDAP_USER_RDN": ["ou=employees"],
+ "LDAP_UID_ATTR": "uid",
+ "LDAP_EMAIL_ATTR": "mail",
+ }
- mockldap.start()
- try:
- # Write the config with the valid password.
- self.putResponse(SuperUserConfig,
- data={'config': config,
- 'password': 'password',
- 'hostname': 'foo'}, expected_code=200)
+ mockldap.start()
+ try:
+ # Write the config with the valid password.
+ self.putResponse(
+ SuperUserConfig,
+ data={"config": config, "password": "password", "hostname": "foo"},
+ expected_code=200,
+ )
+
+ # Ensure that the user row has been linked.
+ # TODO: fix this test
+ # self.assertEquals(ADMIN_ACCESS_USER,
+ # model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
+ finally:
+ mockldap.stop()
- # Ensure that the user row has been linked.
- # TODO: fix this test
- # self.assertEquals(ADMIN_ACCESS_USER,
- # model.user.verify_federated_login('ldap', ADMIN_ACCESS_USER).username)
- finally:
- mockldap.stop()
class TestSuperUserCustomCertificates(ApiTestCase):
- def test_custom_certificates(self):
+ def test_custom_certificates(self):
- # Upload a certificate.
- cert_contents, _ = generate_test_cert(hostname='somecoolhost', san_list=['DNS:bar', 'DNS:baz'])
- self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
- file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204)
+ # Upload a certificate.
+ cert_contents, _ = generate_test_cert(
+ hostname="somecoolhost", san_list=["DNS:bar", "DNS:baz"]
+ )
+ self.postResponse(
+ SuperUserCustomCertificate,
+ params=dict(certpath="testcert.crt"),
+ file=(StringIO(cert_contents), "testcert.crt"),
+ expected_code=204,
+ )
- # Make sure it is present.
- json = self.getJsonResponse(SuperUserCustomCertificates)
- self.assertEquals(1, len(json['certs']))
+ # Make sure it is present.
+ json = self.getJsonResponse(SuperUserCustomCertificates)
+ self.assertEquals(1, len(json["certs"]))
- cert_info = json['certs'][0]
- self.assertEquals('testcert.crt', cert_info['path'])
+ cert_info = json["certs"][0]
+ self.assertEquals("testcert.crt", cert_info["path"])
- self.assertEquals(set(['somecoolhost', 'bar', 'baz']), set(cert_info['names']))
- self.assertFalse(cert_info['expired'])
+ self.assertEquals(set(["somecoolhost", "bar", "baz"]), set(cert_info["names"]))
+ self.assertFalse(cert_info["expired"])
- # Remove the certificate.
- self.deleteResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'))
+ # Remove the certificate.
+ self.deleteResponse(
+ SuperUserCustomCertificate, params=dict(certpath="testcert.crt")
+ )
- # Make sure it is gone.
- json = self.getJsonResponse(SuperUserCustomCertificates)
- self.assertEquals(0, len(json['certs']))
+ # Make sure it is gone.
+ json = self.getJsonResponse(SuperUserCustomCertificates)
+ self.assertEquals(0, len(json["certs"]))
- def test_expired_custom_certificate(self):
- # Upload a certificate.
- cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10)
- self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
- file=(StringIO(cert_contents), 'testcert.crt'), expected_code=204)
+ def test_expired_custom_certificate(self):
+ # Upload a certificate.
+ cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
+ self.postResponse(
+ SuperUserCustomCertificate,
+ params=dict(certpath="testcert.crt"),
+ file=(StringIO(cert_contents), "testcert.crt"),
+ expected_code=204,
+ )
- # Make sure it is present.
- json = self.getJsonResponse(SuperUserCustomCertificates)
- self.assertEquals(1, len(json['certs']))
+ # Make sure it is present.
+ json = self.getJsonResponse(SuperUserCustomCertificates)
+ self.assertEquals(1, len(json["certs"]))
- cert_info = json['certs'][0]
- self.assertEquals('testcert.crt', cert_info['path'])
+ cert_info = json["certs"][0]
+ self.assertEquals("testcert.crt", cert_info["path"])
- self.assertEquals(set(['somecoolhost']), set(cert_info['names']))
- self.assertTrue(cert_info['expired'])
+ self.assertEquals(set(["somecoolhost"]), set(cert_info["names"]))
+ self.assertTrue(cert_info["expired"])
- def test_invalid_custom_certificate(self):
- # Upload an invalid certificate.
- self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert.crt'),
- file=(StringIO('some contents'), 'testcert.crt'), expected_code=204)
+ def test_invalid_custom_certificate(self):
+ # Upload an invalid certificate.
+ self.postResponse(
+ SuperUserCustomCertificate,
+ params=dict(certpath="testcert.crt"),
+ file=(StringIO("some contents"), "testcert.crt"),
+ expected_code=204,
+ )
- # Make sure it is present but invalid.
- json = self.getJsonResponse(SuperUserCustomCertificates)
- self.assertEquals(1, len(json['certs']))
+ # Make sure it is present but invalid.
+ json = self.getJsonResponse(SuperUserCustomCertificates)
+ self.assertEquals(1, len(json["certs"]))
- cert_info = json['certs'][0]
- self.assertEquals('testcert.crt', cert_info['path'])
- self.assertEquals('no start line', cert_info['error'])
+ cert_info = json["certs"][0]
+ self.assertEquals("testcert.crt", cert_info["path"])
+ self.assertEquals("no start line", cert_info["error"])
- def test_path_sanitization(self):
- # Upload a certificate.
- cert_contents, _ = generate_test_cert(hostname='somecoolhost', expires=-10)
- self.postResponse(SuperUserCustomCertificate, params=dict(certpath='testcert/../foobar.crt'),
- file=(StringIO(cert_contents), 'testcert/../foobar.crt'), expected_code=204)
+ def test_path_sanitization(self):
+ # Upload a certificate.
+ cert_contents, _ = generate_test_cert(hostname="somecoolhost", expires=-10)
+ self.postResponse(
+ SuperUserCustomCertificate,
+ params=dict(certpath="testcert/../foobar.crt"),
+ file=(StringIO(cert_contents), "testcert/../foobar.crt"),
+ expected_code=204,
+ )
- # Make sure it is present.
- json = self.getJsonResponse(SuperUserCustomCertificates)
- self.assertEquals(1, len(json['certs']))
-
- cert_info = json['certs'][0]
- self.assertEquals('foobar.crt', cert_info['path'])
+ # Make sure it is present.
+ json = self.getJsonResponse(SuperUserCustomCertificates)
+ self.assertEquals(1, len(json["certs"]))
+ cert_info = json["certs"][0]
+ self.assertEquals("foobar.crt", cert_info["path"])
diff --git a/config_app/config_test/test_suconfig_api.py b/config_app/config_test/test_suconfig_api.py
index 408b96a8b..a805e6421 100644
--- a/config_app/config_test/test_suconfig_api.py
+++ b/config_app/config_test/test_suconfig_api.py
@@ -4,176 +4,235 @@ import mock
from data.database import User
from data import model
-from config_app.config_endpoints.api.suconfig import SuperUserConfig, SuperUserConfigValidate, SuperUserConfigFile, \
- SuperUserRegistryStatus, SuperUserCreateInitialSuperUser
+from config_app.config_endpoints.api.suconfig import (
+ SuperUserConfig,
+ SuperUserConfigValidate,
+ SuperUserConfigFile,
+ SuperUserRegistryStatus,
+ SuperUserCreateInitialSuperUser,
+)
from config_app.config_endpoints.api import api_bp
from config_app.config_test import ApiTestCase, READ_ACCESS_USER, ADMIN_ACCESS_USER
from config_app.c_app import app, config_provider
try:
- app.register_blueprint(api_bp, url_prefix='/api')
+ app.register_blueprint(api_bp, url_prefix="/api")
except ValueError:
- # This blueprint was already registered
- pass
+ # This blueprint was already registered
+ pass
# OVERRIDES FROM PORTING FROM OLD APP:
-all_queues = [] # the config app doesn't have any queues
+all_queues = [] # the config app doesn't have any queues
+
class FreshConfigProvider(object):
- def __enter__(self):
- config_provider.reset_for_test()
- return config_provider
+ def __enter__(self):
+ config_provider.reset_for_test()
+ return config_provider
- def __exit__(self, type, value, traceback):
- config_provider.reset_for_test()
+ def __exit__(self, type, value, traceback):
+ config_provider.reset_for_test()
class TestSuperUserRegistryStatus(ApiTestCase):
- def test_registry_status_no_config(self):
- with FreshConfigProvider():
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('config-db', json['status'])
+ def test_registry_status_no_config(self):
+ with FreshConfigProvider():
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("config-db", json["status"])
- @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=False))
- def test_registry_status_no_database(self):
- with FreshConfigProvider():
- config_provider.save_config({'key': 'value'})
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('setup-db', json['status'])
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_is_valid",
+ mock.Mock(return_value=False),
+ )
+ def test_registry_status_no_database(self):
+ with FreshConfigProvider():
+ config_provider.save_config({"key": "value"})
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("setup-db", json["status"])
- @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
- def test_registry_status_db_has_superuser(self):
- with FreshConfigProvider():
- config_provider.save_config({'key': 'value'})
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('config', json['status'])
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_is_valid",
+ mock.Mock(return_value=True),
+ )
+ def test_registry_status_db_has_superuser(self):
+ with FreshConfigProvider():
+ config_provider.save_config({"key": "value"})
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("config", json["status"])
- @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
- @mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=False))
- def test_registry_status_db_no_superuser(self):
- with FreshConfigProvider():
- config_provider.save_config({'key': 'value'})
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('create-superuser', json['status'])
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_is_valid",
+ mock.Mock(return_value=True),
+ )
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_has_users",
+ mock.Mock(return_value=False),
+ )
+ def test_registry_status_db_no_superuser(self):
+ with FreshConfigProvider():
+ config_provider.save_config({"key": "value"})
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("create-superuser", json["status"])
+
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_is_valid",
+ mock.Mock(return_value=True),
+ )
+ @mock.patch(
+ "config_app.config_endpoints.api.suconfig.database_has_users",
+ mock.Mock(return_value=True),
+ )
+ def test_registry_status_setup_complete(self):
+ with FreshConfigProvider():
+ config_provider.save_config({"key": "value", "SETUP_COMPLETE": True})
+ json = self.getJsonResponse(SuperUserRegistryStatus)
+ self.assertEquals("config", json["status"])
- @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True))
- @mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=True))
- def test_registry_status_setup_complete(self):
- with FreshConfigProvider():
- config_provider.save_config({'key': 'value', 'SETUP_COMPLETE': True})
- json = self.getJsonResponse(SuperUserRegistryStatus)
- self.assertEquals('config', json['status'])
class TestSuperUserConfigFile(ApiTestCase):
- def test_get_superuser_invalid_filename(self):
- with FreshConfigProvider():
- self.getResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404)
+ def test_get_superuser_invalid_filename(self):
+ with FreshConfigProvider():
+ self.getResponse(
+ SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
+ )
- def test_get_superuser(self):
- with FreshConfigProvider():
- result = self.getJsonResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'))
- self.assertFalse(result['exists'])
+ def test_get_superuser(self):
+ with FreshConfigProvider():
+ result = self.getJsonResponse(
+ SuperUserConfigFile, params=dict(filename="ssl.cert")
+ )
+ self.assertFalse(result["exists"])
- def test_post_no_file(self):
- with FreshConfigProvider():
- # No file
- self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400)
+ def test_post_no_file(self):
+ with FreshConfigProvider():
+ # No file
+ self.postResponse(
+ SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
+ )
- def test_post_superuser_invalid_filename(self):
- with FreshConfigProvider():
- self.postResponse(SuperUserConfigFile, params=dict(filename='somefile'), expected_code=404)
+ def test_post_superuser_invalid_filename(self):
+ with FreshConfigProvider():
+ self.postResponse(
+ SuperUserConfigFile, params=dict(filename="somefile"), expected_code=404
+ )
- def test_post_superuser(self):
- with FreshConfigProvider():
- self.postResponse(SuperUserConfigFile, params=dict(filename='ssl.cert'), expected_code=400)
+ def test_post_superuser(self):
+ with FreshConfigProvider():
+ self.postResponse(
+ SuperUserConfigFile, params=dict(filename="ssl.cert"), expected_code=400
+ )
class TestSuperUserCreateInitialSuperUser(ApiTestCase):
- def test_no_config_file(self):
- with FreshConfigProvider():
- # If there is no config.yaml, then this method should security fail.
- data = dict(username='cooluser', password='password', email='fake@example.com')
- self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
+ def test_no_config_file(self):
+ with FreshConfigProvider():
+ # If there is no config.yaml, then this method should security fail.
+ data = dict(
+ username="cooluser", password="password", email="fake@example.com"
+ )
+ self.postResponse(
+ SuperUserCreateInitialSuperUser, data=data, expected_code=403
+ )
- def test_config_file_with_db_users(self):
- with FreshConfigProvider():
- # Write some config.
- self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
+ def test_config_file_with_db_users(self):
+ with FreshConfigProvider():
+ # Write some config.
+ self.putJsonResponse(
+ SuperUserConfig, data=dict(config={}, hostname="foobar")
+ )
- # If there is a config.yaml, but existing DB users exist, then this method should security
- # fail.
- data = dict(username='cooluser', password='password', email='fake@example.com')
- self.postResponse(SuperUserCreateInitialSuperUser, data=data, expected_code=403)
+ # If there is a config.yaml, but existing DB users exist, then this method should security
+ # fail.
+ data = dict(
+ username="cooluser", password="password", email="fake@example.com"
+ )
+ self.postResponse(
+ SuperUserCreateInitialSuperUser, data=data, expected_code=403
+ )
- def test_config_file_with_no_db_users(self):
- with FreshConfigProvider():
- # Write some config.
- self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
+ def test_config_file_with_no_db_users(self):
+ with FreshConfigProvider():
+ # Write some config.
+ self.putJsonResponse(
+ SuperUserConfig, data=dict(config={}, hostname="foobar")
+ )
- # Delete all the users in the DB.
- for user in list(User.select()):
- model.user.delete_user(user, all_queues)
+ # Delete all the users in the DB.
+ for user in list(User.select()):
+ model.user.delete_user(user, all_queues)
- # This method should now succeed.
- data = dict(username='cooluser', password='password', email='fake@example.com')
- result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
- self.assertTrue(result['status'])
+ # This method should now succeed.
+ data = dict(
+ username="cooluser", password="password", email="fake@example.com"
+ )
+ result = self.postJsonResponse(SuperUserCreateInitialSuperUser, data=data)
+ self.assertTrue(result["status"])
- # Verify the superuser was created.
- User.get(User.username == 'cooluser')
+ # Verify the superuser was created.
+ User.get(User.username == "cooluser")
- # Verify the superuser was placed into the config.
- result = self.getJsonResponse(SuperUserConfig)
- self.assertEquals(['cooluser'], result['config']['SUPER_USERS'])
+ # Verify the superuser was placed into the config.
+ result = self.getJsonResponse(SuperUserConfig)
+ self.assertEquals(["cooluser"], result["config"]["SUPER_USERS"])
class TestSuperUserConfigValidate(ApiTestCase):
- def test_nonsuperuser_noconfig(self):
- with FreshConfigProvider():
- result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'),
- data=dict(config={}))
+ def test_nonsuperuser_noconfig(self):
+ with FreshConfigProvider():
+ result = self.postJsonResponse(
+ SuperUserConfigValidate,
+ params=dict(service="someservice"),
+ data=dict(config={}),
+ )
- self.assertFalse(result['status'])
+ self.assertFalse(result["status"])
+ def test_nonsuperuser_config(self):
+ with FreshConfigProvider():
+ # The validate config call works if there is no config.yaml OR the user is a superuser.
+ # Add a config, and verify it breaks when unauthenticated.
+ json = self.putJsonResponse(
+ SuperUserConfig, data=dict(config={}, hostname="foobar")
+ )
+ self.assertTrue(json["exists"])
- def test_nonsuperuser_config(self):
- with FreshConfigProvider():
- # The validate config call works if there is no config.yaml OR the user is a superuser.
- # Add a config, and verify it breaks when unauthenticated.
- json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
- self.assertTrue(json['exists'])
+ result = self.postJsonResponse(
+ SuperUserConfigValidate,
+ params=dict(service="someservice"),
+ data=dict(config={}),
+ )
-
- result = self.postJsonResponse(SuperUserConfigValidate, params=dict(service='someservice'),
- data=dict(config={}))
-
- self.assertFalse(result['status'])
+ self.assertFalse(result["status"])
class TestSuperUserConfig(ApiTestCase):
- def test_get_superuser(self):
- with FreshConfigProvider():
- json = self.getJsonResponse(SuperUserConfig)
+ def test_get_superuser(self):
+ with FreshConfigProvider():
+ json = self.getJsonResponse(SuperUserConfig)
- # Note: We expect the config to be none because a config.yaml should never be checked into
- # the directory.
- self.assertIsNone(json['config'])
+ # Note: We expect the config to be none because a config.yaml should never be checked into
+ # the directory.
+ self.assertIsNone(json["config"])
- def test_put(self):
- with FreshConfigProvider() as config:
- json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='foobar'))
- self.assertTrue(json['exists'])
+ def test_put(self):
+ with FreshConfigProvider() as config:
+ json = self.putJsonResponse(
+ SuperUserConfig, data=dict(config={}, hostname="foobar")
+ )
+ self.assertTrue(json["exists"])
- # Verify the config file exists.
- self.assertTrue(config.config_exists())
+ # Verify the config file exists.
+ self.assertTrue(config.config_exists())
- # This should succeed.
- json = self.putJsonResponse(SuperUserConfig, data=dict(config={}, hostname='barbaz'))
- self.assertTrue(json['exists'])
+ # This should succeed.
+ json = self.putJsonResponse(
+ SuperUserConfig, data=dict(config={}, hostname="barbaz")
+ )
+ self.assertTrue(json["exists"])
- json = self.getJsonResponse(SuperUserConfig)
- self.assertIsNotNone(json['config'])
+ json = self.getJsonResponse(SuperUserConfig)
+ self.assertIsNotNone(json["config"])
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ unittest.main()
diff --git a/config_app/config_util/config/TransientDirectoryProvider.py b/config_app/config_util/config/TransientDirectoryProvider.py
index 5ac685592..a8be2710d 100644
--- a/config_app/config_util/config/TransientDirectoryProvider.py
+++ b/config_app/config_util/config/TransientDirectoryProvider.py
@@ -5,58 +5,63 @@ from backports.tempfile import TemporaryDirectory
from config_app.config_util.config.fileprovider import FileConfigProvider
-OLD_CONFIG_SUBDIR = 'old/'
+OLD_CONFIG_SUBDIR = "old/"
+
class TransientDirectoryProvider(FileConfigProvider):
- """ Implementation of the config provider that reads and writes the data
+ """ Implementation of the config provider that reads and writes the data
from/to the file system, only using temporary directories,
deleting old dirs and creating new ones as requested.
"""
- def __init__(self, config_volume, yaml_filename, py_filename):
- # Create a temp directory that will be cleaned up when we change the config path
- # This should ensure we have no "pollution" of different configs:
- # no uploaded config should ever affect subsequent config modifications/creations
- temp_dir = TemporaryDirectory()
- self.temp_dir = temp_dir
- self.old_config_dir = None
- super(TransientDirectoryProvider, self).__init__(temp_dir.name, yaml_filename, py_filename)
+ def __init__(self, config_volume, yaml_filename, py_filename):
+ # Create a temp directory that will be cleaned up when we change the config path
+ # This should ensure we have no "pollution" of different configs:
+ # no uploaded config should ever affect subsequent config modifications/creations
+ temp_dir = TemporaryDirectory()
+ self.temp_dir = temp_dir
+ self.old_config_dir = None
+ super(TransientDirectoryProvider, self).__init__(
+ temp_dir.name, yaml_filename, py_filename
+ )
- @property
- def provider_id(self):
- return 'transient'
+ @property
+ def provider_id(self):
+ return "transient"
- def new_config_dir(self):
- """
+ def new_config_dir(self):
+ """
Update the path with a new temporary directory, deleting the old one in the process
"""
- self.temp_dir.cleanup()
- temp_dir = TemporaryDirectory()
+ self.temp_dir.cleanup()
+ temp_dir = TemporaryDirectory()
- self.config_volume = temp_dir.name
- self.temp_dir = temp_dir
- self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename)
+ self.config_volume = temp_dir.name
+ self.temp_dir = temp_dir
+ self.yaml_path = os.path.join(temp_dir.name, self.yaml_filename)
- def create_copy_of_config_dir(self):
- """
+ def create_copy_of_config_dir(self):
+ """
Create a directory to store loaded/populated configuration (for rollback if necessary)
"""
- if self.old_config_dir is not None:
- self.old_config_dir.cleanup()
+ if self.old_config_dir is not None:
+ self.old_config_dir.cleanup()
- temp_dir = TemporaryDirectory()
- self.old_config_dir = temp_dir
+ temp_dir = TemporaryDirectory()
+ self.old_config_dir = temp_dir
- # Python 2.7's shutil.copy() doesn't allow for copying to existing directories,
- # so when copying/reading to the old saved config, we have to talk to a subdirectory,
- # and use the shutil.copytree() function
- copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR))
+ # Python 2.7's shutil.copy() doesn't allow for copying to existing directories,
+ # so when copying/reading to the old saved config, we have to talk to a subdirectory,
+ # and use the shutil.copytree() function
+ copytree(self.config_volume, os.path.join(temp_dir.name, OLD_CONFIG_SUBDIR))
- def get_config_dir_path(self):
- return self.config_volume
+ def get_config_dir_path(self):
+ return self.config_volume
- def get_old_config_dir(self):
- if self.old_config_dir is None:
- raise Exception('Cannot return a configuration that was no old configuration')
+ def get_old_config_dir(self):
+ if self.old_config_dir is None:
+ raise Exception(
+ "Cannot return a configuration that was no old configuration"
+ )
- return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)
+ return os.path.join(self.old_config_dir.name, OLD_CONFIG_SUBDIR)
diff --git a/config_app/config_util/config/__init__.py b/config_app/config_util/config/__init__.py
index d39d0ea1c..7429d17cf 100644
--- a/config_app/config_util/config/__init__.py
+++ b/config_app/config_util/config/__init__.py
@@ -3,37 +3,40 @@ import os
from config_app.config_util.config.fileprovider import FileConfigProvider
from config_app.config_util.config.testprovider import TestConfigProvider
-from config_app.config_util.config.TransientDirectoryProvider import TransientDirectoryProvider
+from config_app.config_util.config.TransientDirectoryProvider import (
+ TransientDirectoryProvider,
+)
from util.config.validator import EXTRA_CA_DIRECTORY, EXTRA_CA_DIRECTORY_PREFIX
def get_config_provider(config_volume, yaml_filename, py_filename, testing=False):
- """ Loads and returns the config provider for the current environment. """
+ """ Loads and returns the config provider for the current environment. """
- if testing:
- return TestConfigProvider()
+ if testing:
+ return TestConfigProvider()
- return TransientDirectoryProvider(config_volume, yaml_filename, py_filename)
+ return TransientDirectoryProvider(config_volume, yaml_filename, py_filename)
def get_config_as_kube_secret(config_path):
- data = {}
+ data = {}
- # Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir
- # we have to put the extra certs in with a prefix, and then one of our init scripts
- # (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory
- # so that they'll be installed correctly on startup by the certs_install script
- certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY)
- if os.path.exists(certs_dir):
- for extra_cert in os.listdir(certs_dir):
- with open(os.path.join(certs_dir, extra_cert)) as f:
- data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(f.read())
+ # Kubernetes secrets don't have sub-directories, so for the extra_ca_certs dir
+ # we have to put the extra certs in with a prefix, and then one of our init scripts
+ # (02_get_kube_certs.sh) will expand the prefixed certs into the equivalent directory
+ # so that they'll be installed correctly on startup by the certs_install script
+ certs_dir = os.path.join(config_path, EXTRA_CA_DIRECTORY)
+ if os.path.exists(certs_dir):
+ for extra_cert in os.listdir(certs_dir):
+ with open(os.path.join(certs_dir, extra_cert)) as f:
+ data[EXTRA_CA_DIRECTORY_PREFIX + extra_cert] = base64.b64encode(
+ f.read()
+ )
+ for name in os.listdir(config_path):
+ file_path = os.path.join(config_path, name)
+ if not os.path.isdir(file_path):
+ with open(file_path) as f:
+ data[name] = base64.b64encode(f.read())
- for name in os.listdir(config_path):
- file_path = os.path.join(config_path, name)
- if not os.path.isdir(file_path):
- with open(file_path) as f:
- data[name] = base64.b64encode(f.read())
-
- return data
+ return data
diff --git a/config_app/config_util/config/basefileprovider.py b/config_app/config_util/config/basefileprovider.py
index caf231321..ac78000d9 100644
--- a/config_app/config_util/config/basefileprovider.py
+++ b/config_app/config_util/config/basefileprovider.py
@@ -1,72 +1,76 @@
import os
import logging
-from config_app.config_util.config.baseprovider import (BaseProvider, import_yaml, export_yaml,
- CannotWriteConfigException)
+from config_app.config_util.config.baseprovider import (
+ BaseProvider,
+ import_yaml,
+ export_yaml,
+ CannotWriteConfigException,
+)
logger = logging.getLogger(__name__)
class BaseFileProvider(BaseProvider):
- """ Base implementation of the config provider that reads the data from the file system. """
+ """ Base implementation of the config provider that reads the data from the file system. """
- def __init__(self, config_volume, yaml_filename, py_filename):
- self.config_volume = config_volume
- self.yaml_filename = yaml_filename
- self.py_filename = py_filename
+ def __init__(self, config_volume, yaml_filename, py_filename):
+ self.config_volume = config_volume
+ self.yaml_filename = yaml_filename
+ self.py_filename = py_filename
- self.yaml_path = os.path.join(config_volume, yaml_filename)
- self.py_path = os.path.join(config_volume, py_filename)
+ self.yaml_path = os.path.join(config_volume, yaml_filename)
+ self.py_path = os.path.join(config_volume, py_filename)
- def update_app_config(self, app_config):
- if os.path.exists(self.py_path):
- logger.debug('Applying config file: %s', self.py_path)
- app_config.from_pyfile(self.py_path)
+ def update_app_config(self, app_config):
+ if os.path.exists(self.py_path):
+ logger.debug("Applying config file: %s", self.py_path)
+ app_config.from_pyfile(self.py_path)
- if os.path.exists(self.yaml_path):
- logger.debug('Applying config file: %s', self.yaml_path)
- import_yaml(app_config, self.yaml_path)
+ if os.path.exists(self.yaml_path):
+ logger.debug("Applying config file: %s", self.yaml_path)
+ import_yaml(app_config, self.yaml_path)
- def get_config(self):
- if not self.config_exists():
- return None
+ def get_config(self):
+ if not self.config_exists():
+ return None
- config_obj = {}
- import_yaml(config_obj, self.yaml_path)
- return config_obj
+ config_obj = {}
+ import_yaml(config_obj, self.yaml_path)
+ return config_obj
- def config_exists(self):
- return self.volume_file_exists(self.yaml_filename)
+ def config_exists(self):
+ return self.volume_file_exists(self.yaml_filename)
- def volume_exists(self):
- return os.path.exists(self.config_volume)
+ def volume_exists(self):
+ return os.path.exists(self.config_volume)
- def volume_file_exists(self, filename):
- return os.path.exists(os.path.join(self.config_volume, filename))
+ def volume_file_exists(self, filename):
+ return os.path.exists(os.path.join(self.config_volume, filename))
- def get_volume_file(self, filename, mode='r'):
- return open(os.path.join(self.config_volume, filename), mode=mode)
+ def get_volume_file(self, filename, mode="r"):
+ return open(os.path.join(self.config_volume, filename), mode=mode)
- def get_volume_path(self, directory, filename):
- return os.path.join(directory, filename)
+ def get_volume_path(self, directory, filename):
+ return os.path.join(directory, filename)
- def list_volume_directory(self, path):
- dirpath = os.path.join(self.config_volume, path)
- if not os.path.exists(dirpath):
- return None
+ def list_volume_directory(self, path):
+ dirpath = os.path.join(self.config_volume, path)
+ if not os.path.exists(dirpath):
+ return None
- if not os.path.isdir(dirpath):
- return None
+ if not os.path.isdir(dirpath):
+ return None
- return os.listdir(dirpath)
+ return os.listdir(dirpath)
- def requires_restart(self, app_config):
- file_config = self.get_config()
- if not file_config:
- return False
+ def requires_restart(self, app_config):
+ file_config = self.get_config()
+ if not file_config:
+ return False
- for key in file_config:
- if app_config.get(key) != file_config[key]:
- return True
+ for key in file_config:
+ if app_config.get(key) != file_config[key]:
+ return True
- return False
+ return False
diff --git a/config_app/config_util/config/baseprovider.py b/config_app/config_util/config/baseprovider.py
index 17ae7e86b..e6705809d 100644
--- a/config_app/config_util/config/baseprovider.py
+++ b/config_app/config_util/config/baseprovider.py
@@ -12,117 +12,119 @@ logger = logging.getLogger(__name__)
class CannotWriteConfigException(Exception):
- """ Exception raised when the config cannot be written. """
- pass
+ """ Exception raised when the config cannot be written. """
+
+ pass
class SetupIncompleteException(Exception):
- """ Exception raised when attempting to verify config that has not yet been setup. """
- pass
+ """ Exception raised when attempting to verify config that has not yet been setup. """
+
+ pass
def import_yaml(config_obj, config_file):
- with open(config_file) as f:
- c = yaml.safe_load(f)
- if not c:
- logger.debug('Empty YAML config file')
- return
+ with open(config_file) as f:
+ c = yaml.safe_load(f)
+ if not c:
+ logger.debug("Empty YAML config file")
+ return
- if isinstance(c, str):
- raise Exception('Invalid YAML config file: ' + str(c))
+ if isinstance(c, str):
+ raise Exception("Invalid YAML config file: " + str(c))
- for key in c.iterkeys():
- if key.isupper():
- config_obj[key] = c[key]
+ for key in c.iterkeys():
+ if key.isupper():
+ config_obj[key] = c[key]
- if config_obj.get('SETUP_COMPLETE', False):
- try:
- validate(config_obj, CONFIG_SCHEMA)
- except ValidationError:
- # TODO: Change this into a real error
- logger.exception('Could not validate config schema')
- else:
- logger.debug('Skipping config schema validation because setup is not complete')
+ if config_obj.get("SETUP_COMPLETE", False):
+ try:
+ validate(config_obj, CONFIG_SCHEMA)
+ except ValidationError:
+ # TODO: Change this into a real error
+ logger.exception("Could not validate config schema")
+ else:
+ logger.debug("Skipping config schema validation because setup is not complete")
- return config_obj
+ return config_obj
def get_yaml(config_obj):
- return yaml.safe_dump(config_obj, encoding='utf-8', allow_unicode=True)
+ return yaml.safe_dump(config_obj, encoding="utf-8", allow_unicode=True)
def export_yaml(config_obj, config_file):
- try:
- with open(config_file, 'w') as f:
- f.write(get_yaml(config_obj))
- except IOError as ioe:
- raise CannotWriteConfigException(str(ioe))
+ try:
+ with open(config_file, "w") as f:
+ f.write(get_yaml(config_obj))
+ except IOError as ioe:
+ raise CannotWriteConfigException(str(ioe))
@add_metaclass(ABCMeta)
class BaseProvider(object):
- """ A configuration provider helps to load, save, and handle config override in the application.
+ """ A configuration provider helps to load, save, and handle config override in the application.
"""
- @property
- def provider_id(self):
- raise NotImplementedError
+ @property
+ def provider_id(self):
+ raise NotImplementedError
- @abstractmethod
- def update_app_config(self, app_config):
- """ Updates the given application config object with the loaded override config. """
+ @abstractmethod
+ def update_app_config(self, app_config):
+ """ Updates the given application config object with the loaded override config. """
- @abstractmethod
- def get_config(self):
- """ Returns the contents of the config override file, or None if none. """
+ @abstractmethod
+ def get_config(self):
+ """ Returns the contents of the config override file, or None if none. """
- @abstractmethod
- def save_config(self, config_object):
- """ Updates the contents of the config override file to those given. """
+ @abstractmethod
+ def save_config(self, config_object):
+ """ Updates the contents of the config override file to those given. """
- @abstractmethod
- def config_exists(self):
- """ Returns true if a config override file exists in the config volume. """
+ @abstractmethod
+ def config_exists(self):
+ """ Returns true if a config override file exists in the config volume. """
- @abstractmethod
- def volume_exists(self):
- """ Returns whether the config override volume exists. """
+ @abstractmethod
+ def volume_exists(self):
+ """ Returns whether the config override volume exists. """
- @abstractmethod
- def volume_file_exists(self, filename):
- """ Returns whether the file with the given name exists under the config override volume. """
+ @abstractmethod
+ def volume_file_exists(self, filename):
+ """ Returns whether the file with the given name exists under the config override volume. """
- @abstractmethod
- def get_volume_file(self, filename, mode='r'):
- """ Returns a Python file referring to the given name under the config override volume. """
+ @abstractmethod
+ def get_volume_file(self, filename, mode="r"):
+ """ Returns a Python file referring to the given name under the config override volume. """
- @abstractmethod
- def write_volume_file(self, filename, contents):
- """ Writes the given contents to the config override volumne, with the given filename. """
+ @abstractmethod
+ def write_volume_file(self, filename, contents):
+ """ Writes the given contents to the config override volumne, with the given filename. """
- @abstractmethod
- def remove_volume_file(self, filename):
- """ Removes the config override volume file with the given filename. """
+ @abstractmethod
+ def remove_volume_file(self, filename):
+ """ Removes the config override volume file with the given filename. """
- @abstractmethod
- def list_volume_directory(self, path):
- """ Returns a list of strings representing the names of the files found in the config override
+ @abstractmethod
+ def list_volume_directory(self, path):
+ """ Returns a list of strings representing the names of the files found in the config override
directory under the given path. If the path doesn't exist, returns None.
"""
- @abstractmethod
- def save_volume_file(self, filename, flask_file):
- """ Saves the given flask file to the config override volume, with the given
+ @abstractmethod
+ def save_volume_file(self, filename, flask_file):
+ """ Saves the given flask file to the config override volume, with the given
filename.
"""
- @abstractmethod
- def requires_restart(self, app_config):
- """ If true, the configuration loaded into memory for the app does not match that on disk,
+ @abstractmethod
+ def requires_restart(self, app_config):
+ """ If true, the configuration loaded into memory for the app does not match that on disk,
indicating that this container requires a restart.
"""
- @abstractmethod
- def get_volume_path(self, directory, filename):
- """ Helper for constructing file paths, which may differ between providers. For example,
+ @abstractmethod
+ def get_volume_path(self, directory, filename):
+ """ Helper for constructing file paths, which may differ between providers. For example,
kubernetes can't have subfolders in configmaps """
diff --git a/config_app/config_util/config/fileprovider.py b/config_app/config_util/config/fileprovider.py
index 74531e581..4f9d94ad0 100644
--- a/config_app/config_util/config/fileprovider.py
+++ b/config_app/config_util/config/fileprovider.py
@@ -1,60 +1,65 @@
import os
import logging
-from config_app.config_util.config.baseprovider import export_yaml, CannotWriteConfigException
+from config_app.config_util.config.baseprovider import (
+ export_yaml,
+ CannotWriteConfigException,
+)
from config_app.config_util.config.basefileprovider import BaseFileProvider
logger = logging.getLogger(__name__)
def _ensure_parent_dir(filepath):
- """ Ensures that the parent directory of the given file path exists. """
- try:
- parentpath = os.path.abspath(os.path.join(filepath, os.pardir))
- if not os.path.isdir(parentpath):
- os.makedirs(parentpath)
- except IOError as ioe:
- raise CannotWriteConfigException(str(ioe))
+ """ Ensures that the parent directory of the given file path exists. """
+ try:
+ parentpath = os.path.abspath(os.path.join(filepath, os.pardir))
+ if not os.path.isdir(parentpath):
+ os.makedirs(parentpath)
+ except IOError as ioe:
+ raise CannotWriteConfigException(str(ioe))
class FileConfigProvider(BaseFileProvider):
- """ Implementation of the config provider that reads and writes the data
+ """ Implementation of the config provider that reads and writes the data
from/to the file system. """
- def __init__(self, config_volume, yaml_filename, py_filename):
- super(FileConfigProvider, self).__init__(config_volume, yaml_filename, py_filename)
+ def __init__(self, config_volume, yaml_filename, py_filename):
+ super(FileConfigProvider, self).__init__(
+ config_volume, yaml_filename, py_filename
+ )
- @property
- def provider_id(self):
- return 'file'
+ @property
+ def provider_id(self):
+ return "file"
- def save_config(self, config_obj):
- export_yaml(config_obj, self.yaml_path)
+ def save_config(self, config_obj):
+ export_yaml(config_obj, self.yaml_path)
- def write_volume_file(self, filename, contents):
- filepath = os.path.join(self.config_volume, filename)
- _ensure_parent_dir(filepath)
+ def write_volume_file(self, filename, contents):
+ filepath = os.path.join(self.config_volume, filename)
+ _ensure_parent_dir(filepath)
- try:
- with open(filepath, mode='w') as f:
- f.write(contents)
- except IOError as ioe:
- raise CannotWriteConfigException(str(ioe))
+ try:
+ with open(filepath, mode="w") as f:
+ f.write(contents)
+ except IOError as ioe:
+ raise CannotWriteConfigException(str(ioe))
- return filepath
+ return filepath
- def remove_volume_file(self, filename):
- filepath = os.path.join(self.config_volume, filename)
- os.remove(filepath)
+ def remove_volume_file(self, filename):
+ filepath = os.path.join(self.config_volume, filename)
+ os.remove(filepath)
- def save_volume_file(self, filename, flask_file):
- filepath = os.path.join(self.config_volume, filename)
- _ensure_parent_dir(filepath)
+ def save_volume_file(self, filename, flask_file):
+ filepath = os.path.join(self.config_volume, filename)
+ _ensure_parent_dir(filepath)
- # Write the file.
- try:
- flask_file.save(filepath)
- except IOError as ioe:
- raise CannotWriteConfigException(str(ioe))
+ # Write the file.
+ try:
+ flask_file.save(filepath)
+ except IOError as ioe:
+ raise CannotWriteConfigException(str(ioe))
- return filepath
+ return filepath
diff --git a/config_app/config_util/config/test/test_helpers.py b/config_app/config_util/config/test/test_helpers.py
index ceeae51ff..f266bb65c 100644
--- a/config_app/config_util/config/test/test_helpers.py
+++ b/config_app/config_util/config/test/test_helpers.py
@@ -9,67 +9,73 @@ from util.config.validator import EXTRA_CA_DIRECTORY
def _create_temp_file_structure(file_structure):
- temp_dir = TemporaryDirectory()
+ temp_dir = TemporaryDirectory()
- for filename, data in file_structure.iteritems():
- if filename == EXTRA_CA_DIRECTORY:
- extra_ca_dir_path = os.path.join(temp_dir.name, EXTRA_CA_DIRECTORY)
- os.mkdir(extra_ca_dir_path)
+ for filename, data in file_structure.iteritems():
+ if filename == EXTRA_CA_DIRECTORY:
+ extra_ca_dir_path = os.path.join(temp_dir.name, EXTRA_CA_DIRECTORY)
+ os.mkdir(extra_ca_dir_path)
- for name, cert_value in data:
- with open(os.path.join(extra_ca_dir_path, name), 'w') as f:
- f.write(cert_value)
- else:
- with open(os.path.join(temp_dir.name, filename), 'w') as f:
- f.write(data)
+ for name, cert_value in data:
+ with open(os.path.join(extra_ca_dir_path, name), "w") as f:
+ f.write(cert_value)
+ else:
+ with open(os.path.join(temp_dir.name, filename), "w") as f:
+ f.write(data)
- return temp_dir
+ return temp_dir
-@pytest.mark.parametrize('file_structure, expected_secret', [
- pytest.param({
- 'config.yaml': 'test:true',
- },
- {
- 'config.yaml': 'dGVzdDp0cnVl',
- }, id='just a config value'),
- pytest.param({
- 'config.yaml': 'test:true',
- 'otherfile.ext': 'im a file'
- },
- {
- 'config.yaml': 'dGVzdDp0cnVl',
- 'otherfile.ext': base64.b64encode('im a file')
- }, id='config and another file'),
- pytest.param({
- 'config.yaml': 'test:true',
- 'extra_ca_certs': [
- ('cert.crt', 'im a cert!'),
- ]
- },
- {
- 'config.yaml': 'dGVzdDp0cnVl',
- 'extra_ca_certs_cert.crt': base64.b64encode('im a cert!'),
- }, id='config and an extra cert'),
- pytest.param({
- 'config.yaml': 'test:true',
- 'otherfile.ext': 'im a file',
- 'extra_ca_certs': [
- ('cert.crt', 'im a cert!'),
- ('another.crt', 'im a different cert!'),
- ]
- },
- {
- 'config.yaml': 'dGVzdDp0cnVl',
- 'otherfile.ext': base64.b64encode('im a file'),
- 'extra_ca_certs_cert.crt': base64.b64encode('im a cert!'),
- 'extra_ca_certs_another.crt': base64.b64encode('im a different cert!'),
- }, id='config, files, and extra certs!'),
-])
+@pytest.mark.parametrize(
+ "file_structure, expected_secret",
+ [
+ pytest.param(
+ {"config.yaml": "test:true"},
+ {"config.yaml": "dGVzdDp0cnVl"},
+ id="just a config value",
+ ),
+ pytest.param(
+ {"config.yaml": "test:true", "otherfile.ext": "im a file"},
+ {
+ "config.yaml": "dGVzdDp0cnVl",
+ "otherfile.ext": base64.b64encode("im a file"),
+ },
+ id="config and another file",
+ ),
+ pytest.param(
+ {
+ "config.yaml": "test:true",
+ "extra_ca_certs": [("cert.crt", "im a cert!")],
+ },
+ {
+ "config.yaml": "dGVzdDp0cnVl",
+ "extra_ca_certs_cert.crt": base64.b64encode("im a cert!"),
+ },
+ id="config and an extra cert",
+ ),
+ pytest.param(
+ {
+ "config.yaml": "test:true",
+ "otherfile.ext": "im a file",
+ "extra_ca_certs": [
+ ("cert.crt", "im a cert!"),
+ ("another.crt", "im a different cert!"),
+ ],
+ },
+ {
+ "config.yaml": "dGVzdDp0cnVl",
+ "otherfile.ext": base64.b64encode("im a file"),
+ "extra_ca_certs_cert.crt": base64.b64encode("im a cert!"),
+ "extra_ca_certs_another.crt": base64.b64encode("im a different cert!"),
+ },
+ id="config, files, and extra certs!",
+ ),
+ ],
+)
def test_get_config_as_kube_secret(file_structure, expected_secret):
- temp_dir = _create_temp_file_structure(file_structure)
+ temp_dir = _create_temp_file_structure(file_structure)
- secret = get_config_as_kube_secret(temp_dir.name)
- assert secret == expected_secret
+ secret = get_config_as_kube_secret(temp_dir.name)
+ assert secret == expected_secret
- temp_dir.cleanup()
+ temp_dir.cleanup()
diff --git a/config_app/config_util/config/test/test_transient_dir_provider.py b/config_app/config_util/config/test/test_transient_dir_provider.py
index 2d1f3f96c..2d53b5153 100644
--- a/config_app/config_util/config/test/test_transient_dir_provider.py
+++ b/config_app/config_util/config/test/test_transient_dir_provider.py
@@ -1,68 +1,71 @@
import pytest
import os
-from config_app.config_util.config.TransientDirectoryProvider import TransientDirectoryProvider
+from config_app.config_util.config.TransientDirectoryProvider import (
+ TransientDirectoryProvider,
+)
-@pytest.mark.parametrize('files_to_write, operations, expected_new_dir', [
- pytest.param({
- 'config.yaml': 'a config',
- }, ([], [], []), {
- 'config.yaml': 'a config',
- }, id='just a config'),
- pytest.param({
- 'config.yaml': 'a config',
- 'oldfile': 'hmmm'
- }, ([], [], ['oldfile']), {
- 'config.yaml': 'a config',
- }, id='delete a file'),
- pytest.param({
- 'config.yaml': 'a config',
- 'oldfile': 'hmmm'
- }, ([('newfile', 'asdf')], [], ['oldfile']), {
- 'config.yaml': 'a config',
- 'newfile': 'asdf'
- }, id='delete and add a file'),
- pytest.param({
- 'config.yaml': 'a config',
- 'somefile': 'before'
- }, ([('newfile', 'asdf')], [('somefile', 'after')], []), {
- 'config.yaml': 'a config',
- 'newfile': 'asdf',
- 'somefile': 'after',
- }, id='add new files and change files'),
-])
+@pytest.mark.parametrize(
+ "files_to_write, operations, expected_new_dir",
+ [
+ pytest.param(
+ {"config.yaml": "a config"},
+ ([], [], []),
+ {"config.yaml": "a config"},
+ id="just a config",
+ ),
+ pytest.param(
+ {"config.yaml": "a config", "oldfile": "hmmm"},
+ ([], [], ["oldfile"]),
+ {"config.yaml": "a config"},
+ id="delete a file",
+ ),
+ pytest.param(
+ {"config.yaml": "a config", "oldfile": "hmmm"},
+ ([("newfile", "asdf")], [], ["oldfile"]),
+ {"config.yaml": "a config", "newfile": "asdf"},
+ id="delete and add a file",
+ ),
+ pytest.param(
+ {"config.yaml": "a config", "somefile": "before"},
+ ([("newfile", "asdf")], [("somefile", "after")], []),
+ {"config.yaml": "a config", "newfile": "asdf", "somefile": "after"},
+ id="add new files and change files",
+ ),
+ ],
+)
def test_transient_dir_copy_config_dir(files_to_write, operations, expected_new_dir):
- config_provider = TransientDirectoryProvider('', '', '')
+ config_provider = TransientDirectoryProvider("", "", "")
- for name, data in files_to_write.iteritems():
- config_provider.write_volume_file(name, data)
+ for name, data in files_to_write.iteritems():
+ config_provider.write_volume_file(name, data)
- config_provider.create_copy_of_config_dir()
+ config_provider.create_copy_of_config_dir()
- for create in operations[0]:
- (name, data) = create
- config_provider.write_volume_file(name, data)
+ for create in operations[0]:
+ (name, data) = create
+ config_provider.write_volume_file(name, data)
- for update in operations[1]:
- (name, data) = update
- config_provider.write_volume_file(name, data)
+ for update in operations[1]:
+ (name, data) = update
+ config_provider.write_volume_file(name, data)
- for delete in operations[2]:
- config_provider.remove_volume_file(delete)
+ for delete in operations[2]:
+ config_provider.remove_volume_file(delete)
- # check that the new directory matches expected state
- for filename, data in expected_new_dir.iteritems():
- with open(os.path.join(config_provider.get_config_dir_path(), filename)) as f:
- new_data = f.read()
- assert new_data == data
+ # check that the new directory matches expected state
+ for filename, data in expected_new_dir.iteritems():
+ with open(os.path.join(config_provider.get_config_dir_path(), filename)) as f:
+ new_data = f.read()
+ assert new_data == data
- # Now check that the old dir matches the original state
- saved = config_provider.get_old_config_dir()
+ # Now check that the old dir matches the original state
+ saved = config_provider.get_old_config_dir()
- for filename, data in files_to_write.iteritems():
- with open(os.path.join(saved, filename)) as f:
- new_data = f.read()
- assert new_data == data
+ for filename, data in files_to_write.iteritems():
+ with open(os.path.join(saved, filename)) as f:
+ new_data = f.read()
+ assert new_data == data
- config_provider.temp_dir.cleanup()
+ config_provider.temp_dir.cleanup()
diff --git a/config_app/config_util/config/testprovider.py b/config_app/config_util/config/testprovider.py
index 63e563056..39070687e 100644
--- a/config_app/config_util/config/testprovider.py
+++ b/config_app/config_util/config/testprovider.py
@@ -4,80 +4,84 @@ import os
from config_app.config_util.config.baseprovider import BaseProvider
-REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
+REAL_FILES = [
+ "test/data/signing-private.gpg",
+ "test/data/signing-public.gpg",
+ "test/data/test.pem",
+]
class TestConfigProvider(BaseProvider):
- """ Implementation of the config provider for testing. Everything is kept in-memory instead on
+ """ Implementation of the config provider for testing. Everything is kept in-memory instead on
the real file system. """
- def __init__(self):
- self.clear()
+ def __init__(self):
+ self.clear()
- def clear(self):
- self.files = {}
- self._config = {}
+ def clear(self):
+ self.files = {}
+ self._config = {}
- @property
- def provider_id(self):
- return 'test'
+ @property
+ def provider_id(self):
+ return "test"
- def update_app_config(self, app_config):
- self._config = app_config
+ def update_app_config(self, app_config):
+ self._config = app_config
- def get_config(self):
- if not 'config.yaml' in self.files:
- return None
+ def get_config(self):
+ if not "config.yaml" in self.files:
+ return None
- return json.loads(self.files.get('config.yaml', '{}'))
+ return json.loads(self.files.get("config.yaml", "{}"))
- def save_config(self, config_obj):
- self.files['config.yaml'] = json.dumps(config_obj)
+ def save_config(self, config_obj):
+ self.files["config.yaml"] = json.dumps(config_obj)
- def config_exists(self):
- return 'config.yaml' in self.files
+ def config_exists(self):
+ return "config.yaml" in self.files
- def volume_exists(self):
- return True
+ def volume_exists(self):
+ return True
- def volume_file_exists(self, filename):
- if filename in REAL_FILES:
- return True
+ def volume_file_exists(self, filename):
+ if filename in REAL_FILES:
+ return True
- return filename in self.files
+ return filename in self.files
- def save_volume_file(self, filename, flask_file):
- self.files[filename] = flask_file.read()
+ def save_volume_file(self, filename, flask_file):
+ self.files[filename] = flask_file.read()
- def write_volume_file(self, filename, contents):
- self.files[filename] = contents
+ def write_volume_file(self, filename, contents):
+ self.files[filename] = contents
- def get_volume_file(self, filename, mode='r'):
- if filename in REAL_FILES:
- return open(filename, mode=mode)
+ def get_volume_file(self, filename, mode="r"):
+ if filename in REAL_FILES:
+ return open(filename, mode=mode)
- return io.BytesIO(self.files[filename])
+ return io.BytesIO(self.files[filename])
- def remove_volume_file(self, filename):
- self.files.pop(filename, None)
+ def remove_volume_file(self, filename):
+ self.files.pop(filename, None)
- def list_volume_directory(self, path):
- paths = []
- for filename in self.files:
- if filename.startswith(path):
- paths.append(filename[len(path) + 1:])
+ def list_volume_directory(self, path):
+ paths = []
+ for filename in self.files:
+ if filename.startswith(path):
+ paths.append(filename[len(path) + 1 :])
- return paths
+ return paths
- def requires_restart(self, app_config):
- return False
+ def requires_restart(self, app_config):
+ return False
- def reset_for_test(self):
- self._config['SUPER_USERS'] = ['devtable']
- self.files = {}
+ def reset_for_test(self):
+ self._config["SUPER_USERS"] = ["devtable"]
+ self.files = {}
- def get_volume_path(self, directory, filename):
- return os.path.join(directory, filename)
+ def get_volume_path(self, directory, filename):
+ return os.path.join(directory, filename)
- def get_config_dir_path(self):
- return ''
+ def get_config_dir_path(self):
+ return ""
diff --git a/config_app/config_util/k8saccessor.py b/config_app/config_util/k8saccessor.py
index dd115681b..e59f127f2 100644
--- a/config_app/config_util/k8saccessor.py
+++ b/config_app/config_util/k8saccessor.py
@@ -12,295 +12,374 @@ from config_app.config_util.k8sconfig import KubernetesConfig
logger = logging.getLogger(__name__)
-QE_DEPLOYMENT_LABEL = 'quay-enterprise-component'
-QE_CONTAINER_NAME = 'quay-enterprise-app'
+QE_DEPLOYMENT_LABEL = "quay-enterprise-component"
+QE_CONTAINER_NAME = "quay-enterprise-app"
# Tuple containing response of the deployment rollout status method.
# status is one of: 'failed' | 'progressing' | 'available'
# message is any string describing the state.
-DeploymentRolloutStatus = namedtuple('DeploymentRolloutStatus', ['status', 'message'])
+DeploymentRolloutStatus = namedtuple("DeploymentRolloutStatus", ["status", "message"])
+
class K8sApiException(Exception):
- pass
+ pass
def _deployment_rollout_status_message(deployment, deployment_name):
- """
+ """
Gets the friendly human readable message of the current state of the deployment rollout
:param deployment: python dict matching: https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.11/#deployment-v1-apps
:param deployment_name: string
:return: DeploymentRolloutStatus
"""
- # Logic for rollout status pulled from the `kubectl rollout status` command:
- # https://github.com/kubernetes/kubernetes/blob/d9ba19c751709c8608e09a0537eea98973f3a796/pkg/kubectl/rollout_status.go#L62
- if deployment['metadata']['generation'] <= deployment['status']['observedGeneration']:
- for cond in deployment['status']['conditions']:
- if cond['type'] == 'Progressing' and cond['reason'] == 'ProgressDeadlineExceeded':
+ # Logic for rollout status pulled from the `kubectl rollout status` command:
+ # https://github.com/kubernetes/kubernetes/blob/d9ba19c751709c8608e09a0537eea98973f3a796/pkg/kubectl/rollout_status.go#L62
+ if (
+ deployment["metadata"]["generation"]
+ <= deployment["status"]["observedGeneration"]
+ ):
+ for cond in deployment["status"]["conditions"]:
+ if (
+ cond["type"] == "Progressing"
+ and cond["reason"] == "ProgressDeadlineExceeded"
+ ):
+ return DeploymentRolloutStatus(
+ status="failed",
+ message="Deployment %s's rollout failed. Please try again later."
+ % deployment_name,
+ )
+
+ desired_replicas = deployment["spec"]["replicas"]
+ current_replicas = deployment["status"].get("replicas", 0)
+ if current_replicas == 0:
+ return DeploymentRolloutStatus(
+ status="available",
+ message="Deployment %s updated (no replicas, so nothing to roll out)"
+ % deployment_name,
+ )
+
+ # Some fields are optional in the spec, so if they're omitted, replace with defaults that won't indicate a wrong status
+ available_replicas = deployment["status"].get("availableReplicas", 0)
+ updated_replicas = deployment["status"].get("updatedReplicas", 0)
+
+ if updated_replicas < desired_replicas:
+ return DeploymentRolloutStatus(
+ status="progressing",
+ message="Waiting for rollout to finish: %d out of %d new replicas have been updated..."
+ % (updated_replicas, desired_replicas),
+ )
+
+ if current_replicas > updated_replicas:
+ return DeploymentRolloutStatus(
+ status="progressing",
+ message="Waiting for rollout to finish: %d old replicas are pending termination..."
+ % (current_replicas - updated_replicas),
+ )
+
+ if available_replicas < updated_replicas:
+ return DeploymentRolloutStatus(
+ status="progressing",
+ message="Waiting for rollout to finish: %d of %d updated replicas are available..."
+ % (available_replicas, updated_replicas),
+ )
+
return DeploymentRolloutStatus(
- status='failed',
- message="Deployment %s's rollout failed. Please try again later." % deployment_name
+ status="available",
+ message="Deployment %s successfully rolled out." % deployment_name,
)
- desired_replicas = deployment['spec']['replicas']
- current_replicas = deployment['status'].get('replicas', 0)
- if current_replicas == 0:
- return DeploymentRolloutStatus(
- status='available',
- message='Deployment %s updated (no replicas, so nothing to roll out)' % deployment_name
- )
-
- # Some fields are optional in the spec, so if they're omitted, replace with defaults that won't indicate a wrong status
- available_replicas = deployment['status'].get('availableReplicas', 0)
- updated_replicas = deployment['status'].get('updatedReplicas', 0)
-
- if updated_replicas < desired_replicas:
- return DeploymentRolloutStatus(
- status='progressing',
- message='Waiting for rollout to finish: %d out of %d new replicas have been updated...' % (
- updated_replicas, desired_replicas)
- )
-
- if current_replicas > updated_replicas:
- return DeploymentRolloutStatus(
- status='progressing',
- message='Waiting for rollout to finish: %d old replicas are pending termination...' % (
- current_replicas - updated_replicas)
- )
-
- if available_replicas < updated_replicas:
- return DeploymentRolloutStatus(
- status='progressing',
- message='Waiting for rollout to finish: %d of %d updated replicas are available...' % (
- available_replicas, updated_replicas)
- )
-
return DeploymentRolloutStatus(
- status='available',
- message='Deployment %s successfully rolled out.' % deployment_name
+ status="progressing", message="Waiting for deployment spec to be updated..."
)
- return DeploymentRolloutStatus(
- status='progressing',
- message='Waiting for deployment spec to be updated...'
- )
-
class KubernetesAccessorSingleton(object):
- """ Singleton allowing access to kubernetes operations """
- _instance = None
+ """ Singleton allowing access to kubernetes operations """
- def __init__(self, kube_config=None):
- self.kube_config = kube_config
- if kube_config is None:
- self.kube_config = KubernetesConfig.from_env()
+ _instance = None
- KubernetesAccessorSingleton._instance = self
+ def __init__(self, kube_config=None):
+ self.kube_config = kube_config
+ if kube_config is None:
+ self.kube_config = KubernetesConfig.from_env()
- @classmethod
- def get_instance(cls, kube_config=None):
- """
+ KubernetesAccessorSingleton._instance = self
+
+ @classmethod
+ def get_instance(cls, kube_config=None):
+ """
Singleton getter implementation, returns the instance if one exists, otherwise creates the
instance and ties it to the class.
:return: KubernetesAccessorSingleton
"""
- if cls._instance is None:
- return cls(kube_config)
+ if cls._instance is None:
+ return cls(kube_config)
- return cls._instance
+ return cls._instance
- def save_secret_to_directory(self, dir_path):
- """
+ def save_secret_to_directory(self, dir_path):
+ """
Saves all files in the kubernetes secret to a local directory.
Assumes the directory is empty.
"""
- secret = self._lookup_secret()
+ secret = self._lookup_secret()
- secret_data = secret.get('data', {})
+ secret_data = secret.get("data", {})
- # Make the `extra_ca_certs` dir to ensure we can populate extra certs
- extra_ca_dir_path = os.path.join(dir_path, EXTRA_CA_DIRECTORY)
- os.mkdir(extra_ca_dir_path)
+ # Make the `extra_ca_certs` dir to ensure we can populate extra certs
+ extra_ca_dir_path = os.path.join(dir_path, EXTRA_CA_DIRECTORY)
+ os.mkdir(extra_ca_dir_path)
- for secret_filename, data in secret_data.iteritems():
- write_path = os.path.join(dir_path, secret_filename)
+ for secret_filename, data in secret_data.iteritems():
+ write_path = os.path.join(dir_path, secret_filename)
- if EXTRA_CA_DIRECTORY_PREFIX in secret_filename:
- write_path = os.path.join(extra_ca_dir_path, secret_filename.replace(EXTRA_CA_DIRECTORY_PREFIX, ''))
+ if EXTRA_CA_DIRECTORY_PREFIX in secret_filename:
+ write_path = os.path.join(
+ extra_ca_dir_path,
+ secret_filename.replace(EXTRA_CA_DIRECTORY_PREFIX, ""),
+ )
- with open(write_path, 'w') as f:
- f.write(base64.b64decode(data))
+ with open(write_path, "w") as f:
+ f.write(base64.b64decode(data))
- return 200
+ return 200
- def save_file_as_secret(self, name, file_pointer):
- value = file_pointer.read()
- self._update_secret_file(name, value)
+ def save_file_as_secret(self, name, file_pointer):
+ value = file_pointer.read()
+ self._update_secret_file(name, value)
- def replace_qe_secret(self, new_secret_data):
- """
+ def replace_qe_secret(self, new_secret_data):
+ """
Removes the old config and replaces it with the new_secret_data as one action
"""
- # Check first that the namespace for Red Hat Quay exists. If it does not, report that
- # as an error, as it seems to be a common issue.
- namespace_url = 'namespaces/%s' % (self.kube_config.qe_namespace)
- response = self._execute_k8s_api('GET', namespace_url)
- if response.status_code // 100 != 2:
- msg = 'A Kubernetes namespace with name `%s` must be created to save config' % self.kube_config.qe_namespace
- raise Exception(msg)
+ # Check first that the namespace for Red Hat Quay exists. If it does not, report that
+ # as an error, as it seems to be a common issue.
+ namespace_url = "namespaces/%s" % (self.kube_config.qe_namespace)
+ response = self._execute_k8s_api("GET", namespace_url)
+ if response.status_code // 100 != 2:
+ msg = (
+ "A Kubernetes namespace with name `%s` must be created to save config"
+ % self.kube_config.qe_namespace
+ )
+ raise Exception(msg)
- # Check if the secret exists. If not, then we create an empty secret and then update the file
- # inside.
- secret_url = 'namespaces/%s/secrets/%s' % (self.kube_config.qe_namespace, self.kube_config.qe_config_secret)
- secret = self._lookup_secret()
- if secret is None:
- self._assert_success(self._execute_k8s_api('POST', secret_url, {
- "kind": "Secret",
- "apiVersion": "v1",
- "metadata": {
- "name": self.kube_config.qe_config_secret
- },
- "data": {}
- }))
+ # Check if the secret exists. If not, then we create an empty secret and then update the file
+ # inside.
+ secret_url = "namespaces/%s/secrets/%s" % (
+ self.kube_config.qe_namespace,
+ self.kube_config.qe_config_secret,
+ )
+ secret = self._lookup_secret()
+ if secret is None:
+ self._assert_success(
+ self._execute_k8s_api(
+ "POST",
+ secret_url,
+ {
+ "kind": "Secret",
+ "apiVersion": "v1",
+ "metadata": {"name": self.kube_config.qe_config_secret},
+ "data": {},
+ },
+ )
+ )
- # Update the secret to reflect the file change.
- secret['data'] = new_secret_data
+ # Update the secret to reflect the file change.
+ secret["data"] = new_secret_data
- self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))
+ self._assert_success(self._execute_k8s_api("PUT", secret_url, secret))
- def get_deployment_rollout_status(self, deployment_name):
- """"
+ def get_deployment_rollout_status(self, deployment_name):
+ """"
Returns the status of a rollout of a given deployment
:return _DeploymentRolloutStatus
"""
- deployment_selector_url = 'namespaces/%s/deployments/%s' % (
- self.kube_config.qe_namespace, deployment_name
- )
+ deployment_selector_url = "namespaces/%s/deployments/%s" % (
+ self.kube_config.qe_namespace,
+ deployment_name,
+ )
- response = self._execute_k8s_api('GET', deployment_selector_url, api_prefix='apis/apps/v1')
- if response.status_code != 200:
- return DeploymentRolloutStatus('failed', 'Could not get deployment. Please check that the deployment exists')
+ response = self._execute_k8s_api(
+ "GET", deployment_selector_url, api_prefix="apis/apps/v1"
+ )
+ if response.status_code != 200:
+ return DeploymentRolloutStatus(
+ "failed",
+ "Could not get deployment. Please check that the deployment exists",
+ )
- deployment = json.loads(response.text)
+ deployment = json.loads(response.text)
- return _deployment_rollout_status_message(deployment, deployment_name)
+ return _deployment_rollout_status_message(deployment, deployment_name)
- def get_qe_deployments(self):
- """"
+ def get_qe_deployments(self):
+ """"
Returns all deployments matching the label selector provided in the KubeConfig
"""
- deployment_selector_url = 'namespaces/%s/deployments?labelSelector=%s%%3D%s' % (
- self.kube_config.qe_namespace, QE_DEPLOYMENT_LABEL, self.kube_config.qe_deployment_selector
- )
+ deployment_selector_url = "namespaces/%s/deployments?labelSelector=%s%%3D%s" % (
+ self.kube_config.qe_namespace,
+ QE_DEPLOYMENT_LABEL,
+ self.kube_config.qe_deployment_selector,
+ )
- response = self._execute_k8s_api('GET', deployment_selector_url, api_prefix='apis/extensions/v1beta1')
- if response.status_code != 200:
- return None
- return json.loads(response.text)
+ response = self._execute_k8s_api(
+ "GET", deployment_selector_url, api_prefix="apis/extensions/v1beta1"
+ )
+ if response.status_code != 200:
+ return None
+ return json.loads(response.text)
- def cycle_qe_deployments(self, deployment_names):
- """"
+ def cycle_qe_deployments(self, deployment_names):
+ """"
Triggers a rollout of all desired deployments in the qe namespace
"""
- for name in deployment_names:
- logger.debug('Cycling deployment %s', name)
- deployment_url = 'namespaces/%s/deployments/%s' % (self.kube_config.qe_namespace, name)
+ for name in deployment_names:
+ logger.debug("Cycling deployment %s", name)
+ deployment_url = "namespaces/%s/deployments/%s" % (
+ self.kube_config.qe_namespace,
+ name,
+ )
- # There is currently no command to simply rolling restart all the pods: https://github.com/kubernetes/kubernetes/issues/13488
- # Instead, we modify the template of the deployment with a dummy env variable to trigger a cycle of the pods
- # (based off this comment: https://github.com/kubernetes/kubernetes/issues/13488#issuecomment-240393845)
- self._assert_success(self._execute_k8s_api('PATCH', deployment_url, {
- 'spec': {
- 'template': {
- 'spec': {
- 'containers': [{
- # Note: this name MUST match the deployment template's pod template
- # (e.g. .spec.template.spec.containers[0] == 'quay-enterprise-app')
- 'name': QE_CONTAINER_NAME,
- 'env': [{
- 'name': 'RESTART_TIME',
- 'value': str(datetime.datetime.now())
- }],
- }]
- }
- }
- }
- }, api_prefix='apis/extensions/v1beta1', content_type='application/strategic-merge-patch+json'))
+ # There is currently no command to simply rolling restart all the pods: https://github.com/kubernetes/kubernetes/issues/13488
+ # Instead, we modify the template of the deployment with a dummy env variable to trigger a cycle of the pods
+ # (based off this comment: https://github.com/kubernetes/kubernetes/issues/13488#issuecomment-240393845)
+ self._assert_success(
+ self._execute_k8s_api(
+ "PATCH",
+ deployment_url,
+ {
+ "spec": {
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ # Note: this name MUST match the deployment template's pod template
+ # (e.g. .spec.template.spec.containers[0] == 'quay-enterprise-app')
+ "name": QE_CONTAINER_NAME,
+ "env": [
+ {
+ "name": "RESTART_TIME",
+ "value": str(
+ datetime.datetime.now()
+ ),
+ }
+ ],
+ }
+ ]
+ }
+ }
+ }
+ },
+ api_prefix="apis/extensions/v1beta1",
+ content_type="application/strategic-merge-patch+json",
+ )
+ )
- def rollback_deployment(self, deployment_name):
- deployment_rollback_url = 'namespaces/%s/deployments/%s/rollback' % (
- self.kube_config.qe_namespace, deployment_name
- )
+ def rollback_deployment(self, deployment_name):
+ deployment_rollback_url = "namespaces/%s/deployments/%s/rollback" % (
+ self.kube_config.qe_namespace,
+ deployment_name,
+ )
- self._assert_success(self._execute_k8s_api('POST', deployment_rollback_url, {
- 'name': deployment_name,
- 'rollbackTo': {
- # revision=0 makes the deployment rollout to the previous revision
- 'revision': 0
- }
- }, api_prefix='apis/extensions/v1beta1'), 201)
+ self._assert_success(
+ self._execute_k8s_api(
+ "POST",
+ deployment_rollback_url,
+ {
+ "name": deployment_name,
+ "rollbackTo": {
+ # revision=0 makes the deployment rollout to the previous revision
+ "revision": 0
+ },
+ },
+ api_prefix="apis/extensions/v1beta1",
+ ),
+ 201,
+ )
- def _assert_success(self, response, expected_code=200):
- if response.status_code != expected_code:
- logger.error('Kubernetes API call failed with response: %s => %s', response.status_code,
- response.text)
- raise K8sApiException('Kubernetes API call failed: %s' % response.text)
+ def _assert_success(self, response, expected_code=200):
+ if response.status_code != expected_code:
+ logger.error(
+ "Kubernetes API call failed with response: %s => %s",
+ response.status_code,
+ response.text,
+ )
+ raise K8sApiException("Kubernetes API call failed: %s" % response.text)
- def _update_secret_file(self, relative_file_path, value=None):
- if '/' in relative_file_path:
- raise Exception('Expected path from get_volume_path, but found slashes')
+ def _update_secret_file(self, relative_file_path, value=None):
+ if "/" in relative_file_path:
+ raise Exception("Expected path from get_volume_path, but found slashes")
- # Check first that the namespace for Red Hat Quay exists. If it does not, report that
- # as an error, as it seems to be a common issue.
- namespace_url = 'namespaces/%s' % (self.kube_config.qe_namespace)
- response = self._execute_k8s_api('GET', namespace_url)
- if response.status_code // 100 != 2:
- msg = 'A Kubernetes namespace with name `%s` must be created to save config' % self.kube_config.qe_namespace
- raise Exception(msg)
+ # Check first that the namespace for Red Hat Quay exists. If it does not, report that
+ # as an error, as it seems to be a common issue.
+ namespace_url = "namespaces/%s" % (self.kube_config.qe_namespace)
+ response = self._execute_k8s_api("GET", namespace_url)
+ if response.status_code // 100 != 2:
+ msg = (
+ "A Kubernetes namespace with name `%s` must be created to save config"
+ % self.kube_config.qe_namespace
+ )
+ raise Exception(msg)
- # Check if the secret exists. If not, then we create an empty secret and then update the file
- # inside.
- secret_url = 'namespaces/%s/secrets/%s' % (self.kube_config.qe_namespace, self.kube_config.qe_config_secret)
- secret = self._lookup_secret()
- if secret is None:
- self._assert_success(self._execute_k8s_api('POST', secret_url, {
- "kind": "Secret",
- "apiVersion": "v1",
- "metadata": {
- "name": self.kube_config.qe_config_secret
- },
- "data": {}
- }))
+ # Check if the secret exists. If not, then we create an empty secret and then update the file
+ # inside.
+ secret_url = "namespaces/%s/secrets/%s" % (
+ self.kube_config.qe_namespace,
+ self.kube_config.qe_config_secret,
+ )
+ secret = self._lookup_secret()
+ if secret is None:
+ self._assert_success(
+ self._execute_k8s_api(
+ "POST",
+ secret_url,
+ {
+ "kind": "Secret",
+ "apiVersion": "v1",
+ "metadata": {"name": self.kube_config.qe_config_secret},
+ "data": {},
+ },
+ )
+ )
- # Update the secret to reflect the file change.
- secret['data'] = secret.get('data', {})
+ # Update the secret to reflect the file change.
+ secret["data"] = secret.get("data", {})
- if value is not None:
- secret['data'][relative_file_path] = base64.b64encode(value)
- else:
- secret['data'].pop(relative_file_path)
+ if value is not None:
+ secret["data"][relative_file_path] = base64.b64encode(value)
+ else:
+ secret["data"].pop(relative_file_path)
- self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))
+ self._assert_success(self._execute_k8s_api("PUT", secret_url, secret))
- def _lookup_secret(self):
- secret_url = 'namespaces/%s/secrets/%s' % (self.kube_config.qe_namespace, self.kube_config.qe_config_secret)
- response = self._execute_k8s_api('GET', secret_url)
- if response.status_code != 200:
- return None
- return json.loads(response.text)
+ def _lookup_secret(self):
+ secret_url = "namespaces/%s/secrets/%s" % (
+ self.kube_config.qe_namespace,
+ self.kube_config.qe_config_secret,
+ )
+ response = self._execute_k8s_api("GET", secret_url)
+ if response.status_code != 200:
+ return None
+ return json.loads(response.text)
- def _execute_k8s_api(self, method, relative_url, data=None, api_prefix='api/v1', content_type='application/json'):
- headers = {
- 'Authorization': 'Bearer ' + self.kube_config.service_account_token
- }
+ def _execute_k8s_api(
+ self,
+ method,
+ relative_url,
+ data=None,
+ api_prefix="api/v1",
+ content_type="application/json",
+ ):
+ headers = {"Authorization": "Bearer " + self.kube_config.service_account_token}
- if data:
- headers['Content-Type'] = content_type
+ if data:
+ headers["Content-Type"] = content_type
- data = json.dumps(data) if data else None
- session = Session()
- url = 'https://%s/%s/%s' % (self.kube_config.api_host, api_prefix, relative_url)
+ data = json.dumps(data) if data else None
+ session = Session()
+ url = "https://%s/%s/%s" % (self.kube_config.api_host, api_prefix, relative_url)
- request = Request(method, url, data=data, headers=headers)
- return session.send(request.prepare(), verify=False, timeout=2)
+ request = Request(method, url, data=data, headers=headers)
+ return session.send(request.prepare(), verify=False, timeout=2)
diff --git a/config_app/config_util/k8sconfig.py b/config_app/config_util/k8sconfig.py
index c7e5ac3ed..7e4bd43b1 100644
--- a/config_app/config_util/k8sconfig.py
+++ b/config_app/config_util/k8sconfig.py
@@ -1,46 +1,59 @@
import os
-SERVICE_ACCOUNT_TOKEN_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/token'
+SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token"
-DEFAULT_QE_NAMESPACE = 'quay-enterprise'
-DEFAULT_QE_CONFIG_SECRET = 'quay-enterprise-config-secret'
+DEFAULT_QE_NAMESPACE = "quay-enterprise"
+DEFAULT_QE_CONFIG_SECRET = "quay-enterprise-config-secret"
# The name of the quay enterprise deployment (not config app) that is used to query & rollout
-DEFAULT_QE_DEPLOYMENT_SELECTOR = 'app'
+DEFAULT_QE_DEPLOYMENT_SELECTOR = "app"
def get_k8s_namespace():
- return os.environ.get('QE_K8S_NAMESPACE', DEFAULT_QE_NAMESPACE)
+ return os.environ.get("QE_K8S_NAMESPACE", DEFAULT_QE_NAMESPACE)
class KubernetesConfig(object):
- def __init__(self, api_host='', service_account_token=SERVICE_ACCOUNT_TOKEN_PATH,
- qe_namespace=DEFAULT_QE_NAMESPACE,
- qe_config_secret=DEFAULT_QE_CONFIG_SECRET,
- qe_deployment_selector=DEFAULT_QE_DEPLOYMENT_SELECTOR):
- self.api_host = api_host
- self.qe_namespace = qe_namespace
- self.qe_config_secret = qe_config_secret
- self.qe_deployment_selector = qe_deployment_selector
- self.service_account_token = service_account_token
+ def __init__(
+ self,
+ api_host="",
+ service_account_token=SERVICE_ACCOUNT_TOKEN_PATH,
+ qe_namespace=DEFAULT_QE_NAMESPACE,
+ qe_config_secret=DEFAULT_QE_CONFIG_SECRET,
+ qe_deployment_selector=DEFAULT_QE_DEPLOYMENT_SELECTOR,
+ ):
+ self.api_host = api_host
+ self.qe_namespace = qe_namespace
+ self.qe_config_secret = qe_config_secret
+ self.qe_deployment_selector = qe_deployment_selector
+ self.service_account_token = service_account_token
- @classmethod
- def from_env(cls):
- # Load the service account token from the local store.
- if not os.path.exists(SERVICE_ACCOUNT_TOKEN_PATH):
- raise Exception('Cannot load Kubernetes service account token')
+ @classmethod
+ def from_env(cls):
+ # Load the service account token from the local store.
+ if not os.path.exists(SERVICE_ACCOUNT_TOKEN_PATH):
+ raise Exception("Cannot load Kubernetes service account token")
- with open(SERVICE_ACCOUNT_TOKEN_PATH, 'r') as f:
- service_token = f.read()
+ with open(SERVICE_ACCOUNT_TOKEN_PATH, "r") as f:
+ service_token = f.read()
- api_host = os.environ.get('KUBERNETES_SERVICE_HOST', '')
- port = os.environ.get('KUBERNETES_SERVICE_PORT')
- if port:
- api_host += ':' + port
+ api_host = os.environ.get("KUBERNETES_SERVICE_HOST", "")
+ port = os.environ.get("KUBERNETES_SERVICE_PORT")
+ if port:
+ api_host += ":" + port
- qe_namespace = get_k8s_namespace()
- qe_config_secret = os.environ.get('QE_K8S_CONFIG_SECRET', DEFAULT_QE_CONFIG_SECRET)
- qe_deployment_selector = os.environ.get('QE_DEPLOYMENT_SELECTOR', DEFAULT_QE_DEPLOYMENT_SELECTOR)
+ qe_namespace = get_k8s_namespace()
+ qe_config_secret = os.environ.get(
+ "QE_K8S_CONFIG_SECRET", DEFAULT_QE_CONFIG_SECRET
+ )
+ qe_deployment_selector = os.environ.get(
+ "QE_DEPLOYMENT_SELECTOR", DEFAULT_QE_DEPLOYMENT_SELECTOR
+ )
- return cls(api_host=api_host, service_account_token=service_token, qe_namespace=qe_namespace,
- qe_config_secret=qe_config_secret, qe_deployment_selector=qe_deployment_selector)
+ return cls(
+ api_host=api_host,
+ service_account_token=service_token,
+ qe_namespace=qe_namespace,
+ qe_config_secret=qe_config_secret,
+ qe_deployment_selector=qe_deployment_selector,
+ )
diff --git a/config_app/config_util/log.py b/config_app/config_util/log.py
index 783c9c2cd..65504debc 100644
--- a/config_app/config_util/log.py
+++ b/config_app/config_util/log.py
@@ -3,7 +3,7 @@ from config_app._init_config import CONF_DIR
def logfile_path(jsonfmt=False, debug=False):
- """
+ """
Returns the a logfileconf path following this rules:
- conf/logging_debug_json.conf # jsonfmt=true, debug=true
- conf/logging_json.conf # jsonfmt=true, debug=false
@@ -11,20 +11,20 @@ def logfile_path(jsonfmt=False, debug=False):
- conf/logging.conf # jsonfmt=false, debug=false
Can be parametrized via envvars: JSONLOG=true, DEBUGLOG=true
"""
- _json = ""
- _debug = ""
+ _json = ""
+ _debug = ""
- if jsonfmt or os.getenv('JSONLOG', 'false').lower() == 'true':
- _json = "_json"
+ if jsonfmt or os.getenv("JSONLOG", "false").lower() == "true":
+ _json = "_json"
- if debug or os.getenv('DEBUGLOG', 'false').lower() == 'true':
- _debug = "_debug"
+ if debug or os.getenv("DEBUGLOG", "false").lower() == "true":
+ _debug = "_debug"
- return os.path.join(CONF_DIR, "logging%s%s.conf" % (_debug, _json))
+ return os.path.join(CONF_DIR, "logging%s%s.conf" % (_debug, _json))
def filter_logs(values, filtered_fields):
- """
+ """
Takes a dict and a list of keys to filter.
eg:
with filtered_fields:
@@ -34,14 +34,14 @@ def filter_logs(values, filtered_fields):
the returned dict is:
{'k1': {k2: 'filtered'}, 'k3': 'some-value'}
"""
- for field in filtered_fields:
- cdict = values
+ for field in filtered_fields:
+ cdict = values
- for key in field['key'][:-1]:
- if key in cdict:
- cdict = cdict[key]
+ for key in field["key"][:-1]:
+ if key in cdict:
+ cdict = cdict[key]
- last_key = field['key'][-1]
+ last_key = field["key"][-1]
- if last_key in cdict and cdict[last_key]:
- cdict[last_key] = field['fn'](cdict[last_key])
+ if last_key in cdict and cdict[last_key]:
+ cdict[last_key] = field["fn"](cdict[last_key])
diff --git a/config_app/config_util/ssl.py b/config_app/config_util/ssl.py
index e246bc937..98a412588 100644
--- a/config_app/config_util/ssl.py
+++ b/config_app/config_util/ssl.py
@@ -4,82 +4,86 @@ import OpenSSL
class CertInvalidException(Exception):
- """ Exception raised when a certificate could not be parsed/loaded. """
- pass
+ """ Exception raised when a certificate could not be parsed/loaded. """
+
+ pass
class KeyInvalidException(Exception):
- """ Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
- pass
+ """ Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
+
+ pass
def load_certificate(cert_contents):
- """ Loads the certificate from the given contents and returns it or raises a CertInvalidException
+ """ Loads the certificate from the given contents and returns it or raises a CertInvalidException
on failure.
"""
- try:
- cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
- return SSLCertificate(cert)
- except OpenSSL.crypto.Error as ex:
- raise CertInvalidException(ex.message[0][2])
+ try:
+ cert = OpenSSL.crypto.load_certificate(
+ OpenSSL.crypto.FILETYPE_PEM, cert_contents
+ )
+ return SSLCertificate(cert)
+ except OpenSSL.crypto.Error as ex:
+ raise CertInvalidException(ex.message[0][2])
-_SUBJECT_ALT_NAME = 'subjectAltName'
+_SUBJECT_ALT_NAME = "subjectAltName"
class SSLCertificate(object):
- """ Helper class for easier working with SSL certificates. """
+ """ Helper class for easier working with SSL certificates. """
- def __init__(self, openssl_cert):
- self.openssl_cert = openssl_cert
+ def __init__(self, openssl_cert):
+ self.openssl_cert = openssl_cert
- def validate_private_key(self, private_key_path):
- """ Validates that the private key found at the given file path applies to this certificate.
+ def validate_private_key(self, private_key_path):
+ """ Validates that the private key found at the given file path applies to this certificate.
Raises a KeyInvalidException on failure.
"""
- context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
- context.use_certificate(self.openssl_cert)
+ context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
+ context.use_certificate(self.openssl_cert)
- try:
- context.use_privatekey_file(private_key_path)
- context.check_privatekey()
- except OpenSSL.SSL.Error as ex:
- raise KeyInvalidException(ex.message[0][2])
+ try:
+ context.use_privatekey_file(private_key_path)
+ context.check_privatekey()
+ except OpenSSL.SSL.Error as ex:
+ raise KeyInvalidException(ex.message[0][2])
- def matches_name(self, check_name):
- """ Returns true if this SSL certificate matches the given DNS hostname. """
- for dns_name in self.names:
- if fnmatch(check_name, dns_name):
- return True
+ def matches_name(self, check_name):
+ """ Returns true if this SSL certificate matches the given DNS hostname. """
+ for dns_name in self.names:
+ if fnmatch(check_name, dns_name):
+ return True
- return False
+ return False
- @property
- def expired(self):
- """ Returns whether the SSL certificate has expired. """
- return self.openssl_cert.has_expired()
+ @property
+ def expired(self):
+ """ Returns whether the SSL certificate has expired. """
+ return self.openssl_cert.has_expired()
- @property
- def common_name(self):
- """ Returns the defined common name for the certificate, if any. """
- return self.openssl_cert.get_subject().commonName
+ @property
+ def common_name(self):
+ """ Returns the defined common name for the certificate, if any. """
+ return self.openssl_cert.get_subject().commonName
- @property
- def names(self):
- """ Returns all the DNS named to which the certificate applies. May be empty. """
- dns_names = set()
- common_name = self.common_name
- if common_name is not None:
- dns_names.add(common_name)
+ @property
+ def names(self):
+ """ Returns all the DNS named to which the certificate applies. May be empty. """
+ dns_names = set()
+ common_name = self.common_name
+ if common_name is not None:
+ dns_names.add(common_name)
- # Find the DNS extension, if any.
- for i in range(0, self.openssl_cert.get_extension_count()):
- ext = self.openssl_cert.get_extension(i)
- if ext.get_short_name() == _SUBJECT_ALT_NAME:
- value = str(ext)
- for san_name in value.split(','):
- san_name_trimmed = san_name.strip()
- if san_name_trimmed.startswith('DNS:'):
- dns_names.add(san_name_trimmed[4:])
+ # Find the DNS extension, if any.
+ for i in range(0, self.openssl_cert.get_extension_count()):
+ ext = self.openssl_cert.get_extension(i)
+ if ext.get_short_name() == _SUBJECT_ALT_NAME:
+ value = str(ext)
+ for san_name in value.split(","):
+ san_name_trimmed = san_name.strip()
+ if san_name_trimmed.startswith("DNS:"):
+ dns_names.add(san_name_trimmed[4:])
- return dns_names
+ return dns_names
diff --git a/config_app/config_util/tar.py b/config_app/config_util/tar.py
index c1dd2e608..bdff6143e 100644
--- a/config_app/config_util/tar.py
+++ b/config_app/config_util/tar.py
@@ -2,21 +2,21 @@ from util.config.validator import EXTRA_CA_DIRECTORY
def strip_absolute_path_and_add_trailing_dir(path):
- """
+ """
Removes the initial trailing / from the prefix path, and add the last dir one
"""
- return path[1:] + '/'
+ return path[1:] + "/"
def tarinfo_filter_partial(prefix):
- def tarinfo_filter(tarinfo):
- # remove leading directory info
- tarinfo.name = tarinfo.name.replace(prefix, '')
+ def tarinfo_filter(tarinfo):
+ # remove leading directory info
+ tarinfo.name = tarinfo.name.replace(prefix, "")
- # ignore any directory that isn't the specified extra ca one:
- if tarinfo.isdir() and not tarinfo.name == EXTRA_CA_DIRECTORY:
- return None
+ # ignore any directory that isn't the specified extra ca one:
+ if tarinfo.isdir() and not tarinfo.name == EXTRA_CA_DIRECTORY:
+ return None
- return tarinfo
+ return tarinfo
- return tarinfo_filter
+ return tarinfo_filter
diff --git a/config_app/config_util/test/test_k8saccessor.py b/config_app/config_util/test/test_k8saccessor.py
index 9cf817064..ee99245d6 100644
--- a/config_app/config_util/test/test_k8saccessor.py
+++ b/config_app/config_util/test/test_k8saccessor.py
@@ -2,115 +2,218 @@ import pytest
from httmock import urlmatch, HTTMock, response
-from config_app.config_util.k8saccessor import KubernetesAccessorSingleton, _deployment_rollout_status_message
+from config_app.config_util.k8saccessor import (
+ KubernetesAccessorSingleton,
+ _deployment_rollout_status_message,
+)
from config_app.config_util.k8sconfig import KubernetesConfig
-@pytest.mark.parametrize('deployment_object, expected_status, expected_message', [
- ({'metadata': {'generation': 1},
- 'status': {'observedGeneration': 0, 'conditions': []},
- 'spec': {'replicas': 0}},
- 'progressing',
- 'Waiting for deployment spec to be updated...'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [{'type': 'Progressing', 'reason': 'ProgressDeadlineExceeded'}]},
- 'spec': {'replicas': 0}},
- 'failed',
- "Deployment my-deployment's rollout failed. Please try again later."),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': []},
- 'spec': {'replicas': 0}},
- 'available',
- 'Deployment my-deployment updated (no replicas, so nothing to roll out)'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [], 'replicas': 1},
- 'spec': {'replicas': 2}},
- 'progressing',
- 'Waiting for rollout to finish: 0 out of 2 new replicas have been updated...'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [], 'replicas': 1, 'updatedReplicas': 1},
- 'spec': {'replicas': 2}},
- 'progressing',
- 'Waiting for rollout to finish: 1 out of 2 new replicas have been updated...'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [], 'replicas': 2, 'updatedReplicas': 1},
- 'spec': {'replicas': 1}},
- 'progressing',
- 'Waiting for rollout to finish: 1 old replicas are pending termination...'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [], 'replicas': 1, 'updatedReplicas': 2, 'availableReplicas': 0},
- 'spec': {'replicas': 0}},
- 'progressing',
- 'Waiting for rollout to finish: 0 of 2 updated replicas are available...'),
- ({'metadata': {'generation': 0},
- 'status': {'observedGeneration': 0, 'conditions': [], 'replicas': 1, 'updatedReplicas': 2, 'availableReplicas': 2},
- 'spec': {'replicas': 0}},
- 'available',
- 'Deployment my-deployment successfully rolled out.'),
-])
-def test_deployment_rollout_status_message(deployment_object, expected_status, expected_message):
- deployment_status = _deployment_rollout_status_message(deployment_object, 'my-deployment')
- assert deployment_status.status == expected_status
- assert deployment_status.message == expected_message
+@pytest.mark.parametrize(
+ "deployment_object, expected_status, expected_message",
+ [
+ (
+ {
+ "metadata": {"generation": 1},
+ "status": {"observedGeneration": 0, "conditions": []},
+ "spec": {"replicas": 0},
+ },
+ "progressing",
+ "Waiting for deployment spec to be updated...",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {
+ "observedGeneration": 0,
+ "conditions": [
+ {"type": "Progressing", "reason": "ProgressDeadlineExceeded"}
+ ],
+ },
+ "spec": {"replicas": 0},
+ },
+ "failed",
+ "Deployment my-deployment's rollout failed. Please try again later.",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {"observedGeneration": 0, "conditions": []},
+ "spec": {"replicas": 0},
+ },
+ "available",
+ "Deployment my-deployment updated (no replicas, so nothing to roll out)",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {"observedGeneration": 0, "conditions": [], "replicas": 1},
+ "spec": {"replicas": 2},
+ },
+ "progressing",
+ "Waiting for rollout to finish: 0 out of 2 new replicas have been updated...",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {
+ "observedGeneration": 0,
+ "conditions": [],
+ "replicas": 1,
+ "updatedReplicas": 1,
+ },
+ "spec": {"replicas": 2},
+ },
+ "progressing",
+ "Waiting for rollout to finish: 1 out of 2 new replicas have been updated...",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {
+ "observedGeneration": 0,
+ "conditions": [],
+ "replicas": 2,
+ "updatedReplicas": 1,
+ },
+ "spec": {"replicas": 1},
+ },
+ "progressing",
+ "Waiting for rollout to finish: 1 old replicas are pending termination...",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {
+ "observedGeneration": 0,
+ "conditions": [],
+ "replicas": 1,
+ "updatedReplicas": 2,
+ "availableReplicas": 0,
+ },
+ "spec": {"replicas": 0},
+ },
+ "progressing",
+ "Waiting for rollout to finish: 0 of 2 updated replicas are available...",
+ ),
+ (
+ {
+ "metadata": {"generation": 0},
+ "status": {
+ "observedGeneration": 0,
+ "conditions": [],
+ "replicas": 1,
+ "updatedReplicas": 2,
+ "availableReplicas": 2,
+ },
+ "spec": {"replicas": 0},
+ },
+ "available",
+ "Deployment my-deployment successfully rolled out.",
+ ),
+ ],
+)
+def test_deployment_rollout_status_message(
+ deployment_object, expected_status, expected_message
+):
+ deployment_status = _deployment_rollout_status_message(
+ deployment_object, "my-deployment"
+ )
+ assert deployment_status.status == expected_status
+ assert deployment_status.message == expected_message
-@pytest.mark.parametrize('kube_config, expected_api, expected_query', [
- ({'api_host': 'www.customhost.com'},
- '/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments', 'labelSelector=quay-enterprise-component%3Dapp'),
-
- ({'api_host': 'www.customhost.com', 'qe_deployment_selector': 'custom-selector'},
- '/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments',
- 'labelSelector=quay-enterprise-component%3Dcustom-selector'),
-
- ({'api_host': 'www.customhost.com', 'qe_namespace': 'custom-namespace'},
- '/apis/extensions/v1beta1/namespaces/custom-namespace/deployments', 'labelSelector=quay-enterprise-component%3Dapp'),
-
- ({'api_host': 'www.customhost.com', 'qe_namespace': 'custom-namespace', 'qe_deployment_selector': 'custom-selector'},
- '/apis/extensions/v1beta1/namespaces/custom-namespace/deployments',
- 'labelSelector=quay-enterprise-component%3Dcustom-selector'),
-])
+@pytest.mark.parametrize(
+ "kube_config, expected_api, expected_query",
+ [
+ (
+ {"api_host": "www.customhost.com"},
+ "/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments",
+ "labelSelector=quay-enterprise-component%3Dapp",
+ ),
+ (
+ {
+ "api_host": "www.customhost.com",
+ "qe_deployment_selector": "custom-selector",
+ },
+ "/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments",
+ "labelSelector=quay-enterprise-component%3Dcustom-selector",
+ ),
+ (
+ {"api_host": "www.customhost.com", "qe_namespace": "custom-namespace"},
+ "/apis/extensions/v1beta1/namespaces/custom-namespace/deployments",
+ "labelSelector=quay-enterprise-component%3Dapp",
+ ),
+ (
+ {
+ "api_host": "www.customhost.com",
+ "qe_namespace": "custom-namespace",
+ "qe_deployment_selector": "custom-selector",
+ },
+ "/apis/extensions/v1beta1/namespaces/custom-namespace/deployments",
+ "labelSelector=quay-enterprise-component%3Dcustom-selector",
+ ),
+ ],
+)
def test_get_qe_deployments(kube_config, expected_api, expected_query):
- config = KubernetesConfig(**kube_config)
- url_hit = [False]
+ config = KubernetesConfig(**kube_config)
+ url_hit = [False]
- @urlmatch(netloc=r'www.customhost.com')
- def handler(request, _):
- assert request.path == expected_api
- assert request.query == expected_query
- url_hit[0] = True
- return response(200, '{}')
+ @urlmatch(netloc=r"www.customhost.com")
+ def handler(request, _):
+ assert request.path == expected_api
+ assert request.query == expected_query
+ url_hit[0] = True
+ return response(200, "{}")
- with HTTMock(handler):
- KubernetesAccessorSingleton._instance = None
- assert KubernetesAccessorSingleton.get_instance(config).get_qe_deployments() is not None
+ with HTTMock(handler):
+ KubernetesAccessorSingleton._instance = None
+ assert (
+ KubernetesAccessorSingleton.get_instance(config).get_qe_deployments()
+ is not None
+ )
- assert url_hit[0]
+ assert url_hit[0]
-@pytest.mark.parametrize('kube_config, deployment_names, expected_api_hits', [
- ({'api_host': 'www.customhost.com'}, [], []),
- ({'api_host': 'www.customhost.com'}, ['myDeployment'],
- ['/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments/myDeployment']),
- ({'api_host': 'www.customhost.com', 'qe_namespace': 'custom-namespace'},
- ['myDeployment', 'otherDeployment'],
- ['/apis/extensions/v1beta1/namespaces/custom-namespace/deployments/myDeployment',
- '/apis/extensions/v1beta1/namespaces/custom-namespace/deployments/otherDeployment']),
-])
+@pytest.mark.parametrize(
+ "kube_config, deployment_names, expected_api_hits",
+ [
+ ({"api_host": "www.customhost.com"}, [], []),
+ (
+ {"api_host": "www.customhost.com"},
+ ["myDeployment"],
+ [
+ "/apis/extensions/v1beta1/namespaces/quay-enterprise/deployments/myDeployment"
+ ],
+ ),
+ (
+ {"api_host": "www.customhost.com", "qe_namespace": "custom-namespace"},
+ ["myDeployment", "otherDeployment"],
+ [
+ "/apis/extensions/v1beta1/namespaces/custom-namespace/deployments/myDeployment",
+ "/apis/extensions/v1beta1/namespaces/custom-namespace/deployments/otherDeployment",
+ ],
+ ),
+ ],
+)
def test_cycle_qe_deployments(kube_config, deployment_names, expected_api_hits):
- KubernetesAccessorSingleton._instance = None
+ KubernetesAccessorSingleton._instance = None
- config = KubernetesConfig(**kube_config)
- url_hit = [False] * len(expected_api_hits)
- i = [0]
+ config = KubernetesConfig(**kube_config)
+ url_hit = [False] * len(expected_api_hits)
+ i = [0]
- @urlmatch(netloc=r'www.customhost.com', method='PATCH')
- def handler(request, _):
- assert request.path == expected_api_hits[i[0]]
- url_hit[i[0]] = True
- i[0] += 1
- return response(200, '{}')
+ @urlmatch(netloc=r"www.customhost.com", method="PATCH")
+ def handler(request, _):
+ assert request.path == expected_api_hits[i[0]]
+ url_hit[i[0]] = True
+ i[0] += 1
+ return response(200, "{}")
- with HTTMock(handler):
- KubernetesAccessorSingleton.get_instance(config).cycle_qe_deployments(deployment_names)
+ with HTTMock(handler):
+ KubernetesAccessorSingleton.get_instance(config).cycle_qe_deployments(
+ deployment_names
+ )
- assert all(url_hit)
+ assert all(url_hit)
diff --git a/config_app/config_util/test/test_tar.py b/config_app/config_util/test/test_tar.py
index b5d2a5621..235a36929 100644
--- a/config_app/config_util/test/test_tar.py
+++ b/config_app/config_util/test/test_tar.py
@@ -8,25 +8,39 @@ from test.fixtures import *
class MockTarInfo:
- def __init__(self, name, isdir):
- self.name = name
- self.isdir = lambda: isdir
+ def __init__(self, name, isdir):
+ self.name = name
+ self.isdir = lambda: isdir
- def __eq__(self, other):
- return other is not None and self.name == other.name
+ def __eq__(self, other):
+ return other is not None and self.name == other.name
-@pytest.mark.parametrize('prefix,tarinfo,expected', [
- # It should handle simple files
- ('Users/sam/', MockTarInfo('Users/sam/config.yaml', False), MockTarInfo('config.yaml', False)),
- # It should allow the extra CA dir
- ('Users/sam/', MockTarInfo('Users/sam/%s' % EXTRA_CA_DIRECTORY, True), MockTarInfo('%s' % EXTRA_CA_DIRECTORY, True)),
- # it should allow a file in that extra dir
- ('Users/sam/', MockTarInfo('Users/sam/%s/cert.crt' % EXTRA_CA_DIRECTORY, False),
- MockTarInfo('%s/cert.crt' % EXTRA_CA_DIRECTORY, False)),
- # it should not allow a directory that isn't the CA dir
- ('Users/sam/', MockTarInfo('Users/sam/dirignore', True), None),
-])
+@pytest.mark.parametrize(
+ "prefix,tarinfo,expected",
+ [
+ # It should handle simple files
+ (
+ "Users/sam/",
+ MockTarInfo("Users/sam/config.yaml", False),
+ MockTarInfo("config.yaml", False),
+ ),
+ # It should allow the extra CA dir
+ (
+ "Users/sam/",
+ MockTarInfo("Users/sam/%s" % EXTRA_CA_DIRECTORY, True),
+ MockTarInfo("%s" % EXTRA_CA_DIRECTORY, True),
+ ),
+ # it should allow a file in that extra dir
+ (
+ "Users/sam/",
+ MockTarInfo("Users/sam/%s/cert.crt" % EXTRA_CA_DIRECTORY, False),
+ MockTarInfo("%s/cert.crt" % EXTRA_CA_DIRECTORY, False),
+ ),
+ # it should not allow a directory that isn't the CA dir
+ ("Users/sam/", MockTarInfo("Users/sam/dirignore", True), None),
+ ],
+)
def test_tarinfo_filter(prefix, tarinfo, expected):
- partial = tarinfo_filter_partial(prefix)
- assert partial(tarinfo) == expected
+ partial = tarinfo_filter_partial(prefix)
+ assert partial(tarinfo) == expected
diff --git a/config_app/config_web.py b/config_app/config_web.py
index bb283c3cf..7e1b21529 100644
--- a/config_app/config_web.py
+++ b/config_app/config_web.py
@@ -3,4 +3,4 @@ from config_app.config_endpoints.api import api_bp
from config_app.config_endpoints.setup_web import setup_web
application.register_blueprint(setup_web)
-application.register_blueprint(api_bp, url_prefix='/api')
+application.register_blueprint(api_bp, url_prefix="/api")
diff --git a/data/appr_model/__init__.py b/data/appr_model/__init__.py
index 7c9620864..d7cabc012 100644
--- a/data/appr_model/__init__.py
+++ b/data/appr_model/__init__.py
@@ -1,9 +1,9 @@
from data.appr_model import (
- blob,
- channel,
- manifest,
- manifest_list,
- package,
- release,
- tag,
+ blob,
+ channel,
+ manifest,
+ manifest_list,
+ package,
+ release,
+ tag,
)
diff --git a/data/appr_model/blob.py b/data/appr_model/blob.py
index d340a7491..9ee118994 100644
--- a/data/appr_model/blob.py
+++ b/data/appr_model/blob.py
@@ -6,71 +6,79 @@ from data.model import db_transaction
logger = logging.getLogger(__name__)
+
def _ensure_sha256_header(digest):
- if digest.startswith('sha256:'):
- return digest
- return 'sha256:' + digest
+ if digest.startswith("sha256:"):
+ return digest
+ return "sha256:" + digest
def get_blob(digest, models_ref):
- """ Find a blob by its digest. """
- Blob = models_ref.Blob
- return Blob.select().where(Blob.digest == _ensure_sha256_header(digest)).get()
+ """ Find a blob by its digest. """
+ Blob = models_ref.Blob
+ return Blob.select().where(Blob.digest == _ensure_sha256_header(digest)).get()
def get_or_create_blob(digest, size, media_type_name, locations, models_ref):
- """ Try to find a blob by its digest or create it. """
- Blob = models_ref.Blob
- BlobPlacement = models_ref.BlobPlacement
+ """ Try to find a blob by its digest or create it. """
+ Blob = models_ref.Blob
+ BlobPlacement = models_ref.BlobPlacement
- # Get or create the blog entry for the digest.
- try:
- blob = get_blob(digest, models_ref)
- logger.debug('Retrieved blob with digest %s', digest)
- except Blob.DoesNotExist:
- blob = Blob.create(digest=_ensure_sha256_header(digest),
- media_type_id=Blob.media_type.get_id(media_type_name),
- size=size)
- logger.debug('Created blob with digest %s', digest)
-
- # Add the locations to the blob.
- for location_name in locations:
- location_id = BlobPlacement.location.get_id(location_name)
+ # Get or create the blog entry for the digest.
try:
- BlobPlacement.create(blob=blob, location=location_id)
- except IntegrityError:
- logger.debug('Location %s already existing for blob %s', location_name, blob.id)
+ blob = get_blob(digest, models_ref)
+ logger.debug("Retrieved blob with digest %s", digest)
+ except Blob.DoesNotExist:
+ blob = Blob.create(
+ digest=_ensure_sha256_header(digest),
+ media_type_id=Blob.media_type.get_id(media_type_name),
+ size=size,
+ )
+ logger.debug("Created blob with digest %s", digest)
- return blob
+ # Add the locations to the blob.
+ for location_name in locations:
+ location_id = BlobPlacement.location.get_id(location_name)
+ try:
+ BlobPlacement.create(blob=blob, location=location_id)
+ except IntegrityError:
+ logger.debug(
+ "Location %s already existing for blob %s", location_name, blob.id
+ )
+
+ return blob
def get_blob_locations(digest, models_ref):
- """ Find all locations names for a blob. """
- Blob = models_ref.Blob
- BlobPlacement = models_ref.BlobPlacement
- BlobPlacementLocation = models_ref.BlobPlacementLocation
+ """ Find all locations names for a blob. """
+ Blob = models_ref.Blob
+ BlobPlacement = models_ref.BlobPlacement
+ BlobPlacementLocation = models_ref.BlobPlacementLocation
- return [x.name for x in
- BlobPlacementLocation
- .select()
- .join(BlobPlacement)
- .join(Blob)
- .where(Blob.digest == _ensure_sha256_header(digest))]
+ return [
+ x.name
+ for x in BlobPlacementLocation.select()
+ .join(BlobPlacement)
+ .join(Blob)
+ .where(Blob.digest == _ensure_sha256_header(digest))
+ ]
def ensure_blob_locations(models_ref, *names):
- BlobPlacementLocation = models_ref.BlobPlacementLocation
+ BlobPlacementLocation = models_ref.BlobPlacementLocation
- with db_transaction():
- locations = BlobPlacementLocation.select().where(BlobPlacementLocation.name << names)
+ with db_transaction():
+ locations = BlobPlacementLocation.select().where(
+ BlobPlacementLocation.name << names
+ )
- insert_names = list(names)
+ insert_names = list(names)
- for location in locations:
- insert_names.remove(location.name)
+ for location in locations:
+ insert_names.remove(location.name)
- if not insert_names:
- return
+ if not insert_names:
+ return
- data = [{'name': name} for name in insert_names]
- BlobPlacementLocation.insert_many(data).execute()
+ data = [{"name": name} for name in insert_names]
+ BlobPlacementLocation.insert_many(data).execute()
diff --git a/data/appr_model/channel.py b/data/appr_model/channel.py
index 3631d97a5..297cf144e 100644
--- a/data/appr_model/channel.py
+++ b/data/appr_model/channel.py
@@ -2,63 +2,68 @@ from data.appr_model import tag as tag_model
def get_channel_releases(repo, channel, models_ref):
- """ Return all previously linked tags.
+ """ Return all previously linked tags.
This works based upon Tag lifetimes.
"""
- Channel = models_ref.Channel
- Tag = models_ref.Tag
+ Channel = models_ref.Channel
+ Tag = models_ref.Tag
- tag_kind_id = Channel.tag_kind.get_id('channel')
- channel_name = channel.name
- return (Tag
- .select(Tag, Channel)
- .join(Channel, on=(Tag.id == Channel.linked_tag))
- .where(Channel.repository == repo,
- Channel.name == channel_name,
- Channel.tag_kind == tag_kind_id, Channel.lifetime_end != None)
- .order_by(Tag.lifetime_end))
+ tag_kind_id = Channel.tag_kind.get_id("channel")
+ channel_name = channel.name
+ return (
+ Tag.select(Tag, Channel)
+ .join(Channel, on=(Tag.id == Channel.linked_tag))
+ .where(
+ Channel.repository == repo,
+ Channel.name == channel_name,
+ Channel.tag_kind == tag_kind_id,
+ Channel.lifetime_end != None,
+ )
+ .order_by(Tag.lifetime_end)
+ )
def get_channel(repo, channel_name, models_ref):
- """ Find a Channel by name. """
- channel = tag_model.get_tag(repo, channel_name, models_ref, "channel")
- return channel
+ """ Find a Channel by name. """
+ channel = tag_model.get_tag(repo, channel_name, models_ref, "channel")
+ return channel
def get_tag_channels(repo, tag_name, models_ref, active=True):
- """ Find the Channels associated with a Tag. """
- Tag = models_ref.Tag
+ """ Find the Channels associated with a Tag. """
+ Tag = models_ref.Tag
- tag = tag_model.get_tag(repo, tag_name, models_ref, "release")
- query = tag.tag_parents
+ tag = tag_model.get_tag(repo, tag_name, models_ref, "release")
+ query = tag.tag_parents
- if active:
- query = tag_model.tag_is_alive(query, Tag)
+ if active:
+ query = tag_model.tag_is_alive(query, Tag)
- return query
+ return query
def delete_channel(repo, channel_name, models_ref):
- """ Delete a channel by name. """
- return tag_model.delete_tag(repo, channel_name, models_ref, "channel")
+ """ Delete a channel by name. """
+ return tag_model.delete_tag(repo, channel_name, models_ref, "channel")
def create_or_update_channel(repo, channel_name, tag_name, models_ref):
- """ Creates or updates a channel to include a particular tag. """
- tag = tag_model.get_tag(repo, tag_name, models_ref, 'release')
- return tag_model.create_or_update_tag(repo, channel_name, models_ref, linked_tag=tag,
- tag_kind="channel")
+ """ Creates or updates a channel to include a particular tag. """
+ tag = tag_model.get_tag(repo, tag_name, models_ref, "release")
+ return tag_model.create_or_update_tag(
+ repo, channel_name, models_ref, linked_tag=tag, tag_kind="channel"
+ )
def get_repo_channels(repo, models_ref):
- """ Creates or updates a channel to include a particular tag. """
- Channel = models_ref.Channel
- Tag = models_ref.Tag
+ """ Creates or updates a channel to include a particular tag. """
+ Channel = models_ref.Channel
+ Tag = models_ref.Tag
- tag_kind_id = Channel.tag_kind.get_id('channel')
- query = (Channel
- .select(Channel, Tag)
- .join(Tag, on=(Tag.id == Channel.linked_tag))
- .where(Channel.repository == repo,
- Channel.tag_kind == tag_kind_id))
- return tag_model.tag_is_alive(query, Channel)
+ tag_kind_id = Channel.tag_kind.get_id("channel")
+ query = (
+ Channel.select(Channel, Tag)
+ .join(Tag, on=(Tag.id == Channel.linked_tag))
+ .where(Channel.repository == repo, Channel.tag_kind == tag_kind_id)
+ )
+ return tag_model.tag_is_alive(query, Channel)
diff --git a/data/appr_model/manifest.py b/data/appr_model/manifest.py
index f08be8d9b..36a75d848 100644
--- a/data/appr_model/manifest.py
+++ b/data/appr_model/manifest.py
@@ -12,56 +12,65 @@ logger = logging.getLogger(__name__)
def _ensure_sha256_header(digest):
- if digest.startswith('sha256:'):
- return digest
- return 'sha256:' + digest
+ if digest.startswith("sha256:"):
+ return digest
+ return "sha256:" + digest
def _digest(manifestjson):
- return _ensure_sha256_header(hashlib.sha256(json.dumps(manifestjson, sort_keys=True)).hexdigest())
+ return _ensure_sha256_header(
+ hashlib.sha256(json.dumps(manifestjson, sort_keys=True)).hexdigest()
+ )
def get_manifest_query(digest, media_type, models_ref):
- Manifest = models_ref.Manifest
- return Manifest.select().where(Manifest.digest == _ensure_sha256_header(digest),
- Manifest.media_type == Manifest.media_type.get_id(media_type))
+ Manifest = models_ref.Manifest
+ return Manifest.select().where(
+ Manifest.digest == _ensure_sha256_header(digest),
+ Manifest.media_type == Manifest.media_type.get_id(media_type),
+ )
def get_manifest_with_blob(digest, media_type, models_ref):
- Blob = models_ref.Blob
- query = get_manifest_query(digest, media_type, models_ref)
- return query.join(Blob).get()
+ Blob = models_ref.Blob
+ query = get_manifest_query(digest, media_type, models_ref)
+ return query.join(Blob).get()
def get_or_create_manifest(manifest_json, media_type_name, models_ref):
- Manifest = models_ref.Manifest
- digest = _digest(manifest_json)
- try:
- manifest = get_manifest_query(digest, media_type_name, models_ref).get()
- except Manifest.DoesNotExist:
- with db_transaction():
- manifest = Manifest.create(digest=digest,
- manifest_json=manifest_json,
- media_type=Manifest.media_type.get_id(media_type_name))
- return manifest
+ Manifest = models_ref.Manifest
+ digest = _digest(manifest_json)
+ try:
+ manifest = get_manifest_query(digest, media_type_name, models_ref).get()
+ except Manifest.DoesNotExist:
+ with db_transaction():
+ manifest = Manifest.create(
+ digest=digest,
+ manifest_json=manifest_json,
+ media_type=Manifest.media_type.get_id(media_type_name),
+ )
+ return manifest
+
def get_manifest_types(repo, models_ref, release=None):
- """ Returns an array of MediaTypes.name for a repo, can filter by tag """
- Tag = models_ref.Tag
- ManifestListManifest = models_ref.ManifestListManifest
+ """ Returns an array of MediaTypes.name for a repo, can filter by tag """
+ Tag = models_ref.Tag
+ ManifestListManifest = models_ref.ManifestListManifest
- query = tag_model.tag_is_alive(Tag
- .select(MediaType.name)
- .join(ManifestListManifest,
- on=(ManifestListManifest.manifest_list == Tag.manifest_list))
- .join(MediaType,
- on=(ManifestListManifest.media_type == MediaType.id))
- .where(Tag.repository == repo,
- Tag.tag_kind == Tag.tag_kind.get_id('release')), Tag)
- if release:
- query = query.where(Tag.name == release)
+ query = tag_model.tag_is_alive(
+ Tag.select(MediaType.name)
+ .join(
+ ManifestListManifest,
+ on=(ManifestListManifest.manifest_list == Tag.manifest_list),
+ )
+ .join(MediaType, on=(ManifestListManifest.media_type == MediaType.id))
+ .where(Tag.repository == repo, Tag.tag_kind == Tag.tag_kind.get_id("release")),
+ Tag,
+ )
+ if release:
+ query = query.where(Tag.name == release)
- manifests = set()
- for m in query.distinct().tuples():
- manifests.add(get_media_type(m[0]))
- return manifests
+ manifests = set()
+ for m in query.distinct().tuples():
+ manifests.add(get_media_type(m[0]))
+ return manifests
diff --git a/data/appr_model/manifest_list.py b/data/appr_model/manifest_list.py
index 92b10be6e..87010ccc3 100644
--- a/data/appr_model/manifest_list.py
+++ b/data/appr_model/manifest_list.py
@@ -9,59 +9,80 @@ logger = logging.getLogger(__name__)
def _ensure_sha256_header(digest):
- if digest.startswith('sha256:'):
- return digest
- return 'sha256:' + digest
+ if digest.startswith("sha256:"):
+ return digest
+ return "sha256:" + digest
def _digest(manifestjson):
- return _ensure_sha256_header(hashlib.sha256(json.dumps(manifestjson, sort_keys=True)).hexdigest())
+ return _ensure_sha256_header(
+ hashlib.sha256(json.dumps(manifestjson, sort_keys=True)).hexdigest()
+ )
def get_manifest_list(digest, models_ref):
- ManifestList = models_ref.ManifestList
- return ManifestList.select().where(ManifestList.digest == _ensure_sha256_header(digest)).get()
+ ManifestList = models_ref.ManifestList
+ return (
+ ManifestList.select()
+ .where(ManifestList.digest == _ensure_sha256_header(digest))
+ .get()
+ )
-def get_or_create_manifest_list(manifest_list_json, media_type_name, schema_version, models_ref):
- ManifestList = models_ref.ManifestList
+def get_or_create_manifest_list(
+ manifest_list_json, media_type_name, schema_version, models_ref
+):
+ ManifestList = models_ref.ManifestList
- digest = _digest(manifest_list_json)
- media_type_id = ManifestList.media_type.get_id(media_type_name)
+ digest = _digest(manifest_list_json)
+ media_type_id = ManifestList.media_type.get_id(media_type_name)
- try:
- return get_manifest_list(digest, models_ref)
- except ManifestList.DoesNotExist:
- with db_transaction():
- manifestlist = ManifestList.create(digest=digest, manifest_list_json=manifest_list_json,
- schema_version=schema_version, media_type=media_type_id)
- return manifestlist
+ try:
+ return get_manifest_list(digest, models_ref)
+ except ManifestList.DoesNotExist:
+ with db_transaction():
+ manifestlist = ManifestList.create(
+ digest=digest,
+ manifest_list_json=manifest_list_json,
+ schema_version=schema_version,
+ media_type=media_type_id,
+ )
+ return manifestlist
-def create_manifestlistmanifest(manifestlist, manifest_ids, manifest_list_json, models_ref):
- """ From a manifestlist, manifests, and the manifest list blob,
+def create_manifestlistmanifest(
+ manifestlist, manifest_ids, manifest_list_json, models_ref
+):
+ """ From a manifestlist, manifests, and the manifest list blob,
create if doesn't exist the manfiestlistmanifest for each manifest """
- for pos in xrange(len(manifest_ids)):
- manifest_id = manifest_ids[pos]
- manifest_json = manifest_list_json[pos]
- get_or_create_manifestlistmanifest(manifest=manifest_id,
- manifestlist=manifestlist,
- media_type_name=manifest_json['mediaType'],
- models_ref=models_ref)
+ for pos in xrange(len(manifest_ids)):
+ manifest_id = manifest_ids[pos]
+ manifest_json = manifest_list_json[pos]
+ get_or_create_manifestlistmanifest(
+ manifest=manifest_id,
+ manifestlist=manifestlist,
+ media_type_name=manifest_json["mediaType"],
+ models_ref=models_ref,
+ )
-def get_or_create_manifestlistmanifest(manifest, manifestlist, media_type_name, models_ref):
- ManifestListManifest = models_ref.ManifestListManifest
+def get_or_create_manifestlistmanifest(
+ manifest, manifestlist, media_type_name, models_ref
+):
+ ManifestListManifest = models_ref.ManifestListManifest
- media_type_id = ManifestListManifest.media_type.get_id(media_type_name)
- try:
- ml = (ManifestListManifest
- .select()
- .where(ManifestListManifest.manifest == manifest,
- ManifestListManifest.media_type == media_type_id,
- ManifestListManifest.manifest_list == manifestlist)).get()
+ media_type_id = ManifestListManifest.media_type.get_id(media_type_name)
+ try:
+ ml = (
+ ManifestListManifest.select().where(
+ ManifestListManifest.manifest == manifest,
+ ManifestListManifest.media_type == media_type_id,
+ ManifestListManifest.manifest_list == manifestlist,
+ )
+ ).get()
- except ManifestListManifest.DoesNotExist:
- ml = ManifestListManifest.create(manifest_list=manifestlist, media_type=media_type_id,
- manifest=manifest)
- return ml
+ except ManifestListManifest.DoesNotExist:
+ ml = ManifestListManifest.create(
+ manifest_list=manifestlist, media_type=media_type_id, manifest=manifest
+ )
+ return ml
diff --git a/data/appr_model/models.py b/data/appr_model/models.py
index 0fde7d83c..aefdc9f94 100644
--- a/data/appr_model/models.py
+++ b/data/appr_model/models.py
@@ -1,15 +1,47 @@
from collections import namedtuple
-from data.database import (ApprTag, ApprTagKind, ApprBlobPlacementLocation, ApprManifestList,
- ApprManifestBlob, ApprBlob, ApprManifestListManifest, ApprManifest,
- ApprBlobPlacement, ApprChannel)
+from data.database import (
+ ApprTag,
+ ApprTagKind,
+ ApprBlobPlacementLocation,
+ ApprManifestList,
+ ApprManifestBlob,
+ ApprBlob,
+ ApprManifestListManifest,
+ ApprManifest,
+ ApprBlobPlacement,
+ ApprChannel,
+)
-ModelsRef = namedtuple('ModelsRef', ['Tag', 'TagKind', 'BlobPlacementLocation', 'ManifestList',
- 'ManifestBlob', 'Blob', 'ManifestListManifest', 'Manifest',
- 'BlobPlacement', 'Channel', 'manifestlistmanifest_set_name',
- 'tag_set_prefetch_name'])
+ModelsRef = namedtuple(
+ "ModelsRef",
+ [
+ "Tag",
+ "TagKind",
+ "BlobPlacementLocation",
+ "ManifestList",
+ "ManifestBlob",
+ "Blob",
+ "ManifestListManifest",
+ "Manifest",
+ "BlobPlacement",
+ "Channel",
+ "manifestlistmanifest_set_name",
+ "tag_set_prefetch_name",
+ ],
+)
-NEW_MODELS = ModelsRef(ApprTag, ApprTagKind, ApprBlobPlacementLocation, ApprManifestList,
- ApprManifestBlob, ApprBlob, ApprManifestListManifest, ApprManifest,
- ApprBlobPlacement, ApprChannel, 'apprmanifestlistmanifest_set',
- 'apprtag_set')
+NEW_MODELS = ModelsRef(
+ ApprTag,
+ ApprTagKind,
+ ApprBlobPlacementLocation,
+ ApprManifestList,
+ ApprManifestBlob,
+ ApprBlob,
+ ApprManifestListManifest,
+ ApprManifest,
+ ApprBlobPlacement,
+ ApprChannel,
+ "apprmanifestlistmanifest_set",
+ "apprtag_set",
+)
diff --git a/data/appr_model/package.py b/data/appr_model/package.py
index 97ea9f791..abb15a6ea 100644
--- a/data/appr_model/package.py
+++ b/data/appr_model/package.py
@@ -7,61 +7,70 @@ from data.database import Repository, Namespace
from data.appr_model import tag as tag_model
-def list_packages_query(models_ref, namespace=None, media_type=None, search_query=None,
- username=None):
- """ List and filter repository by search query. """
- Tag = models_ref.Tag
+def list_packages_query(
+ models_ref, namespace=None, media_type=None, search_query=None, username=None
+):
+ """ List and filter repository by search query. """
+ Tag = models_ref.Tag
- if username and not search_query:
- repositories = model.repository.get_visible_repositories(username,
- kind_filter='application',
- include_public=True,
- namespace=namespace,
- limit=50)
- if not repositories:
- return []
+ if username and not search_query:
+ repositories = model.repository.get_visible_repositories(
+ username,
+ kind_filter="application",
+ include_public=True,
+ namespace=namespace,
+ limit=50,
+ )
+ if not repositories:
+ return []
- repo_query = (Repository
- .select(Repository, Namespace.username)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.id << [repo.rid for repo in repositories]))
+ repo_query = (
+ Repository.select(Repository, Namespace.username)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Repository.id << [repo.rid for repo in repositories])
+ )
- if namespace:
- repo_query = (repo_query
- .where(Namespace.username == namespace))
- else:
- if search_query is not None:
- fields = [model.repository.SEARCH_FIELDS.name.name]
- repositories = model.repository.get_app_search(search_query,
- username=username,
- search_fields=fields,
- limit=50)
- if not repositories:
- return []
-
- repo_query = (Repository
- .select(Repository, Namespace.username)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.id << [repo.id for repo in repositories]))
+ if namespace:
+ repo_query = repo_query.where(Namespace.username == namespace)
else:
- repo_query = (Repository
- .select(Repository, Namespace.username)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.visibility == model.repository.get_public_repo_visibility(),
- Repository.kind == Repository.kind.get_id('application')))
+ if search_query is not None:
+ fields = [model.repository.SEARCH_FIELDS.name.name]
+ repositories = model.repository.get_app_search(
+ search_query, username=username, search_fields=fields, limit=50
+ )
+ if not repositories:
+ return []
- if namespace:
- repo_query = (repo_query
- .where(Namespace.username == namespace))
+ repo_query = (
+ Repository.select(Repository, Namespace.username)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Repository.id << [repo.id for repo in repositories])
+ )
+ else:
+ repo_query = (
+ Repository.select(Repository, Namespace.username)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ Repository.visibility
+ == model.repository.get_public_repo_visibility(),
+ Repository.kind == Repository.kind.get_id("application"),
+ )
+ )
- tag_query = (Tag
- .select()
- .where(Tag.tag_kind == Tag.tag_kind.get_id('release'))
- .order_by(Tag.lifetime_start))
+ if namespace:
+ repo_query = repo_query.where(Namespace.username == namespace)
- if media_type:
- tag_query = tag_model.filter_tags_by_media_type(tag_query, media_type, models_ref)
+ tag_query = (
+ Tag.select()
+ .where(Tag.tag_kind == Tag.tag_kind.get_id("release"))
+ .order_by(Tag.lifetime_start)
+ )
- tag_query = tag_model.tag_is_alive(tag_query, Tag)
- query = prefetch(repo_query, tag_query)
- return query
+ if media_type:
+ tag_query = tag_model.filter_tags_by_media_type(
+ tag_query, media_type, models_ref
+ )
+
+ tag_query = tag_model.tag_is_alive(tag_query, Tag)
+ query = prefetch(repo_query, tag_query)
+ return query
diff --git a/data/appr_model/release.py b/data/appr_model/release.py
index dcfa455d0..b5640a5cd 100644
--- a/data/appr_model/release.py
+++ b/data/appr_model/release.py
@@ -4,149 +4,186 @@ from cnr.exception import PackageAlreadyExists
from cnr.models.package_base import manifest_media_type
from data.database import db_transaction, get_epoch_timestamp
-from data.appr_model import (blob as blob_model, manifest as manifest_model,
- manifest_list as manifest_list_model,
- tag as tag_model)
+from data.appr_model import (
+ blob as blob_model,
+ manifest as manifest_model,
+ manifest_list as manifest_list_model,
+ tag as tag_model,
+)
-LIST_MEDIA_TYPE = 'application/vnd.cnr.manifest.list.v0.json'
-SCHEMA_VERSION = 'v0'
+LIST_MEDIA_TYPE = "application/vnd.cnr.manifest.list.v0.json"
+SCHEMA_VERSION = "v0"
def _ensure_sha256_header(digest):
- if digest.startswith('sha256:'):
- return digest
- return 'sha256:' + digest
+ if digest.startswith("sha256:"):
+ return digest
+ return "sha256:" + digest
def get_app_release(repo, tag_name, media_type, models_ref):
- """ Returns (tag, manifest, blob) given a repo object, tag_name, and media_type). """
- ManifestListManifest = models_ref.ManifestListManifest
- Manifest = models_ref.Manifest
- Blob = models_ref.Blob
- ManifestBlob = models_ref.ManifestBlob
- manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
+ """ Returns (tag, manifest, blob) given a repo object, tag_name, and media_type). """
+ ManifestListManifest = models_ref.ManifestListManifest
+ Manifest = models_ref.Manifest
+ Blob = models_ref.Blob
+ ManifestBlob = models_ref.ManifestBlob
+ manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
- tag = tag_model.get_tag(repo, tag_name, models_ref, tag_kind='release')
- media_type_id = ManifestListManifest.media_type.get_id(manifest_media_type(media_type))
- manifestlistmanifest = (getattr(tag.manifest_list, manifestlistmanifest_set_name)
- .join(Manifest)
- .where(ManifestListManifest.media_type == media_type_id).get())
- manifest = manifestlistmanifest.manifest
- blob = Blob.select().join(ManifestBlob).where(ManifestBlob.manifest == manifest).get()
- return (tag, manifest, blob)
+ tag = tag_model.get_tag(repo, tag_name, models_ref, tag_kind="release")
+ media_type_id = ManifestListManifest.media_type.get_id(
+ manifest_media_type(media_type)
+ )
+ manifestlistmanifest = (
+ getattr(tag.manifest_list, manifestlistmanifest_set_name)
+ .join(Manifest)
+ .where(ManifestListManifest.media_type == media_type_id)
+ .get()
+ )
+ manifest = manifestlistmanifest.manifest
+ blob = (
+ Blob.select().join(ManifestBlob).where(ManifestBlob.manifest == manifest).get()
+ )
+ return (tag, manifest, blob)
def delete_app_release(repo, tag_name, media_type, models_ref):
- """ Terminate a Tag/media-type couple
+ """ Terminate a Tag/media-type couple
It find the corresponding tag/manifest and remove from the manifestlistmanifest the manifest
1. it terminates the current tag (in all-cases)
2. if the new manifestlist is not empty, it creates a new tag for it
"""
- ManifestListManifest = models_ref.ManifestListManifest
- manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
+ ManifestListManifest = models_ref.ManifestListManifest
+ manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
- media_type_id = ManifestListManifest.media_type.get_id(manifest_media_type(media_type))
+ media_type_id = ManifestListManifest.media_type.get_id(
+ manifest_media_type(media_type)
+ )
- with db_transaction():
- tag = tag_model.get_tag(repo, tag_name, models_ref)
- manifest_list = tag.manifest_list
- list_json = manifest_list.manifest_list_json
- mlm_query = (ManifestListManifest
- .select()
- .where(ManifestListManifest.manifest_list == tag.manifest_list))
- list_manifest_ids = sorted([mlm.manifest_id for mlm in mlm_query])
- manifestlistmanifest = (getattr(tag.manifest_list, manifestlistmanifest_set_name)
- .where(ManifestListManifest.media_type == media_type_id).get())
- index = list_manifest_ids.index(manifestlistmanifest.manifest_id)
- list_manifest_ids.pop(index)
- list_json.pop(index)
+ with db_transaction():
+ tag = tag_model.get_tag(repo, tag_name, models_ref)
+ manifest_list = tag.manifest_list
+ list_json = manifest_list.manifest_list_json
+ mlm_query = ManifestListManifest.select().where(
+ ManifestListManifest.manifest_list == tag.manifest_list
+ )
+ list_manifest_ids = sorted([mlm.manifest_id for mlm in mlm_query])
+ manifestlistmanifest = (
+ getattr(tag.manifest_list, manifestlistmanifest_set_name)
+ .where(ManifestListManifest.media_type == media_type_id)
+ .get()
+ )
+ index = list_manifest_ids.index(manifestlistmanifest.manifest_id)
+ list_manifest_ids.pop(index)
+ list_json.pop(index)
- if not list_json:
- tag.lifetime_end = get_epoch_timestamp()
- tag.save()
- else:
- manifestlist = manifest_list_model.get_or_create_manifest_list(list_json, LIST_MEDIA_TYPE,
- SCHEMA_VERSION, models_ref)
- manifest_list_model.create_manifestlistmanifest(manifestlist, list_manifest_ids,
- list_json, models_ref)
- tag = tag_model.create_or_update_tag(repo, tag_name, models_ref, manifest_list=manifestlist,
- tag_kind="release")
- return tag
+ if not list_json:
+ tag.lifetime_end = get_epoch_timestamp()
+ tag.save()
+ else:
+ manifestlist = manifest_list_model.get_or_create_manifest_list(
+ list_json, LIST_MEDIA_TYPE, SCHEMA_VERSION, models_ref
+ )
+ manifest_list_model.create_manifestlistmanifest(
+ manifestlist, list_manifest_ids, list_json, models_ref
+ )
+ tag = tag_model.create_or_update_tag(
+ repo,
+ tag_name,
+ models_ref,
+ manifest_list=manifestlist,
+ tag_kind="release",
+ )
+ return tag
def create_app_release(repo, tag_name, manifest_data, digest, models_ref, force=False):
- """ Create a new application release, it includes creating a new Tag, ManifestList,
+ """ Create a new application release, it includes creating a new Tag, ManifestList,
ManifestListManifests, Manifest, ManifestBlob.
To deduplicate the ManifestList, the manifestlist_json is kept ordered by the manifest.id.
To find the insert point in the ManifestList it uses bisect on the manifest-ids list.
"""
- ManifestList = models_ref.ManifestList
- ManifestListManifest = models_ref.ManifestListManifest
- Blob = models_ref.Blob
- ManifestBlob = models_ref.ManifestBlob
+ ManifestList = models_ref.ManifestList
+ ManifestListManifest = models_ref.ManifestListManifest
+ Blob = models_ref.Blob
+ ManifestBlob = models_ref.ManifestBlob
- with db_transaction():
- # Create/get the package manifest
- manifest = manifest_model.get_or_create_manifest(manifest_data, manifest_data['mediaType'],
- models_ref)
- # get the tag
- tag = tag_model.get_or_initialize_tag(repo, tag_name, models_ref)
+ with db_transaction():
+ # Create/get the package manifest
+ manifest = manifest_model.get_or_create_manifest(
+ manifest_data, manifest_data["mediaType"], models_ref
+ )
+ # get the tag
+ tag = tag_model.get_or_initialize_tag(repo, tag_name, models_ref)
- if tag.manifest_list is None:
- tag.manifest_list = ManifestList(media_type=ManifestList.media_type.get_id(LIST_MEDIA_TYPE),
- schema_version=SCHEMA_VERSION,
- manifest_list_json=[], )
+ if tag.manifest_list is None:
+ tag.manifest_list = ManifestList(
+ media_type=ManifestList.media_type.get_id(LIST_MEDIA_TYPE),
+ schema_version=SCHEMA_VERSION,
+ manifest_list_json=[],
+ )
- elif tag_model.tag_media_type_exists(tag, manifest.media_type, models_ref):
- if force:
- delete_app_release(repo, tag_name, manifest.media_type.name, models_ref)
- return create_app_release(repo, tag_name, manifest_data, digest, models_ref, force=False)
- else:
- raise PackageAlreadyExists("package exists already")
+ elif tag_model.tag_media_type_exists(tag, manifest.media_type, models_ref):
+ if force:
+ delete_app_release(repo, tag_name, manifest.media_type.name, models_ref)
+ return create_app_release(
+ repo, tag_name, manifest_data, digest, models_ref, force=False
+ )
+ else:
+ raise PackageAlreadyExists("package exists already")
- list_json = tag.manifest_list.manifest_list_json
- mlm_query = (ManifestListManifest
- .select()
- .where(ManifestListManifest.manifest_list == tag.manifest_list))
- list_manifest_ids = sorted([mlm.manifest_id for mlm in mlm_query])
- insert_point = bisect.bisect_left(list_manifest_ids, manifest.id)
- list_json.insert(insert_point, manifest.manifest_json)
- list_manifest_ids.insert(insert_point, manifest.id)
- manifestlist = manifest_list_model.get_or_create_manifest_list(list_json, LIST_MEDIA_TYPE,
- SCHEMA_VERSION, models_ref)
- manifest_list_model.create_manifestlistmanifest(manifestlist, list_manifest_ids, list_json,
- models_ref)
+ list_json = tag.manifest_list.manifest_list_json
+ mlm_query = ManifestListManifest.select().where(
+ ManifestListManifest.manifest_list == tag.manifest_list
+ )
+ list_manifest_ids = sorted([mlm.manifest_id for mlm in mlm_query])
+ insert_point = bisect.bisect_left(list_manifest_ids, manifest.id)
+ list_json.insert(insert_point, manifest.manifest_json)
+ list_manifest_ids.insert(insert_point, manifest.id)
+ manifestlist = manifest_list_model.get_or_create_manifest_list(
+ list_json, LIST_MEDIA_TYPE, SCHEMA_VERSION, models_ref
+ )
+ manifest_list_model.create_manifestlistmanifest(
+ manifestlist, list_manifest_ids, list_json, models_ref
+ )
- tag = tag_model.create_or_update_tag(repo, tag_name, models_ref, manifest_list=manifestlist,
- tag_kind="release")
- blob_digest = digest
+ tag = tag_model.create_or_update_tag(
+ repo, tag_name, models_ref, manifest_list=manifestlist, tag_kind="release"
+ )
+ blob_digest = digest
+
+ try:
+ (
+ ManifestBlob.select()
+ .join(Blob)
+ .where(
+ ManifestBlob.manifest == manifest,
+ Blob.digest == _ensure_sha256_header(blob_digest),
+ )
+ .get()
+ )
+ except ManifestBlob.DoesNotExist:
+ blob = blob_model.get_blob(blob_digest, models_ref)
+ ManifestBlob.create(manifest=manifest, blob=blob)
+ return tag
- try:
- (ManifestBlob
- .select()
- .join(Blob)
- .where(ManifestBlob.manifest == manifest,
- Blob.digest == _ensure_sha256_header(blob_digest)).get())
- except ManifestBlob.DoesNotExist:
- blob = blob_model.get_blob(blob_digest, models_ref)
- ManifestBlob.create(manifest=manifest, blob=blob)
- return tag
def get_release_objs(repo, models_ref, media_type=None):
- """ Returns an array of Tag for a repo, with optional filtering by media_type. """
- Tag = models_ref.Tag
+ """ Returns an array of Tag for a repo, with optional filtering by media_type. """
+ Tag = models_ref.Tag
- release_query = (Tag
- .select()
- .where(Tag.repository == repo,
- Tag.tag_kind == Tag.tag_kind.get_id("release")))
- if media_type:
- release_query = tag_model.filter_tags_by_media_type(release_query, media_type, models_ref)
+ release_query = Tag.select().where(
+ Tag.repository == repo, Tag.tag_kind == Tag.tag_kind.get_id("release")
+ )
+ if media_type:
+ release_query = tag_model.filter_tags_by_media_type(
+ release_query, media_type, models_ref
+ )
+
+ return tag_model.tag_is_alive(release_query, Tag)
- return tag_model.tag_is_alive(release_query, Tag)
def get_releases(repo, model_refs, media_type=None):
- """ Returns an array of Tag.name for a repo, can filter by media_type. """
- return [t.name for t in get_release_objs(repo, model_refs, media_type)]
+ """ Returns an array of Tag.name for a repo, can filter by media_type. """
+ return [t.name for t in get_release_objs(repo, model_refs, media_type)]
diff --git a/data/appr_model/tag.py b/data/appr_model/tag.py
index 4903a4572..de353bd52 100644
--- a/data/appr_model/tag.py
+++ b/data/appr_model/tag.py
@@ -3,7 +3,7 @@ import logging
from cnr.models.package_base import manifest_media_type
from peewee import IntegrityError
-from data.model import (db_transaction, TagAlreadyCreatedException)
+from data.model import db_transaction, TagAlreadyCreatedException
from data.database import get_epoch_timestamp_ms, db_for_update
@@ -11,89 +11,120 @@ logger = logging.getLogger(__name__)
def tag_is_alive(query, cls, now_ts=None):
- return query.where((cls.lifetime_end >> None) |
- (cls.lifetime_end > now_ts))
+ return query.where((cls.lifetime_end >> None) | (cls.lifetime_end > now_ts))
def tag_media_type_exists(tag, media_type, models_ref):
- ManifestListManifest = models_ref.ManifestListManifest
- manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
- return (getattr(tag.manifest_list, manifestlistmanifest_set_name)
- .where(ManifestListManifest.media_type == media_type).count() > 0)
+ ManifestListManifest = models_ref.ManifestListManifest
+ manifestlistmanifest_set_name = models_ref.manifestlistmanifest_set_name
+ return (
+ getattr(tag.manifest_list, manifestlistmanifest_set_name)
+ .where(ManifestListManifest.media_type == media_type)
+ .count()
+ > 0
+ )
-def create_or_update_tag(repo, tag_name, models_ref, manifest_list=None, linked_tag=None,
- tag_kind="release"):
- Tag = models_ref.Tag
+def create_or_update_tag(
+ repo, tag_name, models_ref, manifest_list=None, linked_tag=None, tag_kind="release"
+):
+ Tag = models_ref.Tag
- now_ts = get_epoch_timestamp_ms()
- tag_kind_id = Tag.tag_kind.get_id(tag_kind)
- with db_transaction():
- try:
- tag = db_for_update(tag_is_alive(Tag
- .select()
- .where(Tag.repository == repo,
- Tag.name == tag_name,
- Tag.tag_kind == tag_kind_id), Tag, now_ts)).get()
- if tag.manifest_list == manifest_list and tag.linked_tag == linked_tag:
- return tag
- tag.lifetime_end = now_ts
- tag.save()
- except Tag.DoesNotExist:
- pass
+ now_ts = get_epoch_timestamp_ms()
+ tag_kind_id = Tag.tag_kind.get_id(tag_kind)
+ with db_transaction():
+ try:
+ tag = db_for_update(
+ tag_is_alive(
+ Tag.select().where(
+ Tag.repository == repo,
+ Tag.name == tag_name,
+ Tag.tag_kind == tag_kind_id,
+ ),
+ Tag,
+ now_ts,
+ )
+ ).get()
+ if tag.manifest_list == manifest_list and tag.linked_tag == linked_tag:
+ return tag
+ tag.lifetime_end = now_ts
+ tag.save()
+ except Tag.DoesNotExist:
+ pass
- try:
- return Tag.create(repository=repo, manifest_list=manifest_list, linked_tag=linked_tag,
- name=tag_name, lifetime_start=now_ts, lifetime_end=None,
- tag_kind=tag_kind_id)
- except IntegrityError:
- msg = 'Tag with name %s and lifetime start %s under repository %s/%s already exists'
- raise TagAlreadyCreatedException(msg % (tag_name, now_ts, repo.namespace_user, repo.name))
+ try:
+ return Tag.create(
+ repository=repo,
+ manifest_list=manifest_list,
+ linked_tag=linked_tag,
+ name=tag_name,
+ lifetime_start=now_ts,
+ lifetime_end=None,
+ tag_kind=tag_kind_id,
+ )
+ except IntegrityError:
+ msg = "Tag with name %s and lifetime start %s under repository %s/%s already exists"
+ raise TagAlreadyCreatedException(
+ msg % (tag_name, now_ts, repo.namespace_user, repo.name)
+ )
def get_or_initialize_tag(repo, tag_name, models_ref, tag_kind="release"):
- Tag = models_ref.Tag
-
- try:
- return tag_is_alive(Tag.select().where(Tag.repository == repo, Tag.name == tag_name), Tag).get()
- except Tag.DoesNotExist:
- return Tag(repo=repo, name=tag_name, tag_kind=Tag.tag_kind.get_id(tag_kind))
+ Tag = models_ref.Tag
+
+ try:
+ return tag_is_alive(
+ Tag.select().where(Tag.repository == repo, Tag.name == tag_name), Tag
+ ).get()
+ except Tag.DoesNotExist:
+ return Tag(repo=repo, name=tag_name, tag_kind=Tag.tag_kind.get_id(tag_kind))
def get_tag(repo, tag_name, models_ref, tag_kind="release"):
- Tag = models_ref.Tag
- return tag_is_alive(Tag.select()
- .where(Tag.repository == repo,
- Tag.name == tag_name,
- Tag.tag_kind == Tag.tag_kind.get_id(tag_kind)), Tag).get()
+ Tag = models_ref.Tag
+ return tag_is_alive(
+ Tag.select().where(
+ Tag.repository == repo,
+ Tag.name == tag_name,
+ Tag.tag_kind == Tag.tag_kind.get_id(tag_kind),
+ ),
+ Tag,
+ ).get()
def delete_tag(repo, tag_name, models_ref, tag_kind="release"):
- Tag = models_ref.Tag
- tag_kind_id = Tag.tag_kind.get_id(tag_kind)
- tag = tag_is_alive(Tag.select()
- .where(Tag.repository == repo,
- Tag.name == tag_name, Tag.tag_kind == tag_kind_id), Tag).get()
- tag.lifetime_end = get_epoch_timestamp_ms()
- tag.save()
- return tag
+ Tag = models_ref.Tag
+ tag_kind_id = Tag.tag_kind.get_id(tag_kind)
+ tag = tag_is_alive(
+ Tag.select().where(
+ Tag.repository == repo, Tag.name == tag_name, Tag.tag_kind == tag_kind_id
+ ),
+ Tag,
+ ).get()
+ tag.lifetime_end = get_epoch_timestamp_ms()
+ tag.save()
+ return tag
def tag_exists(repo, tag_name, models_ref, tag_kind="release"):
- Tag = models_ref.Tag
- try:
- get_tag(repo, tag_name, models_ref, tag_kind)
- return True
- except Tag.DoesNotExist:
- return False
+ Tag = models_ref.Tag
+ try:
+ get_tag(repo, tag_name, models_ref, tag_kind)
+ return True
+ except Tag.DoesNotExist:
+ return False
def filter_tags_by_media_type(tag_query, media_type, models_ref):
- """ Return only available tag for a media_type. """
- ManifestListManifest = models_ref.ManifestListManifest
- Tag = models_ref.Tag
- media_type = manifest_media_type(media_type)
- t = (tag_query
- .join(ManifestListManifest, on=(ManifestListManifest.manifest_list == Tag.manifest_list))
- .where(ManifestListManifest.media_type == ManifestListManifest.media_type.get_id(media_type)))
- return t
+ """ Return only available tag for a media_type. """
+ ManifestListManifest = models_ref.ManifestListManifest
+ Tag = models_ref.Tag
+ media_type = manifest_media_type(media_type)
+ t = tag_query.join(
+ ManifestListManifest,
+ on=(ManifestListManifest.manifest_list == Tag.manifest_list),
+ ).where(
+ ManifestListManifest.media_type
+ == ManifestListManifest.media_type.get_id(media_type)
+ )
+ return t
diff --git a/data/archivedlogs.py b/data/archivedlogs.py
index 0172c74c8..d2b3fbabb 100644
--- a/data/archivedlogs.py
+++ b/data/archivedlogs.py
@@ -6,32 +6,33 @@ from flask import send_file, abort
from data.userfiles import DelegateUserfiles, UserfilesHandlers
-JSON_MIMETYPE = 'application/json'
+JSON_MIMETYPE = "application/json"
logger = logging.getLogger(__name__)
class LogArchive(object):
- def __init__(self, app=None, distributed_storage=None):
- self.app = app
- if app is not None:
- self.state = self.init_app(app, distributed_storage)
- else:
- self.state = None
+ def __init__(self, app=None, distributed_storage=None):
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(app, distributed_storage)
+ else:
+ self.state = None
- def init_app(self, app, distributed_storage):
- location = app.config.get('LOG_ARCHIVE_LOCATION')
- path = app.config.get('LOG_ARCHIVE_PATH', None)
+ def init_app(self, app, distributed_storage):
+ location = app.config.get("LOG_ARCHIVE_LOCATION")
+ path = app.config.get("LOG_ARCHIVE_PATH", None)
- handler_name = 'web.logarchive'
+ handler_name = "web.logarchive"
- log_archive = DelegateUserfiles(app, distributed_storage, location, path,
- handler_name=handler_name)
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['log_archive'] = log_archive
- return log_archive
+ log_archive = DelegateUserfiles(
+ app, distributed_storage, location, path, handler_name=handler_name
+ )
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["log_archive"] = log_archive
+ return log_archive
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
diff --git a/data/billing.py b/data/billing.py
index aa2420c01..a18000ea2 100644
--- a/data/billing.py
+++ b/data/billing.py
@@ -6,448 +6,456 @@ from calendar import timegm
from util.morecollections import AttrDict
PLANS = [
- # Deprecated Plans (2013-2014)
- {
- 'title': 'Micro',
- 'price': 700,
- 'privateRepos': 5,
- 'stripeId': 'micro',
- 'audience': 'For smaller teams',
- 'bus_features': False,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'personal-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Basic',
- 'price': 1200,
- 'privateRepos': 10,
- 'stripeId': 'small',
- 'audience': 'For your basic team',
- 'bus_features': False,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'bus-micro-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Yacht',
- 'price': 5000,
- 'privateRepos': 20,
- 'stripeId': 'bus-coreos-trial',
- 'audience': 'For small businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 180,
- 'superseded_by': 'bus-small-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Personal',
- 'price': 1200,
- 'privateRepos': 5,
- 'stripeId': 'personal',
- 'audience': 'Individuals',
- 'bus_features': False,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'personal-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Skiff',
- 'price': 2500,
- 'privateRepos': 10,
- 'stripeId': 'bus-micro',
- 'audience': 'For startups',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'bus-micro-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Yacht',
- 'price': 5000,
- 'privateRepos': 20,
- 'stripeId': 'bus-small',
- 'audience': 'For small businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'bus-small-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Freighter',
- 'price': 10000,
- 'privateRepos': 50,
- 'stripeId': 'bus-medium',
- 'audience': 'For normal businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'bus-medium-30',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Tanker',
- 'price': 20000,
- 'privateRepos': 125,
- 'stripeId': 'bus-large',
- 'audience': 'For large businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 14,
- 'superseded_by': 'bus-large-30',
- 'plans_page_hidden': False,
- },
-
- # Deprecated plans (2014-2017)
- {
- 'title': 'Personal',
- 'price': 1200,
- 'privateRepos': 5,
- 'stripeId': 'personal-30',
- 'audience': 'Individuals',
- 'bus_features': False,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'personal-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Skiff',
- 'price': 2500,
- 'privateRepos': 10,
- 'stripeId': 'bus-micro-30',
- 'audience': 'For startups',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-micro-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Yacht',
- 'price': 5000,
- 'privateRepos': 20,
- 'stripeId': 'bus-small-30',
- 'audience': 'For small businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-small-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Freighter',
- 'price': 10000,
- 'privateRepos': 50,
- 'stripeId': 'bus-medium-30',
- 'audience': 'For normal businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-medium-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Tanker',
- 'price': 20000,
- 'privateRepos': 125,
- 'stripeId': 'bus-large-30',
- 'audience': 'For large businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-large-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Carrier',
- 'price': 35000,
- 'privateRepos': 250,
- 'stripeId': 'bus-xlarge-30',
- 'audience': 'For extra large businesses',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-xlarge-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Huge',
- 'price': 65000,
- 'privateRepos': 500,
- 'stripeId': 'bus-500-30',
- 'audience': 'For huge business',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-500-2018',
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Huuge',
- 'price': 120000,
- 'privateRepos': 1000,
- 'stripeId': 'bus-1000-30',
- 'audience': 'For the SaaS savvy enterprise',
- 'bus_features': True,
- 'deprecated': True,
- 'free_trial_days': 30,
- 'superseded_by': 'bus-1000-2018',
- 'plans_page_hidden': False,
- },
-
- # Active plans (as of Dec 2017)
- {
- 'title': 'Open Source',
- 'price': 0,
- 'privateRepos': 0,
- 'stripeId': 'free',
- 'audience': 'Committment to FOSS',
- 'bus_features': False,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Developer',
- 'price': 1500,
- 'privateRepos': 5,
- 'stripeId': 'personal-2018',
- 'audience': 'Individuals',
- 'bus_features': False,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Micro',
- 'price': 3000,
- 'privateRepos': 10,
- 'stripeId': 'bus-micro-2018',
- 'audience': 'For startups',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Small',
- 'price': 6000,
- 'privateRepos': 20,
- 'stripeId': 'bus-small-2018',
- 'audience': 'For small businesses',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Medium',
- 'price': 12500,
- 'privateRepos': 50,
- 'stripeId': 'bus-medium-2018',
- 'audience': 'For normal businesses',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Large',
- 'price': 25000,
- 'privateRepos': 125,
- 'stripeId': 'bus-large-2018',
- 'audience': 'For large businesses',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'Extra Large',
- 'price': 45000,
- 'privateRepos': 250,
- 'stripeId': 'bus-xlarge-2018',
- 'audience': 'For extra large businesses',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'XXL',
- 'price': 85000,
- 'privateRepos': 500,
- 'stripeId': 'bus-500-2018',
- 'audience': 'For huge business',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'XXXL',
- 'price': 160000,
- 'privateRepos': 1000,
- 'stripeId': 'bus-1000-2018',
- 'audience': 'For the SaaS savvy enterprise',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
- {
- 'title': 'XXXXL',
- 'price': 310000,
- 'privateRepos': 2000,
- 'stripeId': 'bus-2000-2018',
- 'audience': 'For the SaaS savvy big enterprise',
- 'bus_features': True,
- 'deprecated': False,
- 'free_trial_days': 30,
- 'superseded_by': None,
- 'plans_page_hidden': False,
- },
+ # Deprecated Plans (2013-2014)
+ {
+ "title": "Micro",
+ "price": 700,
+ "privateRepos": 5,
+ "stripeId": "micro",
+ "audience": "For smaller teams",
+ "bus_features": False,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "personal-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Basic",
+ "price": 1200,
+ "privateRepos": 10,
+ "stripeId": "small",
+ "audience": "For your basic team",
+ "bus_features": False,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "bus-micro-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Yacht",
+ "price": 5000,
+ "privateRepos": 20,
+ "stripeId": "bus-coreos-trial",
+ "audience": "For small businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 180,
+ "superseded_by": "bus-small-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Personal",
+ "price": 1200,
+ "privateRepos": 5,
+ "stripeId": "personal",
+ "audience": "Individuals",
+ "bus_features": False,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "personal-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Skiff",
+ "price": 2500,
+ "privateRepos": 10,
+ "stripeId": "bus-micro",
+ "audience": "For startups",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "bus-micro-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Yacht",
+ "price": 5000,
+ "privateRepos": 20,
+ "stripeId": "bus-small",
+ "audience": "For small businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "bus-small-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Freighter",
+ "price": 10000,
+ "privateRepos": 50,
+ "stripeId": "bus-medium",
+ "audience": "For normal businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "bus-medium-30",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Tanker",
+ "price": 20000,
+ "privateRepos": 125,
+ "stripeId": "bus-large",
+ "audience": "For large businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 14,
+ "superseded_by": "bus-large-30",
+ "plans_page_hidden": False,
+ },
+ # Deprecated plans (2014-2017)
+ {
+ "title": "Personal",
+ "price": 1200,
+ "privateRepos": 5,
+ "stripeId": "personal-30",
+ "audience": "Individuals",
+ "bus_features": False,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "personal-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Skiff",
+ "price": 2500,
+ "privateRepos": 10,
+ "stripeId": "bus-micro-30",
+ "audience": "For startups",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-micro-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Yacht",
+ "price": 5000,
+ "privateRepos": 20,
+ "stripeId": "bus-small-30",
+ "audience": "For small businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-small-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Freighter",
+ "price": 10000,
+ "privateRepos": 50,
+ "stripeId": "bus-medium-30",
+ "audience": "For normal businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-medium-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Tanker",
+ "price": 20000,
+ "privateRepos": 125,
+ "stripeId": "bus-large-30",
+ "audience": "For large businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-large-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Carrier",
+ "price": 35000,
+ "privateRepos": 250,
+ "stripeId": "bus-xlarge-30",
+ "audience": "For extra large businesses",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-xlarge-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Huge",
+ "price": 65000,
+ "privateRepos": 500,
+ "stripeId": "bus-500-30",
+ "audience": "For huge business",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-500-2018",
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Huuge",
+ "price": 120000,
+ "privateRepos": 1000,
+ "stripeId": "bus-1000-30",
+ "audience": "For the SaaS savvy enterprise",
+ "bus_features": True,
+ "deprecated": True,
+ "free_trial_days": 30,
+ "superseded_by": "bus-1000-2018",
+ "plans_page_hidden": False,
+ },
+ # Active plans (as of Dec 2017)
+ {
+ "title": "Open Source",
+ "price": 0,
+ "privateRepos": 0,
+ "stripeId": "free",
+ "audience": "Committment to FOSS",
+ "bus_features": False,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Developer",
+ "price": 1500,
+ "privateRepos": 5,
+ "stripeId": "personal-2018",
+ "audience": "Individuals",
+ "bus_features": False,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Micro",
+ "price": 3000,
+ "privateRepos": 10,
+ "stripeId": "bus-micro-2018",
+ "audience": "For startups",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Small",
+ "price": 6000,
+ "privateRepos": 20,
+ "stripeId": "bus-small-2018",
+ "audience": "For small businesses",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Medium",
+ "price": 12500,
+ "privateRepos": 50,
+ "stripeId": "bus-medium-2018",
+ "audience": "For normal businesses",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Large",
+ "price": 25000,
+ "privateRepos": 125,
+ "stripeId": "bus-large-2018",
+ "audience": "For large businesses",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "Extra Large",
+ "price": 45000,
+ "privateRepos": 250,
+ "stripeId": "bus-xlarge-2018",
+ "audience": "For extra large businesses",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "XXL",
+ "price": 85000,
+ "privateRepos": 500,
+ "stripeId": "bus-500-2018",
+ "audience": "For huge business",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "XXXL",
+ "price": 160000,
+ "privateRepos": 1000,
+ "stripeId": "bus-1000-2018",
+ "audience": "For the SaaS savvy enterprise",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
+ {
+ "title": "XXXXL",
+ "price": 310000,
+ "privateRepos": 2000,
+ "stripeId": "bus-2000-2018",
+ "audience": "For the SaaS savvy big enterprise",
+ "bus_features": True,
+ "deprecated": False,
+ "free_trial_days": 30,
+ "superseded_by": None,
+ "plans_page_hidden": False,
+ },
]
def get_plan(plan_id):
- """ Returns the plan with the given ID or None if none. """
- for plan in PLANS:
- if plan['stripeId'] == plan_id:
- return plan
+ """ Returns the plan with the given ID or None if none. """
+ for plan in PLANS:
+ if plan["stripeId"] == plan_id:
+ return plan
- return None
+ return None
class FakeSubscription(AttrDict):
- @classmethod
- def build(cls, data, customer):
- data = AttrDict.deep_copy(data)
- data['customer'] = customer
- return cls(data)
+ @classmethod
+ def build(cls, data, customer):
+ data = AttrDict.deep_copy(data)
+ data["customer"] = customer
+ return cls(data)
- def delete(self):
- self.customer.subscription = None
+ def delete(self):
+ self.customer.subscription = None
class FakeStripe(object):
- class Customer(AttrDict):
- FAKE_PLAN = AttrDict({
- 'id': 'bus-small',
- })
+ class Customer(AttrDict):
+ FAKE_PLAN = AttrDict({"id": "bus-small"})
- FAKE_SUBSCRIPTION = AttrDict({
- 'plan': FAKE_PLAN,
- 'current_period_start': timegm(datetime.utcnow().utctimetuple()),
- 'current_period_end': timegm((datetime.utcnow() + timedelta(days=30)).utctimetuple()),
- 'trial_start': timegm(datetime.utcnow().utctimetuple()),
- 'trial_end': timegm((datetime.utcnow() + timedelta(days=30)).utctimetuple()),
- })
+ FAKE_SUBSCRIPTION = AttrDict(
+ {
+ "plan": FAKE_PLAN,
+ "current_period_start": timegm(datetime.utcnow().utctimetuple()),
+ "current_period_end": timegm(
+ (datetime.utcnow() + timedelta(days=30)).utctimetuple()
+ ),
+ "trial_start": timegm(datetime.utcnow().utctimetuple()),
+ "trial_end": timegm(
+ (datetime.utcnow() + timedelta(days=30)).utctimetuple()
+ ),
+ }
+ )
- FAKE_CARD = AttrDict({
- 'id': 'card123',
- 'name': 'Joe User',
- 'type': 'Visa',
- 'last4': '4242',
- 'exp_month': 5,
- 'exp_year': 2016,
- })
+ FAKE_CARD = AttrDict(
+ {
+ "id": "card123",
+ "name": "Joe User",
+ "type": "Visa",
+ "last4": "4242",
+ "exp_month": 5,
+ "exp_year": 2016,
+ }
+ )
- FAKE_CARD_LIST = AttrDict({
- 'data': [FAKE_CARD],
- })
+ FAKE_CARD_LIST = AttrDict({"data": [FAKE_CARD]})
- ACTIVE_CUSTOMERS = {}
+ ACTIVE_CUSTOMERS = {}
- @property
- def card(self):
- return self.get('new_card', None)
+ @property
+ def card(self):
+ return self.get("new_card", None)
- @card.setter
- def card(self, card_token):
- self['new_card'] = card_token
+ @card.setter
+ def card(self, card_token):
+ self["new_card"] = card_token
- @property
- def plan(self):
- return self.get('new_plan', None)
+ @property
+ def plan(self):
+ return self.get("new_plan", None)
- @plan.setter
- def plan(self, plan_name):
- self['new_plan'] = plan_name
+ @plan.setter
+ def plan(self, plan_name):
+ self["new_plan"] = plan_name
- def save(self):
- if self.get('new_card', None) is not None:
- raise stripe.error.CardError('Test raising exception on set card.', self.get('new_card'), 402)
- if self.get('new_plan', None) is not None:
- if self.subscription is None:
- self.subscription = FakeSubscription.build(self.FAKE_SUBSCRIPTION, self)
- self.subscription.plan.id = self.get('new_plan')
+ def save(self):
+ if self.get("new_card", None) is not None:
+ raise stripe.error.CardError(
+ "Test raising exception on set card.", self.get("new_card"), 402
+ )
+ if self.get("new_plan", None) is not None:
+ if self.subscription is None:
+ self.subscription = FakeSubscription.build(
+ self.FAKE_SUBSCRIPTION, self
+ )
+ self.subscription.plan.id = self.get("new_plan")
- @classmethod
- def retrieve(cls, stripe_customer_id):
- if stripe_customer_id in cls.ACTIVE_CUSTOMERS:
- cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_card', None)
- cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_plan', None)
- return cls.ACTIVE_CUSTOMERS[stripe_customer_id]
- else:
- new_customer = cls({
- 'default_card': 'card123',
- 'cards': AttrDict.deep_copy(cls.FAKE_CARD_LIST),
- 'id': stripe_customer_id,
- })
- new_customer.subscription = FakeSubscription.build(cls.FAKE_SUBSCRIPTION, new_customer)
- cls.ACTIVE_CUSTOMERS[stripe_customer_id] = new_customer
- return new_customer
+ @classmethod
+ def retrieve(cls, stripe_customer_id):
+ if stripe_customer_id in cls.ACTIVE_CUSTOMERS:
+ cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop("new_card", None)
+ cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop("new_plan", None)
+ return cls.ACTIVE_CUSTOMERS[stripe_customer_id]
+ else:
+ new_customer = cls(
+ {
+ "default_card": "card123",
+ "cards": AttrDict.deep_copy(cls.FAKE_CARD_LIST),
+ "id": stripe_customer_id,
+ }
+ )
+ new_customer.subscription = FakeSubscription.build(
+ cls.FAKE_SUBSCRIPTION, new_customer
+ )
+ cls.ACTIVE_CUSTOMERS[stripe_customer_id] = new_customer
+ return new_customer
- class Invoice(AttrDict):
- @staticmethod
- def list(customer, count):
- return AttrDict({
- 'data': [],
- })
+ class Invoice(AttrDict):
+ @staticmethod
+ def list(customer, count):
+ return AttrDict({"data": []})
class Billing(object):
- def __init__(self, app=None):
- self.app = app
- if app is not None:
- self.state = self.init_app(app)
- else:
- self.state = None
+ def __init__(self, app=None):
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(app)
+ else:
+ self.state = None
- def init_app(self, app):
- billing_type = app.config.get('BILLING_TYPE', 'FakeStripe')
+ def init_app(self, app):
+ billing_type = app.config.get("BILLING_TYPE", "FakeStripe")
- if billing_type == 'Stripe':
- billing = stripe
- stripe.api_key = app.config.get('STRIPE_SECRET_KEY', None)
+ if billing_type == "Stripe":
+ billing = stripe
+ stripe.api_key = app.config.get("STRIPE_SECRET_KEY", None)
- elif billing_type == 'FakeStripe':
- billing = FakeStripe
+ elif billing_type == "FakeStripe":
+ billing = FakeStripe
- else:
- raise RuntimeError('Unknown billing type: %s' % billing_type)
+ else:
+ raise RuntimeError("Unknown billing type: %s" % billing_type)
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['billing'] = billing
- return billing
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["billing"] = billing
+ return billing
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
diff --git a/data/buildlogs.py b/data/buildlogs.py
index b6b4d2652..3763cd569 100644
--- a/data/buildlogs.py
+++ b/data/buildlogs.py
@@ -13,167 +13,174 @@ SEVEN_DAYS = timedelta(days=7)
class BuildStatusRetrievalError(Exception):
- pass
+ pass
+
class RedisBuildLogs(object):
- ERROR = 'error'
- COMMAND = 'command'
- PHASE = 'phase'
+ ERROR = "error"
+ COMMAND = "command"
+ PHASE = "phase"
- def __init__(self, redis_config):
- self._redis_client = None
- self._redis_config = redis_config
+ def __init__(self, redis_config):
+ self._redis_client = None
+ self._redis_config = redis_config
- @property
- def _redis(self):
- if self._redis_client is not None:
- return self._redis_client
+ @property
+ def _redis(self):
+ if self._redis_client is not None:
+ return self._redis_client
- args = dict(self._redis_config)
- args.update({'socket_connect_timeout': 1,
- 'socket_timeout': 2,
- 'single_connection_client': True})
+ args = dict(self._redis_config)
+ args.update(
+ {
+ "socket_connect_timeout": 1,
+ "socket_timeout": 2,
+ "single_connection_client": True,
+ }
+ )
- self._redis_client = redis.StrictRedis(**args)
- return self._redis_client
+ self._redis_client = redis.StrictRedis(**args)
+ return self._redis_client
- @staticmethod
- def _logs_key(build_id):
- return 'builds/%s/logs' % build_id
+ @staticmethod
+ def _logs_key(build_id):
+ return "builds/%s/logs" % build_id
- def append_log_entry(self, build_id, log_obj):
- """
+ def append_log_entry(self, build_id, log_obj):
+ """
Appends the serialized form of log_obj to the end of the log entry list
and returns the new length of the list.
"""
- pipeline = self._redis.pipeline(transaction=False)
- pipeline.expire(self._logs_key(build_id), SEVEN_DAYS)
- pipeline.rpush(self._logs_key(build_id), json.dumps(log_obj))
- result = pipeline.execute()
- return result[1]
+ pipeline = self._redis.pipeline(transaction=False)
+ pipeline.expire(self._logs_key(build_id), SEVEN_DAYS)
+ pipeline.rpush(self._logs_key(build_id), json.dumps(log_obj))
+ result = pipeline.execute()
+ return result[1]
- def append_log_message(self, build_id, log_message, log_type=None, log_data=None):
- """
+ def append_log_message(self, build_id, log_message, log_type=None, log_data=None):
+ """
Wraps the message in an envelope and push it to the end of the log entry
list and returns the index at which it was inserted.
"""
- log_obj = {
- 'message': log_message
- }
+ log_obj = {"message": log_message}
- if log_type:
- log_obj['type'] = log_type
+ if log_type:
+ log_obj["type"] = log_type
- if log_data:
- log_obj['data'] = log_data
+ if log_data:
+ log_obj["data"] = log_data
- return self.append_log_entry(build_id, log_obj) - 1
+ return self.append_log_entry(build_id, log_obj) - 1
- def get_log_entries(self, build_id, start_index):
- """
+ def get_log_entries(self, build_id, start_index):
+ """
Returns a tuple of the current length of the list and an iterable of the
requested log entries.
"""
- try:
- llen = self._redis.llen(self._logs_key(build_id))
- log_entries = self._redis.lrange(self._logs_key(build_id), start_index, -1)
- return (llen, (json.loads(entry) for entry in log_entries))
- except redis.RedisError as re:
- raise BuildStatusRetrievalError('Cannot retrieve build logs: %s' % re)
+ try:
+ llen = self._redis.llen(self._logs_key(build_id))
+ log_entries = self._redis.lrange(self._logs_key(build_id), start_index, -1)
+ return (llen, (json.loads(entry) for entry in log_entries))
+ except redis.RedisError as re:
+ raise BuildStatusRetrievalError("Cannot retrieve build logs: %s" % re)
- def expire_status(self, build_id):
- """
+ def expire_status(self, build_id):
+ """
Sets the status entry to expire in 1 day.
"""
- self._redis.expire(self._status_key(build_id), ONE_DAY)
+ self._redis.expire(self._status_key(build_id), ONE_DAY)
- def expire_log_entries(self, build_id):
- """
+ def expire_log_entries(self, build_id):
+ """
Sets the log entry to expire in 1 day.
"""
- self._redis.expire(self._logs_key(build_id), ONE_DAY)
+ self._redis.expire(self._logs_key(build_id), ONE_DAY)
- def delete_log_entries(self, build_id):
- """
+ def delete_log_entries(self, build_id):
+ """
Delete the log entry
"""
- self._redis.delete(self._logs_key(build_id))
+ self._redis.delete(self._logs_key(build_id))
- @staticmethod
- def _status_key(build_id):
- return 'builds/%s/status' % build_id
+ @staticmethod
+ def _status_key(build_id):
+ return "builds/%s/status" % build_id
- def set_status(self, build_id, status_obj):
- """
+ def set_status(self, build_id, status_obj):
+ """
Sets the status key for this build to json serialized form of the supplied
obj.
"""
- self._redis.set(self._status_key(build_id), json.dumps(status_obj), ex=SEVEN_DAYS)
+ self._redis.set(
+ self._status_key(build_id), json.dumps(status_obj), ex=SEVEN_DAYS
+ )
- def get_status(self, build_id):
- """
+ def get_status(self, build_id):
+ """
Loads the status information for the specified build id.
"""
- try:
- fetched = self._redis.get(self._status_key(build_id))
- except redis.RedisError as re:
- raise BuildStatusRetrievalError('Cannot retrieve build status: %s' % re)
+ try:
+ fetched = self._redis.get(self._status_key(build_id))
+ except redis.RedisError as re:
+ raise BuildStatusRetrievalError("Cannot retrieve build status: %s" % re)
- return json.loads(fetched) if fetched else None
+ return json.loads(fetched) if fetched else None
- @staticmethod
- def _health_key():
- return '_health'
+ @staticmethod
+ def _health_key():
+ return "_health"
- def check_health(self):
- try:
- args = dict(self._redis_config)
- args.update({'socket_connect_timeout': 1,
- 'socket_timeout': 1,
- 'single_connection_client': True})
+ def check_health(self):
+ try:
+ args = dict(self._redis_config)
+ args.update(
+ {
+ "socket_connect_timeout": 1,
+ "socket_timeout": 1,
+ "single_connection_client": True,
+ }
+ )
- with closing(redis.StrictRedis(**args)) as connection:
- if not connection.ping():
- return (False, 'Could not ping redis')
+ with closing(redis.StrictRedis(**args)) as connection:
+ if not connection.ping():
+ return (False, "Could not ping redis")
- # Ensure we can write and read a key.
- connection.set(self._health_key(), time.time())
- connection.get(self._health_key())
- return (True, None)
- except redis.RedisError as re:
- return (False, 'Could not connect to redis: %s' % re.message)
+ # Ensure we can write and read a key.
+ connection.set(self._health_key(), time.time())
+ connection.get(self._health_key())
+ return (True, None)
+ except redis.RedisError as re:
+ return (False, "Could not connect to redis: %s" % re.message)
class BuildLogs(object):
- def __init__(self, app=None):
- self.app = app
- if app is not None:
- self.state = self.init_app(app)
- else:
- self.state = None
+ def __init__(self, app=None):
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(app)
+ else:
+ self.state = None
- def init_app(self, app):
- buildlogs_config = app.config.get('BUILDLOGS_REDIS')
- if not buildlogs_config:
- # This is the old key name.
- buildlogs_config = {
- 'host': app.config.get('BUILDLOGS_REDIS_HOSTNAME')
- }
+ def init_app(self, app):
+ buildlogs_config = app.config.get("BUILDLOGS_REDIS")
+ if not buildlogs_config:
+ # This is the old key name.
+ buildlogs_config = {"host": app.config.get("BUILDLOGS_REDIS_HOSTNAME")}
- buildlogs_options = app.config.get('BUILDLOGS_OPTIONS', [])
- buildlogs_import = app.config.get('BUILDLOGS_MODULE_AND_CLASS', None)
+ buildlogs_options = app.config.get("BUILDLOGS_OPTIONS", [])
+ buildlogs_import = app.config.get("BUILDLOGS_MODULE_AND_CLASS", None)
- if buildlogs_import is None:
- klass = RedisBuildLogs
- else:
- klass = import_class(buildlogs_import[0], buildlogs_import[1])
+ if buildlogs_import is None:
+ klass = RedisBuildLogs
+ else:
+ klass = import_class(buildlogs_import[0], buildlogs_import[1])
- buildlogs = klass(buildlogs_config, *buildlogs_options)
+ buildlogs = klass(buildlogs_config, *buildlogs_options)
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['buildlogs'] = buildlogs
- return buildlogs
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["buildlogs"] = buildlogs
+ return buildlogs
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
diff --git a/data/cache/__init__.py b/data/cache/__init__.py
index a7c44dadd..a604c63e3 100644
--- a/data/cache/__init__.py
+++ b/data/cache/__init__.py
@@ -1,23 +1,32 @@
-from data.cache.impl import NoopDataModelCache, InMemoryDataModelCache, MemcachedModelCache
+from data.cache.impl import (
+ NoopDataModelCache,
+ InMemoryDataModelCache,
+ MemcachedModelCache,
+)
+
def get_model_cache(config):
- """ Returns a data model cache matching the given configuration. """
- cache_config = config.get('DATA_MODEL_CACHE_CONFIG', {})
- engine = cache_config.get('engine', 'noop')
+ """ Returns a data model cache matching the given configuration. """
+ cache_config = config.get("DATA_MODEL_CACHE_CONFIG", {})
+ engine = cache_config.get("engine", "noop")
- if engine == 'noop':
- return NoopDataModelCache()
+ if engine == "noop":
+ return NoopDataModelCache()
- if engine == 'inmemory':
- return InMemoryDataModelCache()
+ if engine == "inmemory":
+ return InMemoryDataModelCache()
- if engine == 'memcached':
- endpoint = cache_config.get('endpoint', None)
- if endpoint is None:
- raise Exception('Missing `endpoint` for memcached model cache configuration')
+ if engine == "memcached":
+ endpoint = cache_config.get("endpoint", None)
+ if endpoint is None:
+ raise Exception(
+ "Missing `endpoint` for memcached model cache configuration"
+ )
- timeout = cache_config.get('timeout')
- connect_timeout = cache_config.get('connect_timeout')
- return MemcachedModelCache(endpoint, timeout=timeout, connect_timeout=connect_timeout)
+ timeout = cache_config.get("timeout")
+ connect_timeout = cache_config.get("connect_timeout")
+ return MemcachedModelCache(
+ endpoint, timeout=timeout, connect_timeout=connect_timeout
+ )
- raise Exception('Unknown model cache engine `%s`' % engine)
+ raise Exception("Unknown model cache engine `%s`" % engine)
diff --git a/data/cache/cache_key.py b/data/cache/cache_key.py
index 93aad65be..3c99d1d80 100644
--- a/data/cache/cache_key.py
+++ b/data/cache/cache_key.py
@@ -1,27 +1,33 @@
from collections import namedtuple
-class CacheKey(namedtuple('CacheKey', ['key', 'expiration'])):
- """ Defines a key into the data model cache. """
- pass
+
+class CacheKey(namedtuple("CacheKey", ["key", "expiration"])):
+ """ Defines a key into the data model cache. """
+
+ pass
def for_repository_blob(namespace_name, repo_name, digest, version):
- """ Returns a cache key for a blob in a repository. """
- return CacheKey('repo_blob__%s_%s_%s_%s' % (namespace_name, repo_name, digest, version), '60s')
+ """ Returns a cache key for a blob in a repository. """
+ return CacheKey(
+ "repo_blob__%s_%s_%s_%s" % (namespace_name, repo_name, digest, version), "60s"
+ )
def for_catalog_page(auth_context_key, start_id, limit):
- """ Returns a cache key for a single page of a catalog lookup for an authed context. """
- params = (auth_context_key or '(anon)', start_id or 0, limit or 0)
- return CacheKey('catalog_page__%s_%s_%s' % params, '60s')
+ """ Returns a cache key for a single page of a catalog lookup for an authed context. """
+ params = (auth_context_key or "(anon)", start_id or 0, limit or 0)
+ return CacheKey("catalog_page__%s_%s_%s" % params, "60s")
def for_namespace_geo_restrictions(namespace_name):
- """ Returns a cache key for the geo restrictions for a namespace. """
- return CacheKey('geo_restrictions__%s' % (namespace_name), '240s')
+ """ Returns a cache key for the geo restrictions for a namespace. """
+ return CacheKey("geo_restrictions__%s" % (namespace_name), "240s")
def for_active_repo_tags(repository_id, start_pagination_id, limit):
- """ Returns a cache key for the active tags in a repository. """
- return CacheKey('repo_active_tags__%s_%s_%s' % (repository_id, start_pagination_id, limit),
- '120s')
+ """ Returns a cache key for the active tags in a repository. """
+ return CacheKey(
+ "repo_active_tags__%s_%s_%s" % (repository_id, start_pagination_id, limit),
+ "120s",
+ )
diff --git a/data/cache/impl.py b/data/cache/impl.py
index 982e950e9..c374f26e2 100644
--- a/data/cache/impl.py
+++ b/data/cache/impl.py
@@ -15,132 +15,185 @@ logger = logging.getLogger(__name__)
def is_not_none(value):
- return value is not None
+ return value is not None
@add_metaclass(ABCMeta)
class DataModelCache(object):
- """ Defines an interface for cache storing and returning tuple data model objects. """
+ """ Defines an interface for cache storing and returning tuple data model objects. """
- @abstractmethod
- def retrieve(self, cache_key, loader, should_cache=is_not_none):
- """ Checks the cache for the specified cache key and returns the value found (if any). If none
+ @abstractmethod
+ def retrieve(self, cache_key, loader, should_cache=is_not_none):
+ """ Checks the cache for the specified cache key and returns the value found (if any). If none
found, the loader is called to get a result and populate the cache.
"""
- pass
+ pass
class NoopDataModelCache(DataModelCache):
- """ Implementation of the data model cache which does nothing. """
+ """ Implementation of the data model cache which does nothing. """
- def retrieve(self, cache_key, loader, should_cache=is_not_none):
- return loader()
+ def retrieve(self, cache_key, loader, should_cache=is_not_none):
+ return loader()
class InMemoryDataModelCache(DataModelCache):
- """ Implementation of the data model cache backed by an in-memory dictionary. """
- def __init__(self):
- self.cache = ExpiresDict()
+ """ Implementation of the data model cache backed by an in-memory dictionary. """
- def retrieve(self, cache_key, loader, should_cache=is_not_none):
- not_found = [None]
- logger.debug('Checking cache for key %s', cache_key.key)
- result = self.cache.get(cache_key.key, default_value=not_found)
- if result != not_found:
- logger.debug('Found result in cache for key %s: %s', cache_key.key, result)
- return json.loads(result)
+ def __init__(self):
+ self.cache = ExpiresDict()
- logger.debug('Found no result in cache for key %s; calling loader', cache_key.key)
- result = loader()
- logger.debug('Got loaded result for key %s: %s', cache_key.key, result)
- if should_cache(result):
- logger.debug('Caching loaded result for key %s with expiration %s: %s', cache_key.key,
- result, cache_key.expiration)
- expires = convert_to_timedelta(cache_key.expiration) + datetime.now()
- self.cache.set(cache_key.key, json.dumps(result), expires=expires)
- logger.debug('Cached loaded result for key %s with expiration %s: %s', cache_key.key,
- result, cache_key.expiration)
- else:
- logger.debug('Not caching loaded result for key %s: %s', cache_key.key, result)
+ def retrieve(self, cache_key, loader, should_cache=is_not_none):
+ not_found = [None]
+ logger.debug("Checking cache for key %s", cache_key.key)
+ result = self.cache.get(cache_key.key, default_value=not_found)
+ if result != not_found:
+ logger.debug("Found result in cache for key %s: %s", cache_key.key, result)
+ return json.loads(result)
- return result
+ logger.debug(
+ "Found no result in cache for key %s; calling loader", cache_key.key
+ )
+ result = loader()
+ logger.debug("Got loaded result for key %s: %s", cache_key.key, result)
+ if should_cache(result):
+ logger.debug(
+ "Caching loaded result for key %s with expiration %s: %s",
+ cache_key.key,
+ result,
+ cache_key.expiration,
+ )
+ expires = convert_to_timedelta(cache_key.expiration) + datetime.now()
+ self.cache.set(cache_key.key, json.dumps(result), expires=expires)
+ logger.debug(
+ "Cached loaded result for key %s with expiration %s: %s",
+ cache_key.key,
+ result,
+ cache_key.expiration,
+ )
+ else:
+ logger.debug(
+ "Not caching loaded result for key %s: %s", cache_key.key, result
+ )
+
+ return result
-_DEFAULT_MEMCACHE_TIMEOUT = 1 # second
-_DEFAULT_MEMCACHE_CONNECT_TIMEOUT = 1 # second
+_DEFAULT_MEMCACHE_TIMEOUT = 1 # second
+_DEFAULT_MEMCACHE_CONNECT_TIMEOUT = 1 # second
_STRING_TYPE = 1
_JSON_TYPE = 2
+
class MemcachedModelCache(DataModelCache):
- """ Implementation of the data model cache backed by a memcached. """
- def __init__(self, endpoint, timeout=_DEFAULT_MEMCACHE_TIMEOUT,
- connect_timeout=_DEFAULT_MEMCACHE_CONNECT_TIMEOUT):
- self.endpoint = endpoint
- self.timeout = timeout
- self.connect_timeout = connect_timeout
- self.client = None
+ """ Implementation of the data model cache backed by a memcached. """
- def _get_client(self):
- client = self.client
- if client is not None:
- return client
+ def __init__(
+ self,
+ endpoint,
+ timeout=_DEFAULT_MEMCACHE_TIMEOUT,
+ connect_timeout=_DEFAULT_MEMCACHE_CONNECT_TIMEOUT,
+ ):
+ self.endpoint = endpoint
+ self.timeout = timeout
+ self.connect_timeout = connect_timeout
+ self.client = None
- try:
- # Copied from the doc comment for Client.
- def serialize_json(key, value):
- if type(value) == str:
- return value, _STRING_TYPE
+ def _get_client(self):
+ client = self.client
+ if client is not None:
+ return client
- return json.dumps(value), _JSON_TYPE
+ try:
+ # Copied from the doc comment for Client.
+ def serialize_json(key, value):
+ if type(value) == str:
+ return value, _STRING_TYPE
- def deserialize_json(key, value, flags):
- if flags == _STRING_TYPE:
- return value
+ return json.dumps(value), _JSON_TYPE
- if flags == _JSON_TYPE:
- return json.loads(value)
+ def deserialize_json(key, value, flags):
+ if flags == _STRING_TYPE:
+ return value
- raise Exception("Unknown flags for value: {1}".format(flags))
+ if flags == _JSON_TYPE:
+ return json.loads(value)
- self.client = Client(self.endpoint, no_delay=True, timeout=self.timeout,
- connect_timeout=self.connect_timeout,
- key_prefix='data_model_cache__',
- serializer=serialize_json,
- deserializer=deserialize_json,
- ignore_exc=True)
- return self.client
- except:
- logger.exception('Got exception when creating memcached client to %s', self.endpoint)
- return None
+ raise Exception("Unknown flags for value: {1}".format(flags))
- def retrieve(self, cache_key, loader, should_cache=is_not_none):
- not_found = [None]
- client = self._get_client()
- if client is not None:
- logger.debug('Checking cache for key %s', cache_key.key)
- try:
- result = client.get(cache_key.key, default=not_found)
- if result != not_found:
- logger.debug('Found result in cache for key %s: %s', cache_key.key, result)
- return result
- except:
- logger.exception('Got exception when trying to retrieve key %s', cache_key.key)
+ self.client = Client(
+ self.endpoint,
+ no_delay=True,
+ timeout=self.timeout,
+ connect_timeout=self.connect_timeout,
+ key_prefix="data_model_cache__",
+ serializer=serialize_json,
+ deserializer=deserialize_json,
+ ignore_exc=True,
+ )
+ return self.client
+ except:
+ logger.exception(
+ "Got exception when creating memcached client to %s", self.endpoint
+ )
+ return None
- logger.debug('Found no result in cache for key %s; calling loader', cache_key.key)
- result = loader()
- logger.debug('Got loaded result for key %s: %s', cache_key.key, result)
- if client is not None and should_cache(result):
- try:
- logger.debug('Caching loaded result for key %s with expiration %s: %s', cache_key.key,
- result, cache_key.expiration)
- expires = convert_to_timedelta(cache_key.expiration) if cache_key.expiration else None
- client.set(cache_key.key, result, expire=int(expires.total_seconds()) if expires else None)
- logger.debug('Cached loaded result for key %s with expiration %s: %s', cache_key.key,
- result, cache_key.expiration)
- except:
- logger.exception('Got exception when trying to set key %s to %s', cache_key.key, result)
- else:
- logger.debug('Not caching loaded result for key %s: %s', cache_key.key, result)
+ def retrieve(self, cache_key, loader, should_cache=is_not_none):
+ not_found = [None]
+ client = self._get_client()
+ if client is not None:
+ logger.debug("Checking cache for key %s", cache_key.key)
+ try:
+ result = client.get(cache_key.key, default=not_found)
+ if result != not_found:
+ logger.debug(
+ "Found result in cache for key %s: %s", cache_key.key, result
+ )
+ return result
+ except:
+ logger.exception(
+ "Got exception when trying to retrieve key %s", cache_key.key
+ )
- return result
+ logger.debug(
+ "Found no result in cache for key %s; calling loader", cache_key.key
+ )
+ result = loader()
+ logger.debug("Got loaded result for key %s: %s", cache_key.key, result)
+ if client is not None and should_cache(result):
+ try:
+ logger.debug(
+ "Caching loaded result for key %s with expiration %s: %s",
+ cache_key.key,
+ result,
+ cache_key.expiration,
+ )
+ expires = (
+ convert_to_timedelta(cache_key.expiration)
+ if cache_key.expiration
+ else None
+ )
+ client.set(
+ cache_key.key,
+ result,
+ expire=int(expires.total_seconds()) if expires else None,
+ )
+ logger.debug(
+ "Cached loaded result for key %s with expiration %s: %s",
+ cache_key.key,
+ result,
+ cache_key.expiration,
+ )
+ except:
+ logger.exception(
+ "Got exception when trying to set key %s to %s",
+ cache_key.key,
+ result,
+ )
+ else:
+ logger.debug(
+ "Not caching loaded result for key %s: %s", cache_key.key, result
+ )
+
+ return result
diff --git a/data/cache/test/test_cache.py b/data/cache/test/test_cache.py
index bf0c4cccd..380765ce0 100644
--- a/data/cache/test/test_cache.py
+++ b/data/cache/test/test_cache.py
@@ -5,52 +5,50 @@ from mock import patch
from data.cache import InMemoryDataModelCache, NoopDataModelCache, MemcachedModelCache
from data.cache.cache_key import CacheKey
+
class MockClient(object):
- def __init__(self, server, **kwargs):
- self.data = {}
+ def __init__(self, server, **kwargs):
+ self.data = {}
- def get(self, key, default=None):
- return self.data.get(key, default)
+ def get(self, key, default=None):
+ return self.data.get(key, default)
- def set(self, key, value, expire=None):
- self.data[key] = value
+ def set(self, key, value, expire=None):
+ self.data[key] = value
-@pytest.mark.parametrize('cache_type', [
- (NoopDataModelCache),
- (InMemoryDataModelCache),
-])
+@pytest.mark.parametrize("cache_type", [(NoopDataModelCache), (InMemoryDataModelCache)])
def test_caching(cache_type):
- key = CacheKey('foo', '60m')
- cache = cache_type()
+ key = CacheKey("foo", "60m")
+ cache = cache_type()
- # Perform two retrievals, and make sure both return.
- assert cache.retrieve(key, lambda: {'a': 1234}) == {'a': 1234}
- assert cache.retrieve(key, lambda: {'a': 1234}) == {'a': 1234}
+ # Perform two retrievals, and make sure both return.
+ assert cache.retrieve(key, lambda: {"a": 1234}) == {"a": 1234}
+ assert cache.retrieve(key, lambda: {"a": 1234}) == {"a": 1234}
def test_memcache():
- key = CacheKey('foo', '60m')
- with patch('data.cache.impl.Client', MockClient):
- cache = MemcachedModelCache(('127.0.0.1', '-1'))
- assert cache.retrieve(key, lambda: {'a': 1234}) == {'a': 1234}
- assert cache.retrieve(key, lambda: {'a': 1234}) == {'a': 1234}
+ key = CacheKey("foo", "60m")
+ with patch("data.cache.impl.Client", MockClient):
+ cache = MemcachedModelCache(("127.0.0.1", "-1"))
+ assert cache.retrieve(key, lambda: {"a": 1234}) == {"a": 1234}
+ assert cache.retrieve(key, lambda: {"a": 1234}) == {"a": 1234}
def test_memcache_should_cache():
- key = CacheKey('foo', None)
+ key = CacheKey("foo", None)
- def sc(value):
- return value['a'] != 1234
+ def sc(value):
+ return value["a"] != 1234
- with patch('data.cache.impl.Client', MockClient):
- cache = MemcachedModelCache(('127.0.0.1', '-1'))
- assert cache.retrieve(key, lambda: {'a': 1234}, should_cache=sc) == {'a': 1234}
+ with patch("data.cache.impl.Client", MockClient):
+ cache = MemcachedModelCache(("127.0.0.1", "-1"))
+ assert cache.retrieve(key, lambda: {"a": 1234}, should_cache=sc) == {"a": 1234}
- # Ensure not cached since it was `1234`.
- assert cache._get_client().get(key.key) is None
+ # Ensure not cached since it was `1234`.
+ assert cache._get_client().get(key.key) is None
- # Ensure cached.
- assert cache.retrieve(key, lambda: {'a': 2345}, should_cache=sc) == {'a': 2345}
- assert cache._get_client().get(key.key) is not None
- assert cache.retrieve(key, lambda: {'a': 2345}, should_cache=sc) == {'a': 2345}
+ # Ensure cached.
+ assert cache.retrieve(key, lambda: {"a": 2345}, should_cache=sc) == {"a": 2345}
+ assert cache._get_client().get(key.key) is not None
+ assert cache.retrieve(key, lambda: {"a": 2345}, should_cache=sc) == {"a": 2345}
diff --git a/data/database.py b/data/database.py
index 62c59e6e0..2a0a90f6b 100644
--- a/data/database.py
+++ b/data/database.py
@@ -18,7 +18,11 @@ import toposort
from enum import IntEnum, Enum, unique
from peewee import *
from peewee import __exception_wrapper__, Function
-from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase, PooledSqliteDatabase
+from playhouse.pool import (
+ PooledMySQLDatabase,
+ PooledPostgresqlDatabase,
+ PooledSqliteDatabase,
+)
from sqlalchemy.engine.url import make_url
@@ -26,9 +30,18 @@ import resumablehashlib
from cachetools.func import lru_cache
from active_migration import ERTMigrationFlags, ActiveDataMigration
-from data.fields import (ResumableSHA256Field, ResumableSHA1Field, JSONField, Base64BinaryField,
- FullIndexedTextField, FullIndexedCharField, EnumField as ClientEnumField,
- EncryptedTextField, EncryptedCharField, CredentialField)
+from data.fields import (
+ ResumableSHA256Field,
+ ResumableSHA1Field,
+ JSONField,
+ Base64BinaryField,
+ FullIndexedTextField,
+ FullIndexedCharField,
+ EnumField as ClientEnumField,
+ EncryptedTextField,
+ EncryptedCharField,
+ CredentialField,
+)
from data.text import match_mysql, match_like
from data.encryption import FieldEncrypter
from data.readreplica import ReadReplicaSupportedModel, ReadOnlyConfig
@@ -38,86 +51,86 @@ from util.validation import validate_postgres_precondition
logger = logging.getLogger(__name__)
-DEFAULT_DB_CONNECT_TIMEOUT = 10 # seconds
+DEFAULT_DB_CONNECT_TIMEOUT = 10 # seconds
# IMAGE_NOT_SCANNED_ENGINE_VERSION is the version found in security_indexed_engine when the
# image has not yet been scanned.
IMAGE_NOT_SCANNED_ENGINE_VERSION = -1
-schemedriver = namedtuple('schemedriver', ['driver', 'pooled_driver'])
+schemedriver = namedtuple("schemedriver", ["driver", "pooled_driver"])
_SCHEME_DRIVERS = {
- 'mysql': schemedriver(MySQLDatabase, PooledMySQLDatabase),
- 'mysql+pymysql': schemedriver(MySQLDatabase, PooledMySQLDatabase),
- 'sqlite': schemedriver(SqliteDatabase, PooledSqliteDatabase),
- 'postgresql': schemedriver(PostgresqlDatabase, PooledPostgresqlDatabase),
- 'postgresql+psycopg2': schemedriver(PostgresqlDatabase, PooledPostgresqlDatabase),
+ "mysql": schemedriver(MySQLDatabase, PooledMySQLDatabase),
+ "mysql+pymysql": schemedriver(MySQLDatabase, PooledMySQLDatabase),
+ "sqlite": schemedriver(SqliteDatabase, PooledSqliteDatabase),
+ "postgresql": schemedriver(PostgresqlDatabase, PooledPostgresqlDatabase),
+ "postgresql+psycopg2": schemedriver(PostgresqlDatabase, PooledPostgresqlDatabase),
}
SCHEME_MATCH_FUNCTION = {
- 'mysql': match_mysql,
- 'mysql+pymysql': match_mysql,
- 'sqlite': match_like,
- 'postgresql': match_like,
- 'postgresql+psycopg2': match_like,
+ "mysql": match_mysql,
+ "mysql+pymysql": match_mysql,
+ "sqlite": match_like,
+ "postgresql": match_like,
+ "postgresql+psycopg2": match_like,
}
SCHEME_RANDOM_FUNCTION = {
- 'mysql': fn.Rand,
- 'mysql+pymysql': fn.Rand,
- 'sqlite': fn.Random,
- 'postgresql': fn.Random,
- 'postgresql+psycopg2': fn.Random,
+ "mysql": fn.Rand,
+ "mysql+pymysql": fn.Rand,
+ "sqlite": fn.Random,
+ "postgresql": fn.Random,
+ "postgresql+psycopg2": fn.Random,
}
PRECONDITION_VALIDATION = {
- 'postgresql': validate_postgres_precondition,
- 'postgresql+psycopg2': validate_postgres_precondition,
+ "postgresql": validate_postgres_precondition,
+ "postgresql+psycopg2": validate_postgres_precondition,
}
_EXTRA_ARGS = {
- 'mysql': dict(charset='utf8mb4'),
- 'mysql+pymysql': dict(charset='utf8mb4'),
+ "mysql": dict(charset="utf8mb4"),
+ "mysql+pymysql": dict(charset="utf8mb4"),
}
def pipes_concat(arg1, arg2, *extra_args):
- """ Concat function for sqlite, since it doesn't support fn.Concat.
+ """ Concat function for sqlite, since it doesn't support fn.Concat.
Concatenates clauses with || characters.
"""
- reduced = arg1.concat(arg2)
- for arg in extra_args:
- reduced = reduced.concat(arg)
- return reduced
+ reduced = arg1.concat(arg2)
+ for arg in extra_args:
+ reduced = reduced.concat(arg)
+ return reduced
def function_concat(arg1, arg2, *extra_args):
- """ Default implementation of concat which uses fn.Concat(). Used by all
+ """ Default implementation of concat which uses fn.Concat(). Used by all
database engines except sqlite.
"""
- return fn.Concat(arg1, arg2, *extra_args)
+ return fn.Concat(arg1, arg2, *extra_args)
-SCHEME_SPECIALIZED_CONCAT = {
- 'sqlite': pipes_concat,
-}
+SCHEME_SPECIALIZED_CONCAT = {"sqlite": pipes_concat}
def real_for_update(query):
- return query.for_update()
+ return query.for_update()
def null_for_update(query):
- return query
+ return query
-def delete_instance_filtered(instance, model_class, delete_nullable, skip_transitive_deletes):
- """ Deletes the DB instance recursively, skipping any models in the skip_transitive_deletes set.
+def delete_instance_filtered(
+ instance, model_class, delete_nullable, skip_transitive_deletes
+):
+ """ Deletes the DB instance recursively, skipping any models in the skip_transitive_deletes set.
Callers *must* ensure that any models listed in the skip_transitive_deletes must be capable
of being directly deleted when the instance is deleted (with automatic sorting handling
@@ -127,143 +140,147 @@ def delete_instance_filtered(instance, model_class, delete_nullable, skip_transi
*same* repository when RepositoryTag references Image, so we can safely skip
transitive deletion for the RepositoryTag table.
"""
- # We need to sort the ops so that models get cleaned in order of their dependencies
- ops = reversed(list(instance.dependencies(delete_nullable)))
- filtered_ops = []
+ # We need to sort the ops so that models get cleaned in order of their dependencies
+ ops = reversed(list(instance.dependencies(delete_nullable)))
+ filtered_ops = []
- dependencies = defaultdict(set)
+ dependencies = defaultdict(set)
- for query, fk in ops:
- # We only want to skip transitive deletes, which are done using subqueries in the form of
- # DELETE FROM in . If an op is not using a subquery, we allow it to be
- # applied directly.
- if fk.model not in skip_transitive_deletes or query.op.lower() != 'in':
- filtered_ops.append((query, fk))
+ for query, fk in ops:
+ # We only want to skip transitive deletes, which are done using subqueries in the form of
+ # DELETE FROM in . If an op is not using a subquery, we allow it to be
+ # applied directly.
+ if fk.model not in skip_transitive_deletes or query.op.lower() != "in":
+ filtered_ops.append((query, fk))
- if query.op.lower() == 'in':
- dependencies[fk.model.__name__].add(query.rhs.model.__name__)
- elif query.op == '=':
- dependencies[fk.model.__name__].add(model_class.__name__)
- else:
- raise RuntimeError('Unknown operator in recursive repository delete query')
+ if query.op.lower() == "in":
+ dependencies[fk.model.__name__].add(query.rhs.model.__name__)
+ elif query.op == "=":
+ dependencies[fk.model.__name__].add(model_class.__name__)
+ else:
+ raise RuntimeError("Unknown operator in recursive repository delete query")
- sorted_models = list(reversed(toposort.toposort_flatten(dependencies)))
- def sorted_model_key(query_fk_tuple):
- cmp_query, cmp_fk = query_fk_tuple
- if cmp_query.op.lower() == 'in':
- return -1
- return sorted_models.index(cmp_fk.model.__name__)
- filtered_ops.sort(key=sorted_model_key)
+ sorted_models = list(reversed(toposort.toposort_flatten(dependencies)))
- with db_transaction():
- for query, fk in filtered_ops:
- _model = fk.model
- if fk.null and not delete_nullable:
- _model.update(**{fk.name: None}).where(query).execute()
- else:
- _model.delete().where(query).execute()
+ def sorted_model_key(query_fk_tuple):
+ cmp_query, cmp_fk = query_fk_tuple
+ if cmp_query.op.lower() == "in":
+ return -1
+ return sorted_models.index(cmp_fk.model.__name__)
- return instance.delete().where(instance._pk_expr()).execute()
+ filtered_ops.sort(key=sorted_model_key)
+
+ with db_transaction():
+ for query, fk in filtered_ops:
+ _model = fk.model
+ if fk.null and not delete_nullable:
+ _model.update(**{fk.name: None}).where(query).execute()
+ else:
+ _model.delete().where(query).execute()
+
+ return instance.delete().where(instance._pk_expr()).execute()
-SCHEME_SPECIALIZED_FOR_UPDATE = {
- 'sqlite': null_for_update,
-}
+SCHEME_SPECIALIZED_FOR_UPDATE = {"sqlite": null_for_update}
class CallableProxy(Proxy):
- def __call__(self, *args, **kwargs):
- if self.obj is None:
- raise AttributeError('Cannot use uninitialized Proxy.')
- return self.obj(*args, **kwargs)
+ def __call__(self, *args, **kwargs):
+ if self.obj is None:
+ raise AttributeError("Cannot use uninitialized Proxy.")
+ return self.obj(*args, **kwargs)
class RetryOperationalError(object):
+ def execute_sql(self, sql, params=None, commit=True):
+ try:
+ cursor = super(RetryOperationalError, self).execute_sql(sql, params, commit)
+ except OperationalError:
+ if not self.is_closed():
+ self.close()
- def execute_sql(self, sql, params=None, commit=True):
- try:
- cursor = super(RetryOperationalError, self).execute_sql(sql, params, commit)
- except OperationalError:
- if not self.is_closed():
- self.close()
+ with __exception_wrapper__:
+ cursor = self.cursor()
+ cursor.execute(sql, params or ())
+ if commit and not self.in_transaction():
+ self.commit()
- with __exception_wrapper__:
- cursor = self.cursor()
- cursor.execute(sql, params or ())
- if commit and not self.in_transaction():
- self.commit()
-
- return cursor
+ return cursor
class CloseForLongOperation(object):
- """ Helper object which disconnects the database then reconnects after the nested operation
+ """ Helper object which disconnects the database then reconnects after the nested operation
completes.
"""
- def __init__(self, config_object):
- self.config_object = config_object
+ def __init__(self, config_object):
+ self.config_object = config_object
- def __enter__(self):
- if self.config_object.get('TESTING') is True:
- return
+ def __enter__(self):
+ if self.config_object.get("TESTING") is True:
+ return
- close_db_filter(None)
+ close_db_filter(None)
- def __exit__(self, typ, value, traceback):
- # Note: Nothing to do. The next SQL call will reconnect automatically.
- pass
+ def __exit__(self, typ, value, traceback):
+ # Note: Nothing to do. The next SQL call will reconnect automatically.
+ pass
class UseThenDisconnect(object):
- """ Helper object for conducting work with a database and then tearing it down. """
+ """ Helper object for conducting work with a database and then tearing it down. """
- def __init__(self, config_object):
- self.config_object = config_object
+ def __init__(self, config_object):
+ self.config_object = config_object
- def __enter__(self):
- pass
+ def __enter__(self):
+ pass
- def __exit__(self, typ, value, traceback):
- if self.config_object.get('TESTING') is True:
- return
+ def __exit__(self, typ, value, traceback):
+ if self.config_object.get("TESTING") is True:
+ return
- close_db_filter(None)
+ close_db_filter(None)
class TupleSelector(object):
- """ Helper class for selecting tuples from a peewee query and easily accessing
+ """ Helper class for selecting tuples from a peewee query and easily accessing
them as if they were objects.
"""
- class _TupleWrapper(object):
- def __init__(self, data, fields):
- self._data = data
- self._fields = fields
- def get(self, field):
- return self._data[self._fields.index(TupleSelector.tuple_reference_key(field))]
+ class _TupleWrapper(object):
+ def __init__(self, data, fields):
+ self._data = data
+ self._fields = fields
- @classmethod
- def tuple_reference_key(cls, field):
- """ Returns a string key for referencing a field in a TupleSelector. """
- if isinstance(field, Function):
- return field.name + ','.join([cls.tuple_reference_key(arg) for arg in field.arguments])
+ def get(self, field):
+ return self._data[
+ self._fields.index(TupleSelector.tuple_reference_key(field))
+ ]
- if isinstance(field, Field):
- return field.name + ':' + field.model.__name__
+ @classmethod
+ def tuple_reference_key(cls, field):
+ """ Returns a string key for referencing a field in a TupleSelector. """
+ if isinstance(field, Function):
+ return field.name + ",".join(
+ [cls.tuple_reference_key(arg) for arg in field.arguments]
+ )
- raise Exception('Unknown field type %s in TupleSelector' % field._node_type)
+ if isinstance(field, Field):
+ return field.name + ":" + field.model.__name__
- def __init__(self, query, fields):
- self._query = query.select(*fields).tuples()
- self._fields = [TupleSelector.tuple_reference_key(field) for field in fields]
+ raise Exception("Unknown field type %s in TupleSelector" % field._node_type)
- def __iter__(self):
- return self._build_iterator()
+ def __init__(self, query, fields):
+ self._query = query.select(*fields).tuples()
+ self._fields = [TupleSelector.tuple_reference_key(field) for field in fields]
- def _build_iterator(self):
- for tuple_data in self._query:
- yield TupleSelector._TupleWrapper(tuple_data, self._fields)
+ def __iter__(self):
+ return self._build_iterator()
+
+ def _build_iterator(self):
+ for tuple_data in self._query:
+ yield TupleSelector._TupleWrapper(tuple_data, self._fields)
db = Proxy()
@@ -278,141 +295,170 @@ ensure_under_transaction = CallableProxy()
def validate_database_url(url, db_kwargs, connect_timeout=5):
- """ Validates that we can connect to the given database URL, with the given kwargs. Raises
+ """ Validates that we can connect to the given database URL, with the given kwargs. Raises
an exception if the validation fails. """
- db_kwargs = db_kwargs.copy()
+ db_kwargs = db_kwargs.copy()
- try:
- driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False,
- allow_pooling=False)
- driver.connect()
- finally:
try:
- driver.close()
- except:
- pass
+ driver = _db_from_url(
+ url,
+ db_kwargs,
+ connect_timeout=connect_timeout,
+ allow_retry=False,
+ allow_pooling=False,
+ )
+ driver.connect()
+ finally:
+ try:
+ driver.close()
+ except:
+ pass
def validate_database_precondition(url, db_kwargs, connect_timeout=5):
- """ Validates that we can connect to the given database URL and the database meets our
+ """ Validates that we can connect to the given database URL and the database meets our
precondition. Raises an exception if the validation fails. """
- db_kwargs = db_kwargs.copy()
- try:
- driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False,
- allow_pooling=False)
- driver.connect()
- pre_condition_check = PRECONDITION_VALIDATION.get(make_url(url).drivername)
- if pre_condition_check:
- pre_condition_check(driver)
-
- finally:
+ db_kwargs = db_kwargs.copy()
try:
- driver.close()
- except:
- pass
+ driver = _db_from_url(
+ url,
+ db_kwargs,
+ connect_timeout=connect_timeout,
+ allow_retry=False,
+ allow_pooling=False,
+ )
+ driver.connect()
+ pre_condition_check = PRECONDITION_VALIDATION.get(make_url(url).drivername)
+ if pre_condition_check:
+ pre_condition_check(driver)
+
+ finally:
+ try:
+ driver.close()
+ except:
+ pass
def _wrap_for_retry(driver):
- return type('Retrying' + driver.__name__, (RetryOperationalError, driver), {})
+ return type("Retrying" + driver.__name__, (RetryOperationalError, driver), {})
-def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT,
- allow_pooling=True, allow_retry=True):
- parsed_url = make_url(url)
+def _db_from_url(
+ url,
+ db_kwargs,
+ connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT,
+ allow_pooling=True,
+ allow_retry=True,
+):
+ parsed_url = make_url(url)
- if parsed_url.host:
- db_kwargs['host'] = parsed_url.host
- if parsed_url.port:
- db_kwargs['port'] = parsed_url.port
- if parsed_url.username:
- db_kwargs['user'] = parsed_url.username
- if parsed_url.password:
- db_kwargs['password'] = parsed_url.password
+ if parsed_url.host:
+ db_kwargs["host"] = parsed_url.host
+ if parsed_url.port:
+ db_kwargs["port"] = parsed_url.port
+ if parsed_url.username:
+ db_kwargs["user"] = parsed_url.username
+ if parsed_url.password:
+ db_kwargs["password"] = parsed_url.password
- # Remove threadlocals. It used to be required.
- db_kwargs.pop('threadlocals', None)
+ # Remove threadlocals. It used to be required.
+ db_kwargs.pop("threadlocals", None)
- # Note: sqlite does not support connect_timeout.
- if parsed_url.drivername != 'sqlite':
- db_kwargs['connect_timeout'] = db_kwargs.get('connect_timeout', connect_timeout)
+ # Note: sqlite does not support connect_timeout.
+ if parsed_url.drivername != "sqlite":
+ db_kwargs["connect_timeout"] = db_kwargs.get("connect_timeout", connect_timeout)
- drivers = _SCHEME_DRIVERS[parsed_url.drivername]
- driver = drivers.driver
- if allow_pooling and os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true':
- driver = drivers.pooled_driver
- db_kwargs['stale_timeout'] = db_kwargs.get('stale_timeout', None)
- db_kwargs['max_connections'] = db_kwargs.get('max_connections', None)
- logger.info('Connection pooling enabled for %s; stale timeout: %s; max connection count: %s',
- parsed_url.drivername, db_kwargs['stale_timeout'], db_kwargs['max_connections'])
- else:
- logger.info('Connection pooling disabled for %s', parsed_url.drivername)
- db_kwargs.pop('stale_timeout', None)
- db_kwargs.pop('max_connections', None)
+ drivers = _SCHEME_DRIVERS[parsed_url.drivername]
+ driver = drivers.driver
+ if allow_pooling and os.getenv("DB_CONNECTION_POOLING", "false").lower() == "true":
+ driver = drivers.pooled_driver
+ db_kwargs["stale_timeout"] = db_kwargs.get("stale_timeout", None)
+ db_kwargs["max_connections"] = db_kwargs.get("max_connections", None)
+ logger.info(
+ "Connection pooling enabled for %s; stale timeout: %s; max connection count: %s",
+ parsed_url.drivername,
+ db_kwargs["stale_timeout"],
+ db_kwargs["max_connections"],
+ )
+ else:
+ logger.info("Connection pooling disabled for %s", parsed_url.drivername)
+ db_kwargs.pop("stale_timeout", None)
+ db_kwargs.pop("max_connections", None)
- for key, value in _EXTRA_ARGS.get(parsed_url.drivername, {}).iteritems():
- if key not in db_kwargs:
- db_kwargs[key] = value
+ for key, value in _EXTRA_ARGS.get(parsed_url.drivername, {}).iteritems():
+ if key not in db_kwargs:
+ db_kwargs[key] = value
- if allow_retry:
- driver = _wrap_for_retry(driver)
+ if allow_retry:
+ driver = _wrap_for_retry(driver)
- created = driver(parsed_url.database, **db_kwargs)
-
- # Revert the behavior "fixed" in:
- # https://github.com/coleifer/peewee/commit/36bd887ac07647c60dfebe610b34efabec675706
- if parsed_url.drivername.find("mysql") >= 0:
- created.compound_select_parentheses = 0
- return created
+ created = driver(parsed_url.database, **db_kwargs)
+
+ # Revert the behavior "fixed" in:
+ # https://github.com/coleifer/peewee/commit/36bd887ac07647c60dfebe610b34efabec675706
+ if parsed_url.drivername.find("mysql") >= 0:
+ created.compound_select_parentheses = 0
+ return created
def configure(config_object, testing=False):
- logger.debug('Configuring database')
- db_kwargs = dict(config_object['DB_CONNECTION_ARGS'])
- write_db_uri = config_object['DB_URI']
- db.initialize(_db_from_url(write_db_uri, db_kwargs))
+ logger.debug("Configuring database")
+ db_kwargs = dict(config_object["DB_CONNECTION_ARGS"])
+ write_db_uri = config_object["DB_URI"]
+ db.initialize(_db_from_url(write_db_uri, db_kwargs))
- parsed_write_uri = make_url(write_db_uri)
- db_random_func.initialize(SCHEME_RANDOM_FUNCTION[parsed_write_uri.drivername])
- db_match_func.initialize(SCHEME_MATCH_FUNCTION[parsed_write_uri.drivername])
- db_for_update.initialize(SCHEME_SPECIALIZED_FOR_UPDATE.get(parsed_write_uri.drivername,
- real_for_update))
- db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername,
- function_concat))
- db_encrypter.initialize(FieldEncrypter(config_object.get('DATABASE_SECRET_KEY')))
+ parsed_write_uri = make_url(write_db_uri)
+ db_random_func.initialize(SCHEME_RANDOM_FUNCTION[parsed_write_uri.drivername])
+ db_match_func.initialize(SCHEME_MATCH_FUNCTION[parsed_write_uri.drivername])
+ db_for_update.initialize(
+ SCHEME_SPECIALIZED_FOR_UPDATE.get(parsed_write_uri.drivername, real_for_update)
+ )
+ db_concat_func.initialize(
+ SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, function_concat)
+ )
+ db_encrypter.initialize(FieldEncrypter(config_object.get("DATABASE_SECRET_KEY")))
- read_replicas = config_object.get('DB_READ_REPLICAS', None)
- is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly'
+ read_replicas = config_object.get("DB_READ_REPLICAS", None)
+ is_read_only = config_object.get("REGISTRY_STATE", "normal") == "readonly"
- read_replica_dbs = []
- if read_replicas:
- read_replica_dbs = [_db_from_url(config['DB_URI'], db_kwargs) for config in read_replicas]
+ read_replica_dbs = []
+ if read_replicas:
+ read_replica_dbs = [
+ _db_from_url(config["DB_URI"], db_kwargs) for config in read_replicas
+ ]
- read_only_config.initialize(ReadOnlyConfig(is_read_only, read_replica_dbs))
+ read_only_config.initialize(ReadOnlyConfig(is_read_only, read_replica_dbs))
- def _db_transaction():
- return config_object['DB_TRANSACTION_FACTORY'](db)
+ def _db_transaction():
+ return config_object["DB_TRANSACTION_FACTORY"](db)
- @contextmanager
- def _ensure_under_transaction():
- if not testing and not config_object['TESTING']:
- if db.transaction_depth() == 0:
- raise Exception('Expected to be under a transaction')
+ @contextmanager
+ def _ensure_under_transaction():
+ if not testing and not config_object["TESTING"]:
+ if db.transaction_depth() == 0:
+ raise Exception("Expected to be under a transaction")
- yield
+ yield
+
+ db_transaction.initialize(_db_transaction)
+ ensure_under_transaction.initialize(_ensure_under_transaction)
- db_transaction.initialize(_db_transaction)
- ensure_under_transaction.initialize(_ensure_under_transaction)
def random_string_generator(length=16):
- def random_string():
- random = SystemRandom()
- return ''.join([random.choice(string.ascii_uppercase + string.digits)
- for _ in range(length)])
- return random_string
+ def random_string():
+ random = SystemRandom()
+ return "".join(
+ [
+ random.choice(string.ascii_uppercase + string.digits)
+ for _ in range(length)
+ ]
+ )
+
+ return random_string
def uuid_generator():
- return str(uuid.uuid4())
+ return str(uuid.uuid4())
get_epoch_timestamp = lambda: int(time.time())
@@ -420,1374 +466,1494 @@ get_epoch_timestamp_ms = lambda: int(time.time() * 1000)
def close_db_filter(_):
- if db.obj is not None and not db.is_closed():
- logger.debug('Disconnecting from database.')
- db.close()
+ if db.obj is not None and not db.is_closed():
+ logger.debug("Disconnecting from database.")
+ db.close()
- if read_only_config.obj is not None:
- for read_replica in read_only_config.obj.read_replicas:
- if not read_replica.is_closed():
- logger.debug('Disconnecting from read replica.')
- read_replica.close()
+ if read_only_config.obj is not None:
+ for read_replica in read_only_config.obj.read_replicas:
+ if not read_replica.is_closed():
+ logger.debug("Disconnecting from read replica.")
+ read_replica.close()
class QuayUserField(ForeignKeyField):
- def __init__(self, allows_robots=False, robot_null_delete=False, *args, **kwargs):
- self.allows_robots = allows_robots
- self.robot_null_delete = robot_null_delete
- if 'model' not in kwargs:
- kwargs['model'] = User
- super(QuayUserField, self).__init__(*args, **kwargs)
+ def __init__(self, allows_robots=False, robot_null_delete=False, *args, **kwargs):
+ self.allows_robots = allows_robots
+ self.robot_null_delete = robot_null_delete
+ if "model" not in kwargs:
+ kwargs["model"] = User
+ super(QuayUserField, self).__init__(*args, **kwargs)
@lru_cache(maxsize=16)
def _get_enum_field_values(enum_field):
- values = []
- for row in enum_field.rel_model.select():
- key = getattr(row, enum_field.enum_key_field)
- value = getattr(row, 'id')
- values.append((key, value))
- return Enum(enum_field.rel_model.__name__, values)
+ values = []
+ for row in enum_field.rel_model.select():
+ key = getattr(row, enum_field.enum_key_field)
+ value = getattr(row, "id")
+ values.append((key, value))
+ return Enum(enum_field.rel_model.__name__, values)
class EnumField(ForeignKeyField):
- """ Create a cached python Enum from an EnumTable """
- def __init__(self, model, enum_key_field='name', *args, **kwargs):
- """
+ """ Create a cached python Enum from an EnumTable """
+
+ def __init__(self, model, enum_key_field="name", *args, **kwargs):
+ """
model is the EnumTable model-class (see ForeignKeyField)
enum_key_field is the field from the EnumTable to use as the enum name
"""
- self.enum_key_field = enum_key_field
- super(EnumField, self).__init__(model, *args, **kwargs)
+ self.enum_key_field = enum_key_field
+ super(EnumField, self).__init__(model, *args, **kwargs)
- @property
- def enum(self):
- """ Returns a python enun.Enum generated from the associated EnumTable """
- return _get_enum_field_values(self)
+ @property
+ def enum(self):
+ """ Returns a python enun.Enum generated from the associated EnumTable """
+ return _get_enum_field_values(self)
- def get_id(self, name):
- """ Returns the ForeignKeyId from the name field
+ def get_id(self, name):
+ """ Returns the ForeignKeyId from the name field
Example:
>>> Repository.repo_kind.get_id("application")
2
"""
- try:
- return self.enum[name].value
- except KeyError:
- raise self.rel_model.DoesNotExist
+ try:
+ return self.enum[name].value
+ except KeyError:
+ raise self.rel_model.DoesNotExist
- def get_name(self, value):
- """ Returns the name value from the ForeignKeyId
+ def get_name(self, value):
+ """ Returns the name value from the ForeignKeyId
Example:
>>> Repository.repo_kind.get_name(2)
"application"
"""
- try:
- return self.enum(value).name
- except ValueError:
- raise self.rel_model.DoesNotExist
+ try:
+ return self.enum(value).name
+ except ValueError:
+ raise self.rel_model.DoesNotExist
def deprecated_field(field, flag):
- """ Marks a field as deprecated and removes it from the peewee model if the
+ """ Marks a field as deprecated and removes it from the peewee model if the
flag is not set. A flag is defined in the active_migration module and will
be associated with one or more migration phases.
"""
- if ActiveDataMigration.has_flag(flag):
- return field
+ if ActiveDataMigration.has_flag(flag):
+ return field
- return None
+ return None
class BaseModel(ReadReplicaSupportedModel):
- class Meta:
- database = db
- encrypter = db_encrypter
- read_only_config = read_only_config
+ class Meta:
+ database = db
+ encrypter = db_encrypter
+ read_only_config = read_only_config
- def __getattribute__(self, name):
- """ Adds _id accessors so that foreign key field IDs can be looked up without making
+ def __getattribute__(self, name):
+ """ Adds _id accessors so that foreign key field IDs can be looked up without making
a database roundtrip.
"""
- if name.endswith('_id'):
- field_name = name[0:len(name) - 3]
- if field_name in self._meta.fields:
- return self.__data__.get(field_name)
+ if name.endswith("_id"):
+ field_name = name[0 : len(name) - 3]
+ if field_name in self._meta.fields:
+ return self.__data__.get(field_name)
- return super(BaseModel, self).__getattribute__(name)
+ return super(BaseModel, self).__getattribute__(name)
class User(BaseModel):
- uuid = CharField(default=uuid_generator, max_length=36, null=True, index=True)
- username = CharField(unique=True, index=True)
- password_hash = CharField(null=True)
- email = CharField(unique=True, index=True,
- default=random_string_generator(length=64))
- verified = BooleanField(default=False)
- stripe_id = CharField(index=True, null=True)
- organization = BooleanField(default=False, index=True)
- robot = BooleanField(default=False, index=True)
- invoice_email = BooleanField(default=False)
- invalid_login_attempts = IntegerField(default=0)
- last_invalid_login = DateTimeField(default=datetime.utcnow)
- removed_tag_expiration_s = IntegerField(default=1209600) # Two weeks
- enabled = BooleanField(default=True)
- invoice_email_address = CharField(null=True, index=True)
+ uuid = CharField(default=uuid_generator, max_length=36, null=True, index=True)
+ username = CharField(unique=True, index=True)
+ password_hash = CharField(null=True)
+ email = CharField(
+ unique=True, index=True, default=random_string_generator(length=64)
+ )
+ verified = BooleanField(default=False)
+ stripe_id = CharField(index=True, null=True)
+ organization = BooleanField(default=False, index=True)
+ robot = BooleanField(default=False, index=True)
+ invoice_email = BooleanField(default=False)
+ invalid_login_attempts = IntegerField(default=0)
+ last_invalid_login = DateTimeField(default=datetime.utcnow)
+ removed_tag_expiration_s = IntegerField(default=1209600) # Two weeks
+ enabled = BooleanField(default=True)
+ invoice_email_address = CharField(null=True, index=True)
- given_name = CharField(null=True)
- family_name = CharField(null=True)
- company = CharField(null=True)
- location = CharField(null=True)
+ given_name = CharField(null=True)
+ family_name = CharField(null=True)
+ company = CharField(null=True)
+ location = CharField(null=True)
- maximum_queued_builds_count = IntegerField(null=True)
- creation_date = DateTimeField(default=datetime.utcnow, null=True)
- last_accessed = DateTimeField(null=True, index=True)
+ maximum_queued_builds_count = IntegerField(null=True)
+ creation_date = DateTimeField(default=datetime.utcnow, null=True)
+ last_accessed = DateTimeField(null=True, index=True)
- def delete_instance(self, recursive=False, delete_nullable=False):
- # If we are deleting a robot account, only execute the subset of queries necessary.
- if self.robot:
- # For all the model dependencies, only delete those that allow robots.
- for query, fk in reversed(list(self.dependencies(search_nullable=True))):
- if isinstance(fk, QuayUserField) and fk.allows_robots:
- _model = fk.model
+ def delete_instance(self, recursive=False, delete_nullable=False):
+ # If we are deleting a robot account, only execute the subset of queries necessary.
+ if self.robot:
+ # For all the model dependencies, only delete those that allow robots.
+ for query, fk in reversed(list(self.dependencies(search_nullable=True))):
+ if isinstance(fk, QuayUserField) and fk.allows_robots:
+ _model = fk.model
- if fk.robot_null_delete:
- _model.update(**{fk.name: None}).where(query).execute()
- else:
- _model.delete().where(query).execute()
+ if fk.robot_null_delete:
+ _model.update(**{fk.name: None}).where(query).execute()
+ else:
+ _model.delete().where(query).execute()
- # Delete the instance itself.
- super(User, self).delete_instance(recursive=False, delete_nullable=False)
- else:
- if not recursive:
- raise RuntimeError('Non-recursive delete on user.')
+ # Delete the instance itself.
+ super(User, self).delete_instance(recursive=False, delete_nullable=False)
+ else:
+ if not recursive:
+ raise RuntimeError("Non-recursive delete on user.")
- # These models don't need to use transitive deletes, because the referenced objects
- # are cleaned up directly in the model.
- skip_transitive_deletes = {Image, Repository, Team, RepositoryBuild, ServiceKeyApproval,
- RepositoryBuildTrigger, ServiceKey, RepositoryPermission,
- TeamMemberInvite, Star, RepositoryAuthorizedEmail, TeamMember,
- RepositoryTag, PermissionPrototype, DerivedStorageForImage,
- TagManifest, AccessToken, OAuthAccessToken, BlobUpload,
- RepositoryNotification, OAuthAuthorizationCode,
- RepositoryActionCount, TagManifestLabel,
- TeamSync, RepositorySearchScore,
- DeletedNamespace, RepoMirrorRule,
- NamespaceGeoRestriction} | appr_classes | v22_classes | transition_classes
- delete_instance_filtered(self, User, delete_nullable, skip_transitive_deletes)
+ # These models don't need to use transitive deletes, because the referenced objects
+ # are cleaned up directly in the model.
+ skip_transitive_deletes = (
+ {
+ Image,
+ Repository,
+ Team,
+ RepositoryBuild,
+ ServiceKeyApproval,
+ RepositoryBuildTrigger,
+ ServiceKey,
+ RepositoryPermission,
+ TeamMemberInvite,
+ Star,
+ RepositoryAuthorizedEmail,
+ TeamMember,
+ RepositoryTag,
+ PermissionPrototype,
+ DerivedStorageForImage,
+ TagManifest,
+ AccessToken,
+ OAuthAccessToken,
+ BlobUpload,
+ RepositoryNotification,
+ OAuthAuthorizationCode,
+ RepositoryActionCount,
+ TagManifestLabel,
+ TeamSync,
+ RepositorySearchScore,
+ DeletedNamespace,
+ RepoMirrorRule,
+ NamespaceGeoRestriction,
+ }
+ | appr_classes
+ | v22_classes
+ | transition_classes
+ )
+ delete_instance_filtered(
+ self, User, delete_nullable, skip_transitive_deletes
+ )
Namespace = User.alias()
class RobotAccountMetadata(BaseModel):
- robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
- description = CharField()
- unstructured_json = JSONField()
+ robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
+ description = CharField()
+ unstructured_json = JSONField()
class RobotAccountToken(BaseModel):
- robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
- token = EncryptedCharField(default_token_length=64)
- fully_migrated = BooleanField(default=False)
+ robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
+ token = EncryptedCharField(default_token_length=64)
+ fully_migrated = BooleanField(default=False)
class DeletedNamespace(BaseModel):
- namespace = QuayUserField(index=True, allows_robots=False, unique=True)
- marked = DateTimeField(default=datetime.now)
- original_username = CharField(index=True)
- original_email = CharField(index=True)
- queue_id = CharField(null=True, index=True)
+ namespace = QuayUserField(index=True, allows_robots=False, unique=True)
+ marked = DateTimeField(default=datetime.now)
+ original_username = CharField(index=True)
+ original_email = CharField(index=True)
+ queue_id = CharField(null=True, index=True)
class NamespaceGeoRestriction(BaseModel):
- namespace = QuayUserField(index=True, allows_robots=False)
- added = DateTimeField(default=datetime.utcnow)
- description = CharField()
- unstructured_json = JSONField()
- restricted_region_iso_code = CharField(index=True)
+ namespace = QuayUserField(index=True, allows_robots=False)
+ added = DateTimeField(default=datetime.utcnow)
+ description = CharField()
+ unstructured_json = JSONField()
+ restricted_region_iso_code = CharField(index=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('namespace', 'restricted_region_iso_code'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("namespace", "restricted_region_iso_code"), True),)
class UserPromptTypes(object):
- CONFIRM_USERNAME = 'confirm_username'
- ENTER_NAME = 'enter_name'
- ENTER_COMPANY = 'enter_company'
+ CONFIRM_USERNAME = "confirm_username"
+ ENTER_NAME = "enter_name"
+ ENTER_COMPANY = "enter_company"
class UserPromptKind(BaseModel):
- name = CharField(index=True)
+ name = CharField(index=True)
class UserPrompt(BaseModel):
- user = QuayUserField(allows_robots=False, index=True)
- kind = ForeignKeyField(UserPromptKind)
+ user = QuayUserField(allows_robots=False, index=True)
+ kind = ForeignKeyField(UserPromptKind)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('user', 'kind'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("user", "kind"), True),)
class TeamRole(BaseModel):
- name = CharField(index=True)
+ name = CharField(index=True)
class Team(BaseModel):
- name = CharField(index=True)
- organization = QuayUserField(index=True)
- role = EnumField(TeamRole)
- description = TextField(default='')
+ name = CharField(index=True)
+ organization = QuayUserField(index=True)
+ role = EnumField(TeamRole)
+ description = TextField(default="")
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # A team name must be unique within an organization
- (('name', 'organization'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # A team name must be unique within an organization
+ (("name", "organization"), True),
+ )
class TeamMember(BaseModel):
- user = QuayUserField(allows_robots=True, index=True)
- team = ForeignKeyField(Team)
+ user = QuayUserField(allows_robots=True, index=True)
+ team = ForeignKeyField(Team)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # A user may belong to a team only once
- (('user', 'team'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # A user may belong to a team only once
+ (("user", "team"), True),
+ )
class TeamMemberInvite(BaseModel):
- # Note: Either user OR email will be filled in, but not both.
- user = QuayUserField(index=True, null=True)
- email = CharField(null=True)
- team = ForeignKeyField(Team)
- inviter = ForeignKeyField(User, backref='inviter')
- invite_token = CharField(default=urn_generator(['teaminvite']))
+ # Note: Either user OR email will be filled in, but not both.
+ user = QuayUserField(index=True, null=True)
+ email = CharField(null=True)
+ team = ForeignKeyField(Team)
+ inviter = ForeignKeyField(User, backref="inviter")
+ invite_token = CharField(default=urn_generator(["teaminvite"]))
class LoginService(BaseModel):
- name = CharField(unique=True, index=True)
+ name = CharField(unique=True, index=True)
class TeamSync(BaseModel):
- team = ForeignKeyField(Team, unique=True)
+ team = ForeignKeyField(Team, unique=True)
- transaction_id = CharField()
- last_updated = DateTimeField(null=True, index=True)
- service = ForeignKeyField(LoginService)
- config = JSONField()
+ transaction_id = CharField()
+ last_updated = DateTimeField(null=True, index=True)
+ service = ForeignKeyField(LoginService)
+ config = JSONField()
class FederatedLogin(BaseModel):
- user = QuayUserField(allows_robots=True, index=True)
- service = ForeignKeyField(LoginService)
- service_ident = CharField()
- metadata_json = TextField(default='{}')
+ user = QuayUserField(allows_robots=True, index=True)
+ service = ForeignKeyField(LoginService)
+ service_ident = CharField()
+ metadata_json = TextField(default="{}")
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on service and the local service id
- (('service', 'service_ident'), True),
-
- # a user may only have one federated login per service
- (('service', 'user'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on service and the local service id
+ (("service", "service_ident"), True),
+ # a user may only have one federated login per service
+ (("service", "user"), True),
+ )
class Visibility(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class RepositoryKind(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
@unique
class RepositoryState(IntEnum):
- """
+ """
Possible states of a repository.
NORMAL: Regular repo where all actions are possible
READ_ONLY: Only read actions, such as pull, are allowed regardless of specific user permissions
MIRROR: Equivalent to READ_ONLY except that mirror robot has write permission
"""
- NORMAL = 0
- READ_ONLY = 1
- MIRROR = 2
+
+ NORMAL = 0
+ READ_ONLY = 1
+ MIRROR = 2
class Repository(BaseModel):
- namespace_user = QuayUserField(null=True)
- name = FullIndexedCharField(match_function=db_match_func)
- visibility = EnumField(Visibility)
- description = FullIndexedTextField(match_function=db_match_func, null=True)
- badge_token = CharField(default=uuid_generator)
- kind = EnumField(RepositoryKind)
- trust_enabled = BooleanField(default=False)
- state = ClientEnumField(RepositoryState, default=RepositoryState.NORMAL)
+ namespace_user = QuayUserField(null=True)
+ name = FullIndexedCharField(match_function=db_match_func)
+ visibility = EnumField(Visibility)
+ description = FullIndexedTextField(match_function=db_match_func, null=True)
+ badge_token = CharField(default=uuid_generator)
+ kind = EnumField(RepositoryKind)
+ trust_enabled = BooleanField(default=False)
+ state = ClientEnumField(RepositoryState, default=RepositoryState.NORMAL)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on namespace and name
- (('namespace_user', 'name'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on namespace and name
+ (("namespace_user", "name"), True),
+ )
- def delete_instance(self, recursive=False, delete_nullable=False):
- if not recursive:
- raise RuntimeError('Non-recursive delete on repository.')
+ def delete_instance(self, recursive=False, delete_nullable=False):
+ if not recursive:
+ raise RuntimeError("Non-recursive delete on repository.")
- # These models don't need to use transitive deletes, because the referenced objects
- # are cleaned up directly
- skip_transitive_deletes = ({RepositoryTag, RepositoryBuild, RepositoryBuildTrigger, BlobUpload,
- Image, TagManifest, TagManifestLabel, Label, DerivedStorageForImage,
- RepositorySearchScore, RepoMirrorConfig, RepoMirrorRule}
- | appr_classes | v22_classes | transition_classes)
+ # These models don't need to use transitive deletes, because the referenced objects
+ # are cleaned up directly
+ skip_transitive_deletes = (
+ {
+ RepositoryTag,
+ RepositoryBuild,
+ RepositoryBuildTrigger,
+ BlobUpload,
+ Image,
+ TagManifest,
+ TagManifestLabel,
+ Label,
+ DerivedStorageForImage,
+ RepositorySearchScore,
+ RepoMirrorConfig,
+ RepoMirrorRule,
+ }
+ | appr_classes
+ | v22_classes
+ | transition_classes
+ )
- delete_instance_filtered(self, Repository, delete_nullable, skip_transitive_deletes)
+ delete_instance_filtered(
+ self, Repository, delete_nullable, skip_transitive_deletes
+ )
class RepositorySearchScore(BaseModel):
- repository = ForeignKeyField(Repository, unique=True)
- score = BigIntegerField(index=True, default=0)
- last_updated = DateTimeField(null=True)
+ repository = ForeignKeyField(Repository, unique=True)
+ score = BigIntegerField(index=True, default=0)
+ last_updated = DateTimeField(null=True)
class Star(BaseModel):
- user = ForeignKeyField(User)
- repository = ForeignKeyField(Repository)
- created = DateTimeField(default=datetime.now)
+ user = ForeignKeyField(User)
+ repository = ForeignKeyField(Repository)
+ created = DateTimeField(default=datetime.now)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on user and repository
- (('user', 'repository'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on user and repository
+ (("user", "repository"), True),
+ )
class Role(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class RepositoryPermission(BaseModel):
- team = ForeignKeyField(Team, null=True)
- user = QuayUserField(allows_robots=True, null=True)
- repository = ForeignKeyField(Repository)
- role = ForeignKeyField(Role)
+ team = ForeignKeyField(Team, null=True)
+ user = QuayUserField(allows_robots=True, null=True)
+ repository = ForeignKeyField(Repository)
+ role = ForeignKeyField(Role)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('team', 'repository'), True),
- (('user', 'repository'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("team", "repository"), True), (("user", "repository"), True))
class PermissionPrototype(BaseModel):
- org = QuayUserField(index=True, backref='orgpermissionproto')
- uuid = CharField(default=uuid_generator, index=True)
- activating_user = QuayUserField(allows_robots=True, index=True, null=True,
- backref='userpermissionproto')
- delegate_user = QuayUserField(allows_robots=True, backref='receivingpermission',
- null=True)
- delegate_team = ForeignKeyField(Team, backref='receivingpermission',
- null=True)
- role = ForeignKeyField(Role)
-
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('org', 'activating_user'), False),
+ org = QuayUserField(index=True, backref="orgpermissionproto")
+ uuid = CharField(default=uuid_generator, index=True)
+ activating_user = QuayUserField(
+ allows_robots=True, index=True, null=True, backref="userpermissionproto"
)
+ delegate_user = QuayUserField(
+ allows_robots=True, backref="receivingpermission", null=True
+ )
+ delegate_team = ForeignKeyField(Team, backref="receivingpermission", null=True)
+ role = ForeignKeyField(Role)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("org", "activating_user"), False),)
class AccessTokenKind(BaseModel):
- name = CharField(unique=True, index=True)
+ name = CharField(unique=True, index=True)
class AccessToken(BaseModel):
- friendly_name = CharField(null=True)
+ friendly_name = CharField(null=True)
- # TODO(remove-unenc): This field is deprecated and should be removed soon.
- code = deprecated_field(
- CharField(default=random_string_generator(length=64), unique=True, index=True, null=True),
- ERTMigrationFlags.WRITE_OLD_FIELDS)
+ # TODO(remove-unenc): This field is deprecated and should be removed soon.
+ code = deprecated_field(
+ CharField(
+ default=random_string_generator(length=64),
+ unique=True,
+ index=True,
+ null=True,
+ ),
+ ERTMigrationFlags.WRITE_OLD_FIELDS,
+ )
- token_name = CharField(default=random_string_generator(length=32), unique=True, index=True)
- token_code = EncryptedCharField(default_token_length=32)
+ token_name = CharField(
+ default=random_string_generator(length=32), unique=True, index=True
+ )
+ token_code = EncryptedCharField(default_token_length=32)
- repository = ForeignKeyField(Repository)
- created = DateTimeField(default=datetime.now)
- role = ForeignKeyField(Role)
- temporary = BooleanField(default=True)
- kind = ForeignKeyField(AccessTokenKind, null=True)
+ repository = ForeignKeyField(Repository)
+ created = DateTimeField(default=datetime.now)
+ role = ForeignKeyField(Role)
+ temporary = BooleanField(default=True)
+ kind = ForeignKeyField(AccessTokenKind, null=True)
- def get_code(self):
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- return self.code
- else:
- return self.token_name + self.token_code.decrypt()
+ def get_code(self):
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ return self.code
+ else:
+ return self.token_name + self.token_code.decrypt()
class BuildTriggerService(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class DisableReason(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class RepositoryBuildTrigger(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- service = ForeignKeyField(BuildTriggerService)
- repository = ForeignKeyField(Repository)
- connected_user = QuayUserField()
+ uuid = CharField(default=uuid_generator, index=True)
+ service = ForeignKeyField(BuildTriggerService)
+ repository = ForeignKeyField(Repository)
+ connected_user = QuayUserField()
- # TODO(remove-unenc): These fields are deprecated and should be removed soon.
- auth_token = deprecated_field(CharField(null=True), ERTMigrationFlags.WRITE_OLD_FIELDS)
- private_key = deprecated_field(TextField(null=True), ERTMigrationFlags.WRITE_OLD_FIELDS)
+ # TODO(remove-unenc): These fields are deprecated and should be removed soon.
+ auth_token = deprecated_field(
+ CharField(null=True), ERTMigrationFlags.WRITE_OLD_FIELDS
+ )
+ private_key = deprecated_field(
+ TextField(null=True), ERTMigrationFlags.WRITE_OLD_FIELDS
+ )
- secure_auth_token = EncryptedCharField(null=True)
- secure_private_key = EncryptedTextField(null=True)
- fully_migrated = BooleanField(default=False)
+ secure_auth_token = EncryptedCharField(null=True)
+ secure_private_key = EncryptedTextField(null=True)
+ fully_migrated = BooleanField(default=False)
- config = TextField(default='{}')
- write_token = ForeignKeyField(AccessToken, null=True)
- pull_robot = QuayUserField(allows_robots=True, null=True, backref='triggerpullrobot',
- robot_null_delete=True)
+ config = TextField(default="{}")
+ write_token = ForeignKeyField(AccessToken, null=True)
+ pull_robot = QuayUserField(
+ allows_robots=True,
+ null=True,
+ backref="triggerpullrobot",
+ robot_null_delete=True,
+ )
- enabled = BooleanField(default=True)
- disabled_reason = EnumField(DisableReason, null=True)
- disabled_datetime = DateTimeField(default=datetime.utcnow, null=True, index=True)
- successive_failure_count = IntegerField(default=0)
- successive_internal_error_count = IntegerField(default=0)
+ enabled = BooleanField(default=True)
+ disabled_reason = EnumField(DisableReason, null=True)
+ disabled_datetime = DateTimeField(default=datetime.utcnow, null=True, index=True)
+ successive_failure_count = IntegerField(default=0)
+ successive_internal_error_count = IntegerField(default=0)
class EmailConfirmation(BaseModel):
- code = CharField(default=random_string_generator(), unique=True, index=True)
- verification_code = CredentialField(null=True)
- user = QuayUserField()
- pw_reset = BooleanField(default=False)
- new_email = CharField(null=True)
- email_confirm = BooleanField(default=False)
- created = DateTimeField(default=datetime.now)
+ code = CharField(default=random_string_generator(), unique=True, index=True)
+ verification_code = CredentialField(null=True)
+ user = QuayUserField()
+ pw_reset = BooleanField(default=False)
+ new_email = CharField(null=True)
+ email_confirm = BooleanField(default=False)
+ created = DateTimeField(default=datetime.now)
class ImageStorage(BaseModel):
- uuid = CharField(default=uuid_generator, index=True, unique=True)
- image_size = BigIntegerField(null=True)
- uncompressed_size = BigIntegerField(null=True)
- uploading = BooleanField(default=True, null=True)
- cas_path = BooleanField(default=True)
- content_checksum = CharField(null=True, index=True)
+ uuid = CharField(default=uuid_generator, index=True, unique=True)
+ image_size = BigIntegerField(null=True)
+ uncompressed_size = BigIntegerField(null=True)
+ uploading = BooleanField(default=True, null=True)
+ cas_path = BooleanField(default=True)
+ content_checksum = CharField(null=True, index=True)
class ImageStorageTransformation(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class ImageStorageSignatureKind(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class ImageStorageSignature(BaseModel):
- storage = ForeignKeyField(ImageStorage)
- kind = ForeignKeyField(ImageStorageSignatureKind)
- signature = TextField(null=True)
- uploading = BooleanField(default=True, null=True)
+ storage = ForeignKeyField(ImageStorage)
+ kind = ForeignKeyField(ImageStorageSignatureKind)
+ signature = TextField(null=True)
+ uploading = BooleanField(default=True, null=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('kind', 'storage'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("kind", "storage"), True),)
class ImageStorageLocation(BaseModel):
- name = CharField(unique=True, index=True)
+ name = CharField(unique=True, index=True)
class ImageStoragePlacement(BaseModel):
- storage = ForeignKeyField(ImageStorage)
- location = ForeignKeyField(ImageStorageLocation)
+ storage = ForeignKeyField(ImageStorage)
+ location = ForeignKeyField(ImageStorageLocation)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # An image can only be placed in the same place once
- (('storage', 'location'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # An image can only be placed in the same place once
+ (("storage", "location"), True),
+ )
class UserRegion(BaseModel):
- user = QuayUserField(index=True, allows_robots=False)
- location = ForeignKeyField(ImageStorageLocation)
+ user = QuayUserField(index=True, allows_robots=False)
+ location = ForeignKeyField(ImageStorageLocation)
- indexes = (
- (('user', 'location'), True),
- )
+ indexes = ((("user", "location"), True),)
class Image(BaseModel):
- # This class is intentionally denormalized. Even though images are supposed
- # to be globally unique we can't treat them as such for permissions and
- # security reasons. So rather than Repository <-> Image being many to many
- # each image now belongs to exactly one repository.
- docker_image_id = CharField(index=True)
- repository = ForeignKeyField(Repository)
+ # This class is intentionally denormalized. Even though images are supposed
+ # to be globally unique we can't treat them as such for permissions and
+ # security reasons. So rather than Repository <-> Image being many to many
+ # each image now belongs to exactly one repository.
+ docker_image_id = CharField(index=True)
+ repository = ForeignKeyField(Repository)
- # '/' separated list of ancestory ids, e.g. /1/2/6/7/10/
- ancestors = CharField(index=True, default='/', max_length=64535, null=True)
+ # '/' separated list of ancestory ids, e.g. /1/2/6/7/10/
+ ancestors = CharField(index=True, default="/", max_length=64535, null=True)
- storage = ForeignKeyField(ImageStorage, null=True)
+ storage = ForeignKeyField(ImageStorage, null=True)
- created = DateTimeField(null=True)
- comment = TextField(null=True)
- command = TextField(null=True)
- aggregate_size = BigIntegerField(null=True)
- v1_json_metadata = TextField(null=True)
- v1_checksum = CharField(null=True)
+ created = DateTimeField(null=True)
+ comment = TextField(null=True)
+ command = TextField(null=True)
+ aggregate_size = BigIntegerField(null=True)
+ v1_json_metadata = TextField(null=True)
+ v1_checksum = CharField(null=True)
- security_indexed = BooleanField(default=False, index=True)
- security_indexed_engine = IntegerField(default=IMAGE_NOT_SCANNED_ENGINE_VERSION, index=True)
-
- # We use a proxy here instead of 'self' in order to disable the foreign key constraint
- parent = DeferredForeignKey('Image', null=True, backref='children')
-
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # we don't really want duplicates
- (('repository', 'docker_image_id'), True),
-
- (('security_indexed_engine', 'security_indexed'), False),
+ security_indexed = BooleanField(default=False, index=True)
+ security_indexed_engine = IntegerField(
+ default=IMAGE_NOT_SCANNED_ENGINE_VERSION, index=True
)
- def ancestor_id_list(self):
- """ Returns an integer list of ancestor ids, ordered chronologically from
+ # We use a proxy here instead of 'self' in order to disable the foreign key constraint
+ parent = DeferredForeignKey("Image", null=True, backref="children")
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # we don't really want duplicates
+ (("repository", "docker_image_id"), True),
+ (("security_indexed_engine", "security_indexed"), False),
+ )
+
+ def ancestor_id_list(self):
+ """ Returns an integer list of ancestor ids, ordered chronologically from
root to direct parent.
"""
- return map(int, self.ancestors.split('/')[1:-1])
+ return map(int, self.ancestors.split("/")[1:-1])
class DerivedStorageForImage(BaseModel):
- source_image = ForeignKeyField(Image)
- derivative = ForeignKeyField(ImageStorage)
- transformation = ForeignKeyField(ImageStorageTransformation)
- uniqueness_hash = CharField(null=True)
+ source_image = ForeignKeyField(Image)
+ derivative = ForeignKeyField(ImageStorage)
+ transformation = ForeignKeyField(ImageStorageTransformation)
+ uniqueness_hash = CharField(null=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('source_image', 'transformation', 'uniqueness_hash'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("source_image", "transformation", "uniqueness_hash"), True),)
class RepositoryTag(BaseModel):
- name = CharField()
- image = ForeignKeyField(Image)
- repository = ForeignKeyField(Repository)
- lifetime_start_ts = IntegerField(default=get_epoch_timestamp)
- lifetime_end_ts = IntegerField(null=True, index=True)
- hidden = BooleanField(default=False)
- reversion = BooleanField(default=False)
+ name = CharField()
+ image = ForeignKeyField(Image)
+ repository = ForeignKeyField(Repository)
+ lifetime_start_ts = IntegerField(default=get_epoch_timestamp)
+ lifetime_end_ts = IntegerField(null=True, index=True)
+ hidden = BooleanField(default=False)
+ reversion = BooleanField(default=False)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'name'), False),
- (('repository', 'lifetime_start_ts'), False),
- (('repository', 'lifetime_end_ts'), False),
-
- # This unique index prevents deadlocks when concurrently moving and deleting tags
- (('repository', 'name', 'lifetime_end_ts'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "name"), False),
+ (("repository", "lifetime_start_ts"), False),
+ (("repository", "lifetime_end_ts"), False),
+ # This unique index prevents deadlocks when concurrently moving and deleting tags
+ (("repository", "name", "lifetime_end_ts"), True),
+ )
class BUILD_PHASE(object):
- """ Build phases enum """
- ERROR = 'error'
- INTERNAL_ERROR = 'internalerror'
- BUILD_SCHEDULED = 'build-scheduled'
- UNPACKING = 'unpacking'
- PULLING = 'pulling'
- BUILDING = 'building'
- PUSHING = 'pushing'
- WAITING = 'waiting'
- COMPLETE = 'complete'
- CANCELLED = 'cancelled'
+ """ Build phases enum """
- @classmethod
- def is_terminal_phase(cls, phase):
- return (phase == cls.COMPLETE or
- phase == cls.ERROR or
- phase == cls.INTERNAL_ERROR or
- phase == cls.CANCELLED)
+ ERROR = "error"
+ INTERNAL_ERROR = "internalerror"
+ BUILD_SCHEDULED = "build-scheduled"
+ UNPACKING = "unpacking"
+ PULLING = "pulling"
+ BUILDING = "building"
+ PUSHING = "pushing"
+ WAITING = "waiting"
+ COMPLETE = "complete"
+ CANCELLED = "cancelled"
+
+ @classmethod
+ def is_terminal_phase(cls, phase):
+ return (
+ phase == cls.COMPLETE
+ or phase == cls.ERROR
+ or phase == cls.INTERNAL_ERROR
+ or phase == cls.CANCELLED
+ )
class TRIGGER_DISABLE_REASON(object):
- """ Build trigger disable reason enum """
- BUILD_FALURES = 'successive_build_failures'
- INTERNAL_ERRORS = 'successive_build_internal_errors'
- USER_TOGGLED = 'user_toggled'
+ """ Build trigger disable reason enum """
+
+ BUILD_FALURES = "successive_build_failures"
+ INTERNAL_ERRORS = "successive_build_internal_errors"
+ USER_TOGGLED = "user_toggled"
class QueueItem(BaseModel):
- queue_name = CharField(index=True, max_length=1024)
- body = TextField()
- available_after = DateTimeField(default=datetime.utcnow)
- available = BooleanField(default=True)
- processing_expires = DateTimeField(null=True)
- retries_remaining = IntegerField(default=5)
- state_id = CharField(default=uuid_generator, index=True, unique=True)
+ queue_name = CharField(index=True, max_length=1024)
+ body = TextField()
+ available_after = DateTimeField(default=datetime.utcnow)
+ available = BooleanField(default=True)
+ processing_expires = DateTimeField(null=True)
+ retries_remaining = IntegerField(default=5)
+ state_id = CharField(default=uuid_generator, index=True, unique=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- only_save_dirty = True
- indexes = (
- (('processing_expires', 'available'), False),
- (('processing_expires', 'queue_name', 'available'), False),
- (('processing_expires', 'available_after', 'retries_remaining', 'available'), False),
- (('processing_expires', 'available_after', 'queue_name', 'retries_remaining', 'available'), False),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ only_save_dirty = True
+ indexes = (
+ (("processing_expires", "available"), False),
+ (("processing_expires", "queue_name", "available"), False),
+ (
+ (
+ "processing_expires",
+ "available_after",
+ "retries_remaining",
+ "available",
+ ),
+ False,
+ ),
+ (
+ (
+ "processing_expires",
+ "available_after",
+ "queue_name",
+ "retries_remaining",
+ "available",
+ ),
+ False,
+ ),
+ )
- def save(self, *args, **kwargs):
- # Always change the queue item's state ID when we update it.
- self.state_id = str(uuid.uuid4())
- super(QueueItem, self).save(*args, **kwargs)
+ def save(self, *args, **kwargs):
+ # Always change the queue item's state ID when we update it.
+ self.state_id = str(uuid.uuid4())
+ super(QueueItem, self).save(*args, **kwargs)
class RepositoryBuild(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- repository = ForeignKeyField(Repository)
- access_token = ForeignKeyField(AccessToken)
- resource_key = CharField(index=True, null=True)
- job_config = TextField()
- phase = CharField(default=BUILD_PHASE.WAITING)
- started = DateTimeField(default=datetime.now, index=True)
- display_name = CharField()
- trigger = ForeignKeyField(RepositoryBuildTrigger, null=True)
- pull_robot = QuayUserField(null=True, backref='buildpullrobot', allows_robots=True,
- robot_null_delete=True)
- logs_archived = BooleanField(default=False, index=True)
- queue_id = CharField(null=True, index=True)
-
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'started', 'phase'), False),
- (('started', 'logs_archived', 'phase'), False),
+ uuid = CharField(default=uuid_generator, index=True)
+ repository = ForeignKeyField(Repository)
+ access_token = ForeignKeyField(AccessToken)
+ resource_key = CharField(index=True, null=True)
+ job_config = TextField()
+ phase = CharField(default=BUILD_PHASE.WAITING)
+ started = DateTimeField(default=datetime.now, index=True)
+ display_name = CharField()
+ trigger = ForeignKeyField(RepositoryBuildTrigger, null=True)
+ pull_robot = QuayUserField(
+ null=True, backref="buildpullrobot", allows_robots=True, robot_null_delete=True
)
+ logs_archived = BooleanField(default=False, index=True)
+ queue_id = CharField(null=True, index=True)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "started", "phase"), False),
+ (("started", "logs_archived", "phase"), False),
+ )
class LogEntryKind(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class LogEntry(BaseModel):
- id = BigAutoField()
- kind = ForeignKeyField(LogEntryKind)
- account = IntegerField(index=True, column_name='account_id')
- performer = IntegerField(index=True, null=True, column_name='performer_id')
- repository = IntegerField(index=True, null=True, column_name='repository_id')
- datetime = DateTimeField(default=datetime.now, index=True)
- ip = CharField(null=True)
- metadata_json = TextField(default='{}')
+ id = BigAutoField()
+ kind = ForeignKeyField(LogEntryKind)
+ account = IntegerField(index=True, column_name="account_id")
+ performer = IntegerField(index=True, null=True, column_name="performer_id")
+ repository = IntegerField(index=True, null=True, column_name="repository_id")
+ datetime = DateTimeField(default=datetime.now, index=True)
+ ip = CharField(null=True)
+ metadata_json = TextField(default="{}")
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('account', 'datetime'), False),
- (('performer', 'datetime'), False),
- (('repository', 'datetime'), False),
- (('repository', 'datetime', 'kind'), False),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("account", "datetime"), False),
+ (("performer", "datetime"), False),
+ (("repository", "datetime"), False),
+ (("repository", "datetime", "kind"), False),
+ )
class LogEntry2(BaseModel):
- """ TEMP FOR QUAY.IO ONLY. DO NOT RELEASE INTO QUAY ENTERPRISE. """
- kind = ForeignKeyField(LogEntryKind)
- account = IntegerField(index=True, db_column='account_id')
- performer = IntegerField(index=True, null=True, db_column='performer_id')
- repository = IntegerField(index=True, null=True, db_column='repository_id')
- datetime = DateTimeField(default=datetime.now, index=True)
- ip = CharField(null=True)
- metadata_json = TextField(default='{}')
+ """ TEMP FOR QUAY.IO ONLY. DO NOT RELEASE INTO QUAY ENTERPRISE. """
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('account', 'datetime'), False),
- (('performer', 'datetime'), False),
- (('repository', 'datetime'), False),
- (('repository', 'datetime', 'kind'), False),
- )
+ kind = ForeignKeyField(LogEntryKind)
+ account = IntegerField(index=True, db_column="account_id")
+ performer = IntegerField(index=True, null=True, db_column="performer_id")
+ repository = IntegerField(index=True, null=True, db_column="repository_id")
+ datetime = DateTimeField(default=datetime.now, index=True)
+ ip = CharField(null=True)
+ metadata_json = TextField(default="{}")
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("account", "datetime"), False),
+ (("performer", "datetime"), False),
+ (("repository", "datetime"), False),
+ (("repository", "datetime", "kind"), False),
+ )
class LogEntry3(BaseModel):
- id = BigAutoField()
- kind = IntegerField(db_column='kind_id')
- account = IntegerField(db_column='account_id')
- performer = IntegerField(null=True, db_column='performer_id')
- repository = IntegerField(null=True, db_column='repository_id')
- datetime = DateTimeField(default=datetime.now, index=True)
- ip = CharField(null=True)
- metadata_json = TextField(default='{}')
+ id = BigAutoField()
+ kind = IntegerField(db_column="kind_id")
+ account = IntegerField(db_column="account_id")
+ performer = IntegerField(null=True, db_column="performer_id")
+ repository = IntegerField(null=True, db_column="repository_id")
+ datetime = DateTimeField(default=datetime.now, index=True)
+ ip = CharField(null=True)
+ metadata_json = TextField(default="{}")
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('account', 'datetime'), False),
- (('performer', 'datetime'), False),
- (('repository', 'datetime', 'kind'), False),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("account", "datetime"), False),
+ (("performer", "datetime"), False),
+ (("repository", "datetime", "kind"), False),
+ )
class RepositoryActionCount(BaseModel):
- repository = ForeignKeyField(Repository)
- count = IntegerField()
- date = DateField(index=True)
+ repository = ForeignKeyField(Repository)
+ count = IntegerField()
+ date = DateField(index=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on repository and date
- (('repository', 'date'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on repository and date
+ (("repository", "date"), True),
+ )
class OAuthApplication(BaseModel):
- client_id = CharField(index=True, default=random_string_generator(length=20))
- secure_client_secret = EncryptedCharField(default_token_length=40, null=True)
- fully_migrated = BooleanField(default=False)
+ client_id = CharField(index=True, default=random_string_generator(length=20))
+ secure_client_secret = EncryptedCharField(default_token_length=40, null=True)
+ fully_migrated = BooleanField(default=False)
- # TODO(remove-unenc): This field is deprecated and should be removed soon.
- client_secret = deprecated_field(
- CharField(default=random_string_generator(length=40), null=True),
- ERTMigrationFlags.WRITE_OLD_FIELDS)
+ # TODO(remove-unenc): This field is deprecated and should be removed soon.
+ client_secret = deprecated_field(
+ CharField(default=random_string_generator(length=40), null=True),
+ ERTMigrationFlags.WRITE_OLD_FIELDS,
+ )
- redirect_uri = CharField()
- application_uri = CharField()
- organization = QuayUserField()
+ redirect_uri = CharField()
+ application_uri = CharField()
+ organization = QuayUserField()
- name = CharField()
- description = TextField(default='')
- avatar_email = CharField(null=True, column_name='gravatar_email')
+ name = CharField()
+ description = TextField(default="")
+ avatar_email = CharField(null=True, column_name="gravatar_email")
class OAuthAuthorizationCode(BaseModel):
- application = ForeignKeyField(OAuthApplication)
+ application = ForeignKeyField(OAuthApplication)
- # TODO(remove-unenc): This field is deprecated and should be removed soon.
- code = deprecated_field(
- CharField(index=True, unique=True, null=True),
- ERTMigrationFlags.WRITE_OLD_FIELDS)
+ # TODO(remove-unenc): This field is deprecated and should be removed soon.
+ code = deprecated_field(
+ CharField(index=True, unique=True, null=True),
+ ERTMigrationFlags.WRITE_OLD_FIELDS,
+ )
- code_name = CharField(index=True, unique=True)
- code_credential = CredentialField()
+ code_name = CharField(index=True, unique=True)
+ code_credential = CredentialField()
- scope = CharField()
- data = TextField() # Context for the code, such as the user
+ scope = CharField()
+ data = TextField() # Context for the code, such as the user
class OAuthAccessToken(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- application = ForeignKeyField(OAuthApplication)
- authorized_user = QuayUserField()
- scope = CharField()
- token_name = CharField(index=True, unique=True)
- token_code = CredentialField()
+ uuid = CharField(default=uuid_generator, index=True)
+ application = ForeignKeyField(OAuthApplication)
+ authorized_user = QuayUserField()
+ scope = CharField()
+ token_name = CharField(index=True, unique=True)
+ token_code = CredentialField()
- # TODO(remove-unenc): This field is deprecated and should be removed soon.
- access_token = deprecated_field(
- CharField(index=True, null=True),
- ERTMigrationFlags.WRITE_OLD_FIELDS)
+ # TODO(remove-unenc): This field is deprecated and should be removed soon.
+ access_token = deprecated_field(
+ CharField(index=True, null=True), ERTMigrationFlags.WRITE_OLD_FIELDS
+ )
- token_type = CharField(default='Bearer')
- expires_at = DateTimeField()
- data = TextField() # This is context for which this token was generated, such as the user
+ token_type = CharField(default="Bearer")
+ expires_at = DateTimeField()
+ data = (
+ TextField()
+ ) # This is context for which this token was generated, such as the user
class NotificationKind(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class Notification(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- kind = ForeignKeyField(NotificationKind)
- target = QuayUserField(index=True, allows_robots=True)
- metadata_json = TextField(default='{}')
- created = DateTimeField(default=datetime.now, index=True)
- dismissed = BooleanField(default=False)
- lookup_path = CharField(null=True, index=True)
+ uuid = CharField(default=uuid_generator, index=True)
+ kind = ForeignKeyField(NotificationKind)
+ target = QuayUserField(index=True, allows_robots=True)
+ metadata_json = TextField(default="{}")
+ created = DateTimeField(default=datetime.now, index=True)
+ dismissed = BooleanField(default=False)
+ lookup_path = CharField(null=True, index=True)
class ExternalNotificationEvent(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class ExternalNotificationMethod(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class RepositoryNotification(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- repository = ForeignKeyField(Repository)
- event = ForeignKeyField(ExternalNotificationEvent)
- method = ForeignKeyField(ExternalNotificationMethod)
- title = CharField(null=True)
- config_json = TextField()
- event_config_json = TextField(default='{}')
- number_of_failures = IntegerField(default=0)
+ uuid = CharField(default=uuid_generator, index=True)
+ repository = ForeignKeyField(Repository)
+ event = ForeignKeyField(ExternalNotificationEvent)
+ method = ForeignKeyField(ExternalNotificationMethod)
+ title = CharField(null=True)
+ config_json = TextField()
+ event_config_json = TextField(default="{}")
+ number_of_failures = IntegerField(default=0)
class RepositoryAuthorizedEmail(BaseModel):
- repository = ForeignKeyField(Repository)
- email = CharField()
- code = CharField(default=random_string_generator(), unique=True, index=True)
- confirmed = BooleanField(default=False)
+ repository = ForeignKeyField(Repository)
+ email = CharField()
+ code = CharField(default=random_string_generator(), unique=True, index=True)
+ confirmed = BooleanField(default=False)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on email and repository
- (('email', 'repository'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on email and repository
+ (("email", "repository"), True),
+ )
class BlobUpload(BaseModel):
- repository = ForeignKeyField(Repository)
- uuid = CharField(index=True, unique=True)
- byte_count = BigIntegerField(default=0)
- sha_state = ResumableSHA256Field(null=True, default=resumablehashlib.sha256)
- location = ForeignKeyField(ImageStorageLocation)
- storage_metadata = JSONField(null=True, default={})
- chunk_count = IntegerField(default=0)
- uncompressed_byte_count = BigIntegerField(null=True)
- created = DateTimeField(default=datetime.now, index=True)
- piece_sha_state = ResumableSHA1Field(null=True)
- piece_hashes = Base64BinaryField(null=True)
+ repository = ForeignKeyField(Repository)
+ uuid = CharField(index=True, unique=True)
+ byte_count = BigIntegerField(default=0)
+ sha_state = ResumableSHA256Field(null=True, default=resumablehashlib.sha256)
+ location = ForeignKeyField(ImageStorageLocation)
+ storage_metadata = JSONField(null=True, default={})
+ chunk_count = IntegerField(default=0)
+ uncompressed_byte_count = BigIntegerField(null=True)
+ created = DateTimeField(default=datetime.now, index=True)
+ piece_sha_state = ResumableSHA1Field(null=True)
+ piece_hashes = Base64BinaryField(null=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # create a unique index on email and repository
- (('repository', 'uuid'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # create a unique index on email and repository
+ (("repository", "uuid"), True),
+ )
class QuayService(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class QuayRegion(BaseModel):
- name = CharField(index=True, unique=True)
+ name = CharField(index=True, unique=True)
class QuayRelease(BaseModel):
- service = ForeignKeyField(QuayService)
- version = CharField()
- region = ForeignKeyField(QuayRegion)
- reverted = BooleanField(default=False)
- created = DateTimeField(default=datetime.now, index=True)
+ service = ForeignKeyField(QuayService)
+ version = CharField()
+ region = ForeignKeyField(QuayRegion)
+ reverted = BooleanField(default=False)
+ created = DateTimeField(default=datetime.now, index=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # unique release per region
- (('service', 'version', 'region'), True),
-
- # get recent releases
- (('service', 'region', 'created'), False),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # unique release per region
+ (("service", "version", "region"), True),
+ # get recent releases
+ (("service", "region", "created"), False),
+ )
class TorrentInfo(BaseModel):
- storage = ForeignKeyField(ImageStorage)
- piece_length = IntegerField()
- pieces = Base64BinaryField()
+ storage = ForeignKeyField(ImageStorage)
+ piece_length = IntegerField()
+ pieces = Base64BinaryField()
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- # we may want to compute the piece hashes multiple times with different piece lengths
- (('storage', 'piece_length'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ # we may want to compute the piece hashes multiple times with different piece lengths
+ (("storage", "piece_length"), True),
+ )
class ServiceKeyApprovalType(Enum):
- SUPERUSER = 'Super User API'
- KEY_ROTATION = 'Key Rotation'
- AUTOMATIC = 'Automatic'
+ SUPERUSER = "Super User API"
+ KEY_ROTATION = "Key Rotation"
+ AUTOMATIC = "Automatic"
class ServiceKeyApproval(BaseModel):
- approver = QuayUserField(null=True)
- approval_type = CharField(index=True)
- approved_date = DateTimeField(default=datetime.utcnow)
- notes = TextField(default='')
+ approver = QuayUserField(null=True)
+ approval_type = CharField(index=True)
+ approved_date = DateTimeField(default=datetime.utcnow)
+ notes = TextField(default="")
class ServiceKey(BaseModel):
- name = CharField()
- kid = CharField(unique=True, index=True)
- service = CharField(index=True)
- jwk = JSONField()
- metadata = JSONField()
- created_date = DateTimeField(default=datetime.utcnow)
- expiration_date = DateTimeField(null=True)
- rotation_duration = IntegerField(null=True)
- approval = ForeignKeyField(ServiceKeyApproval, null=True)
+ name = CharField()
+ kid = CharField(unique=True, index=True)
+ service = CharField(index=True)
+ jwk = JSONField()
+ metadata = JSONField()
+ created_date = DateTimeField(default=datetime.utcnow)
+ expiration_date = DateTimeField(null=True)
+ rotation_duration = IntegerField(null=True)
+ approval = ForeignKeyField(ServiceKeyApproval, null=True)
class MediaType(BaseModel):
- """ MediaType is an enumeration of the possible formats of various objects in the data model.
+ """ MediaType is an enumeration of the possible formats of various objects in the data model.
"""
- name = CharField(index=True, unique=True)
+
+ name = CharField(index=True, unique=True)
class Messages(BaseModel):
- content = TextField()
- uuid = CharField(default=uuid_generator, max_length=36, index=True)
- severity = CharField(default='info', index=True)
- media_type = ForeignKeyField(MediaType)
+ content = TextField()
+ uuid = CharField(default=uuid_generator, max_length=36, index=True)
+ severity = CharField(default="info", index=True)
+ media_type = ForeignKeyField(MediaType)
class LabelSourceType(BaseModel):
- """ LabelSourceType is an enumeration of the possible sources for a label.
+ """ LabelSourceType is an enumeration of the possible sources for a label.
"""
- name = CharField(index=True, unique=True)
- mutable = BooleanField(default=False)
+
+ name = CharField(index=True, unique=True)
+ mutable = BooleanField(default=False)
class Label(BaseModel):
- """ Label represents user-facing metadata associated with another entry in the database (e.g. a
+ """ Label represents user-facing metadata associated with another entry in the database (e.g. a
Manifest).
"""
- uuid = CharField(default=uuid_generator, index=True, unique=True)
- key = CharField(index=True)
- value = TextField()
- media_type = EnumField(MediaType)
- source_type = EnumField(LabelSourceType)
+
+ uuid = CharField(default=uuid_generator, index=True, unique=True)
+ key = CharField(index=True)
+ value = TextField()
+ media_type = EnumField(MediaType)
+ source_type = EnumField(LabelSourceType)
class ApprBlob(BaseModel):
- """ ApprBlob represents a content-addressable object stored outside of the database.
+ """ ApprBlob represents a content-addressable object stored outside of the database.
"""
- digest = CharField(index=True, unique=True)
- media_type = EnumField(MediaType)
- size = BigIntegerField()
- uncompressed_size = BigIntegerField(null=True)
+
+ digest = CharField(index=True, unique=True)
+ media_type = EnumField(MediaType)
+ size = BigIntegerField()
+ uncompressed_size = BigIntegerField(null=True)
class ApprBlobPlacementLocation(BaseModel):
- """ ApprBlobPlacementLocation is an enumeration of the possible storage locations for ApprBlobs.
+ """ ApprBlobPlacementLocation is an enumeration of the possible storage locations for ApprBlobs.
"""
- name = CharField(index=True, unique=True)
+
+ name = CharField(index=True, unique=True)
class ApprBlobPlacement(BaseModel):
- """ ApprBlobPlacement represents the location of a Blob.
+ """ ApprBlobPlacement represents the location of a Blob.
"""
- blob = ForeignKeyField(ApprBlob)
- location = EnumField(ApprBlobPlacementLocation)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('blob', 'location'), True),
- )
+ blob = ForeignKeyField(ApprBlob)
+ location = EnumField(ApprBlobPlacementLocation)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("blob", "location"), True),)
class ApprManifest(BaseModel):
- """ ApprManifest represents the metadata and collection of blobs that comprise an Appr image.
+ """ ApprManifest represents the metadata and collection of blobs that comprise an Appr image.
"""
- digest = CharField(index=True, unique=True)
- media_type = EnumField(MediaType)
- manifest_json = JSONField()
+
+ digest = CharField(index=True, unique=True)
+ media_type = EnumField(MediaType)
+ manifest_json = JSONField()
class ApprManifestBlob(BaseModel):
- """ ApprManifestBlob is a many-to-many relation table linking ApprManifests and ApprBlobs.
+ """ ApprManifestBlob is a many-to-many relation table linking ApprManifests and ApprBlobs.
"""
- manifest = ForeignKeyField(ApprManifest, index=True)
- blob = ForeignKeyField(ApprBlob, index=True)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('manifest', 'blob'), True),
- )
+ manifest = ForeignKeyField(ApprManifest, index=True)
+ blob = ForeignKeyField(ApprBlob, index=True)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("manifest", "blob"), True),)
class ApprManifestList(BaseModel):
- """ ApprManifestList represents all of the various Appr manifests that compose an ApprTag.
+ """ ApprManifestList represents all of the various Appr manifests that compose an ApprTag.
"""
- digest = CharField(index=True, unique=True)
- manifest_list_json = JSONField()
- schema_version = CharField()
- media_type = EnumField(MediaType)
+ digest = CharField(index=True, unique=True)
+ manifest_list_json = JSONField()
+ schema_version = CharField()
+ media_type = EnumField(MediaType)
class ApprTagKind(BaseModel):
- """ ApprTagKind is a enumtable to reference tag kinds.
+ """ ApprTagKind is a enumtable to reference tag kinds.
"""
- name = CharField(index=True, unique=True)
+
+ name = CharField(index=True, unique=True)
class ApprTag(BaseModel):
- """ ApprTag represents a user-facing alias for referencing an ApprManifestList.
+ """ ApprTag represents a user-facing alias for referencing an ApprManifestList.
"""
- name = CharField()
- repository = ForeignKeyField(Repository)
- manifest_list = ForeignKeyField(ApprManifestList, null=True)
- lifetime_start = BigIntegerField(default=get_epoch_timestamp_ms)
- lifetime_end = BigIntegerField(null=True, index=True)
- hidden = BooleanField(default=False)
- reverted = BooleanField(default=False)
- protected = BooleanField(default=False)
- tag_kind = EnumField(ApprTagKind)
- linked_tag = ForeignKeyField('self', null=True, backref='tag_parents')
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'name'), False),
- (('repository', 'name', 'hidden'), False),
- # This unique index prevents deadlocks when concurrently moving and deleting tags
- (('repository', 'name', 'lifetime_end'), True),
- )
+ name = CharField()
+ repository = ForeignKeyField(Repository)
+ manifest_list = ForeignKeyField(ApprManifestList, null=True)
+ lifetime_start = BigIntegerField(default=get_epoch_timestamp_ms)
+ lifetime_end = BigIntegerField(null=True, index=True)
+ hidden = BooleanField(default=False)
+ reverted = BooleanField(default=False)
+ protected = BooleanField(default=False)
+ tag_kind = EnumField(ApprTagKind)
+ linked_tag = ForeignKeyField("self", null=True, backref="tag_parents")
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "name"), False),
+ (("repository", "name", "hidden"), False),
+ # This unique index prevents deadlocks when concurrently moving and deleting tags
+ (("repository", "name", "lifetime_end"), True),
+ )
+
ApprChannel = ApprTag.alias()
class ApprManifestListManifest(BaseModel):
- """ ApprManifestListManifest is a many-to-many relation table linking ApprManifestLists and
+ """ ApprManifestListManifest is a many-to-many relation table linking ApprManifestLists and
ApprManifests.
"""
- manifest_list = ForeignKeyField(ApprManifestList, index=True)
- manifest = ForeignKeyField(ApprManifest, index=True)
- operating_system = CharField(null=True)
- architecture = CharField(null=True)
- platform_json = JSONField(null=True)
- media_type = EnumField(MediaType)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('manifest_list', 'media_type'), False),
- )
+ manifest_list = ForeignKeyField(ApprManifestList, index=True)
+ manifest = ForeignKeyField(ApprManifest, index=True)
+ operating_system = CharField(null=True)
+ architecture = CharField(null=True)
+ platform_json = JSONField(null=True)
+ media_type = EnumField(MediaType)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("manifest_list", "media_type"), False),)
class AppSpecificAuthToken(BaseModel):
- """ AppSpecificAuthToken represents a token generated by a user for use with an external
+ """ AppSpecificAuthToken represents a token generated by a user for use with an external
application where putting the user's credentials, even encrypted, is deemed too risky.
"""
- user = QuayUserField()
- uuid = CharField(default=uuid_generator, max_length=36, index=True)
- title = CharField()
- token_name = CharField(index=True, unique=True, default=random_string_generator(60))
- token_secret = EncryptedCharField(default_token_length=60)
- # TODO(remove-unenc): This field is deprecated and should be removed soon.
- token_code = deprecated_field(
- CharField(default=random_string_generator(length=120), unique=True, index=True, null=True),
- ERTMigrationFlags.WRITE_OLD_FIELDS)
+ user = QuayUserField()
+ uuid = CharField(default=uuid_generator, max_length=36, index=True)
+ title = CharField()
+ token_name = CharField(index=True, unique=True, default=random_string_generator(60))
+ token_secret = EncryptedCharField(default_token_length=60)
- created = DateTimeField(default=datetime.now)
- expiration = DateTimeField(null=True)
- last_accessed = DateTimeField(null=True)
-
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('user', 'expiration'), False),
+ # TODO(remove-unenc): This field is deprecated and should be removed soon.
+ token_code = deprecated_field(
+ CharField(
+ default=random_string_generator(length=120),
+ unique=True,
+ index=True,
+ null=True,
+ ),
+ ERTMigrationFlags.WRITE_OLD_FIELDS,
)
+ created = DateTimeField(default=datetime.now)
+ expiration = DateTimeField(null=True)
+ last_accessed = DateTimeField(null=True)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("user", "expiration"), False),)
+
class Manifest(BaseModel):
- """ Manifest represents a single manifest under a repository. Within a repository,
+ """ Manifest represents a single manifest under a repository. Within a repository,
there can only be one manifest with the same digest.
"""
- repository = ForeignKeyField(Repository)
- digest = CharField(index=True)
- media_type = EnumField(MediaType)
- manifest_bytes = TextField()
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'digest'), True),
- (('repository', 'media_type'), False),
- )
+ repository = ForeignKeyField(Repository)
+ digest = CharField(index=True)
+ media_type = EnumField(MediaType)
+ manifest_bytes = TextField()
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "digest"), True),
+ (("repository", "media_type"), False),
+ )
class TagKind(BaseModel):
- """ TagKind describes the various kinds of tags that can be found in the registry.
+ """ TagKind describes the various kinds of tags that can be found in the registry.
"""
- name = CharField(index=True, unique=True)
+
+ name = CharField(index=True, unique=True)
class Tag(BaseModel):
- """ Tag represents a user-facing alias for referencing a Manifest or as an alias to another tag.
+ """ Tag represents a user-facing alias for referencing a Manifest or as an alias to another tag.
"""
- name = CharField()
- repository = ForeignKeyField(Repository)
- manifest = ForeignKeyField(Manifest, null=True)
- lifetime_start_ms = BigIntegerField(default=get_epoch_timestamp_ms)
- lifetime_end_ms = BigIntegerField(null=True, index=True)
- hidden = BooleanField(default=False)
- reversion = BooleanField(default=False)
- tag_kind = EnumField(TagKind)
- linked_tag = ForeignKeyField('self', null=True, backref='tag_parents')
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'name'), False),
- (('repository', 'name', 'hidden'), False),
- (('repository', 'name', 'tag_kind'), False),
+ name = CharField()
+ repository = ForeignKeyField(Repository)
+ manifest = ForeignKeyField(Manifest, null=True)
+ lifetime_start_ms = BigIntegerField(default=get_epoch_timestamp_ms)
+ lifetime_end_ms = BigIntegerField(null=True, index=True)
+ hidden = BooleanField(default=False)
+ reversion = BooleanField(default=False)
+ tag_kind = EnumField(TagKind)
+ linked_tag = ForeignKeyField("self", null=True, backref="tag_parents")
- (('repository', 'lifetime_start_ms'), False),
- (('repository', 'lifetime_end_ms'), False),
-
- # This unique index prevents deadlocks when concurrently moving and deleting tags
- (('repository', 'name', 'lifetime_end_ms'), True),
- )
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "name"), False),
+ (("repository", "name", "hidden"), False),
+ (("repository", "name", "tag_kind"), False),
+ (("repository", "lifetime_start_ms"), False),
+ (("repository", "lifetime_end_ms"), False),
+ # This unique index prevents deadlocks when concurrently moving and deleting tags
+ (("repository", "name", "lifetime_end_ms"), True),
+ )
class ManifestChild(BaseModel):
- """ ManifestChild represents a relationship between a manifest and its child manifest(s).
+ """ ManifestChild represents a relationship between a manifest and its child manifest(s).
Multiple manifests can share the same children. Note that since Manifests are stored
per-repository, the repository here is a bit redundant, but we do so to make cleanup easier.
"""
- repository = ForeignKeyField(Repository)
- manifest = ForeignKeyField(Manifest)
- child_manifest = ForeignKeyField(Manifest)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('repository', 'manifest'), False),
- (('repository', 'child_manifest'), False),
- (('repository', 'manifest', 'child_manifest'), False),
- (('manifest', 'child_manifest'), True),
- )
+ repository = ForeignKeyField(Repository)
+ manifest = ForeignKeyField(Manifest)
+ child_manifest = ForeignKeyField(Manifest)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = (
+ (("repository", "manifest"), False),
+ (("repository", "child_manifest"), False),
+ (("repository", "manifest", "child_manifest"), False),
+ (("manifest", "child_manifest"), True),
+ )
class ManifestLabel(BaseModel):
- """ ManifestLabel represents a label applied to a Manifest, within a repository.
+ """ ManifestLabel represents a label applied to a Manifest, within a repository.
Note that since Manifests are stored per-repository, the repository here is
a bit redundant, but we do so to make cleanup easier.
"""
- repository = ForeignKeyField(Repository, index=True)
- manifest = ForeignKeyField(Manifest)
- label = ForeignKeyField(Label)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('manifest', 'label'), True),
- )
+ repository = ForeignKeyField(Repository, index=True)
+ manifest = ForeignKeyField(Manifest)
+ label = ForeignKeyField(Label)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("manifest", "label"), True),)
class ManifestBlob(BaseModel):
- """ ManifestBlob represents a blob that is used by a manifest. """
- repository = ForeignKeyField(Repository, index=True)
- manifest = ForeignKeyField(Manifest)
- blob = ForeignKeyField(ImageStorage)
+ """ ManifestBlob represents a blob that is used by a manifest. """
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('manifest', 'blob'), True),
- )
+ repository = ForeignKeyField(Repository, index=True)
+ manifest = ForeignKeyField(Manifest)
+ blob = ForeignKeyField(ImageStorage)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("manifest", "blob"), True),)
class ManifestLegacyImage(BaseModel):
- """ For V1-compatible manifests only, this table maps from the manifest to its associated
+ """ For V1-compatible manifests only, this table maps from the manifest to its associated
Docker image.
"""
- repository = ForeignKeyField(Repository, index=True)
- manifest = ForeignKeyField(Manifest, unique=True)
- image = ForeignKeyField(Image)
+
+ repository = ForeignKeyField(Repository, index=True)
+ manifest = ForeignKeyField(Manifest, unique=True)
+ image = ForeignKeyField(Image)
class TagManifest(BaseModel):
- """ TO BE DEPRECATED: The manifest for a tag. """
- tag = ForeignKeyField(RepositoryTag, unique=True)
- digest = CharField(index=True)
- json_data = TextField()
+ """ TO BE DEPRECATED: The manifest for a tag. """
+
+ tag = ForeignKeyField(RepositoryTag, unique=True)
+ digest = CharField(index=True)
+ json_data = TextField()
class TagManifestToManifest(BaseModel):
- """ NOTE: Only used for the duration of the migrations. """
- tag_manifest = ForeignKeyField(TagManifest, index=True, unique=True)
- manifest = ForeignKeyField(Manifest, index=True)
- broken = BooleanField(index=True, default=False)
+ """ NOTE: Only used for the duration of the migrations. """
+
+ tag_manifest = ForeignKeyField(TagManifest, index=True, unique=True)
+ manifest = ForeignKeyField(Manifest, index=True)
+ broken = BooleanField(index=True, default=False)
class TagManifestLabel(BaseModel):
- """ TO BE DEPRECATED: Mapping from a tag manifest to a label.
+ """ TO BE DEPRECATED: Mapping from a tag manifest to a label.
"""
- repository = ForeignKeyField(Repository, index=True)
- annotated = ForeignKeyField(TagManifest, index=True)
- label = ForeignKeyField(Label)
- class Meta:
- database = db
- read_only_config = read_only_config
- indexes = (
- (('annotated', 'label'), True),
- )
+ repository = ForeignKeyField(Repository, index=True)
+ annotated = ForeignKeyField(TagManifest, index=True)
+ label = ForeignKeyField(Label)
+
+ class Meta:
+ database = db
+ read_only_config = read_only_config
+ indexes = ((("annotated", "label"), True),)
class TagManifestLabelMap(BaseModel):
- """ NOTE: Only used for the duration of the migrations. """
- tag_manifest = ForeignKeyField(TagManifest, index=True)
- manifest = ForeignKeyField(Manifest, null=True, index=True)
+ """ NOTE: Only used for the duration of the migrations. """
- label = ForeignKeyField(Label, index=True)
+ tag_manifest = ForeignKeyField(TagManifest, index=True)
+ manifest = ForeignKeyField(Manifest, null=True, index=True)
- tag_manifest_label = ForeignKeyField(TagManifestLabel, index=True)
- manifest_label = ForeignKeyField(ManifestLabel, null=True, index=True)
+ label = ForeignKeyField(Label, index=True)
- broken_manifest = BooleanField(index=True, default=False)
+ tag_manifest_label = ForeignKeyField(TagManifestLabel, index=True)
+ manifest_label = ForeignKeyField(ManifestLabel, null=True, index=True)
+
+ broken_manifest = BooleanField(index=True, default=False)
class TagToRepositoryTag(BaseModel):
- """ NOTE: Only used for the duration of the migrations. """
- repository = ForeignKeyField(Repository, index=True)
- tag = ForeignKeyField(Tag, index=True, unique=True)
- repository_tag = ForeignKeyField(RepositoryTag, index=True, unique=True)
+ """ NOTE: Only used for the duration of the migrations. """
+
+ repository = ForeignKeyField(Repository, index=True)
+ tag = ForeignKeyField(Tag, index=True, unique=True)
+ repository_tag = ForeignKeyField(RepositoryTag, index=True, unique=True)
@unique
class RepoMirrorRuleType(IntEnum):
- """
+ """
Types of mirroring rules.
TAG_GLOB_CSV: Comma separated glob values (eg. "7.6,7.6-1.*")
"""
- TAG_GLOB_CSV = 1
+
+ TAG_GLOB_CSV = 1
class RepoMirrorRule(BaseModel):
- """
+ """
Determines how a given Repository should be mirrored.
"""
- uuid = CharField(default=uuid_generator, max_length=36, index=True)
- repository = ForeignKeyField(Repository, index=True)
- creation_date = DateTimeField(default=datetime.utcnow)
- rule_type = ClientEnumField(RepoMirrorRuleType, default=RepoMirrorRuleType.TAG_GLOB_CSV)
- rule_value = JSONField()
+ uuid = CharField(default=uuid_generator, max_length=36, index=True)
+ repository = ForeignKeyField(Repository, index=True)
+ creation_date = DateTimeField(default=datetime.utcnow)
- # Optional associations to allow the generation of a ruleset tree
- left_child = ForeignKeyField('self', null=True, backref='left_child')
- right_child = ForeignKeyField('self', null=True, backref='right_child')
+ rule_type = ClientEnumField(
+ RepoMirrorRuleType, default=RepoMirrorRuleType.TAG_GLOB_CSV
+ )
+ rule_value = JSONField()
+
+ # Optional associations to allow the generation of a ruleset tree
+ left_child = ForeignKeyField("self", null=True, backref="left_child")
+ right_child = ForeignKeyField("self", null=True, backref="right_child")
@unique
class RepoMirrorType(IntEnum):
- """
+ """
Types of repository mirrors.
"""
- PULL = 1 # Pull images from the external repo
+
+ PULL = 1 # Pull images from the external repo
@unique
class RepoMirrorStatus(IntEnum):
- """
+ """
Possible statuses of repository mirroring.
"""
- FAIL = -1
- NEVER_RUN = 0
- SUCCESS = 1
- SYNCING = 2
- SYNC_NOW = 3
+
+ FAIL = -1
+ NEVER_RUN = 0
+ SUCCESS = 1
+ SYNCING = 2
+ SYNC_NOW = 3
class RepoMirrorConfig(BaseModel):
- """
+ """
Represents a repository to be mirrored and any additional configuration
required to perform the mirroring.
"""
- repository = ForeignKeyField(Repository, index=True, unique=True, backref='mirror')
- creation_date = DateTimeField(default=datetime.utcnow)
- is_enabled = BooleanField(default=True)
- # Mirror Configuration
- mirror_type = ClientEnumField(RepoMirrorType, default=RepoMirrorType.PULL)
- internal_robot = QuayUserField(allows_robots=True, null=True, backref='mirrorpullrobot',
- robot_null_delete=True)
- external_reference = CharField()
- external_registry_username = EncryptedCharField(max_length=2048, null=True)
- external_registry_password = EncryptedCharField(max_length=2048, null=True)
- external_registry_config = JSONField(default={})
+ repository = ForeignKeyField(Repository, index=True, unique=True, backref="mirror")
+ creation_date = DateTimeField(default=datetime.utcnow)
+ is_enabled = BooleanField(default=True)
- # Worker Queuing
- sync_interval = IntegerField() # seconds between syncs
- sync_start_date = DateTimeField(null=True) # next start time
- sync_expiration_date = DateTimeField(null=True) # max duration
- sync_retries_remaining = IntegerField(default=3)
- sync_status = ClientEnumField(RepoMirrorStatus, default=RepoMirrorStatus.NEVER_RUN)
- sync_transaction_id = CharField(default=uuid_generator, max_length=36)
+ # Mirror Configuration
+ mirror_type = ClientEnumField(RepoMirrorType, default=RepoMirrorType.PULL)
+ internal_robot = QuayUserField(
+ allows_robots=True, null=True, backref="mirrorpullrobot", robot_null_delete=True
+ )
+ external_reference = CharField()
+ external_registry_username = EncryptedCharField(max_length=2048, null=True)
+ external_registry_password = EncryptedCharField(max_length=2048, null=True)
+ external_registry_config = JSONField(default={})
- # Tag-Matching Rules
- root_rule = ForeignKeyField(RepoMirrorRule)
+ # Worker Queuing
+ sync_interval = IntegerField() # seconds between syncs
+ sync_start_date = DateTimeField(null=True) # next start time
+ sync_expiration_date = DateTimeField(null=True) # max duration
+ sync_retries_remaining = IntegerField(default=3)
+ sync_status = ClientEnumField(RepoMirrorStatus, default=RepoMirrorStatus.NEVER_RUN)
+ sync_transaction_id = CharField(default=uuid_generator, max_length=36)
+
+ # Tag-Matching Rules
+ root_rule = ForeignKeyField(RepoMirrorRule)
-appr_classes = set([ApprTag, ApprTagKind, ApprBlobPlacementLocation, ApprManifestList,
- ApprManifestBlob, ApprBlob, ApprManifestListManifest, ApprManifest,
- ApprBlobPlacement])
-v22_classes = set([Manifest, ManifestLabel, ManifestBlob, ManifestLegacyImage, TagKind,
- ManifestChild, Tag])
-transition_classes = set([TagManifestToManifest, TagManifestLabelMap, TagToRepositoryTag])
+appr_classes = set(
+ [
+ ApprTag,
+ ApprTagKind,
+ ApprBlobPlacementLocation,
+ ApprManifestList,
+ ApprManifestBlob,
+ ApprBlob,
+ ApprManifestListManifest,
+ ApprManifest,
+ ApprBlobPlacement,
+ ]
+)
+v22_classes = set(
+ [
+ Manifest,
+ ManifestLabel,
+ ManifestBlob,
+ ManifestLegacyImage,
+ TagKind,
+ ManifestChild,
+ Tag,
+ ]
+)
+transition_classes = set(
+ [TagManifestToManifest, TagManifestLabelMap, TagToRepositoryTag]
+)
-is_model = lambda x: inspect.isclass(x) and issubclass(x, BaseModel) and x is not BaseModel
+is_model = (
+ lambda x: inspect.isclass(x) and issubclass(x, BaseModel) and x is not BaseModel
+)
all_models = [model[1] for model in inspect.getmembers(sys.modules[__name__], is_model)]
diff --git a/data/encryption.py b/data/encryption.py
index 429f09827..199373ca1 100644
--- a/data/encryption.py
+++ b/data/encryption.py
@@ -7,81 +7,86 @@ from cryptography.hazmat.primitives.ciphers.aead import AESCCM
from util.security.secret import convert_secret_key
+
class DecryptionFailureException(Exception):
- """ Exception raised if a field could not be decrypted. """
+ """ Exception raised if a field could not be decrypted. """
-EncryptionVersion = namedtuple('EncryptionVersion', ['prefix', 'encrypt', 'decrypt'])
+EncryptionVersion = namedtuple("EncryptionVersion", ["prefix", "encrypt", "decrypt"])
logger = logging.getLogger(__name__)
-_SEPARATOR = '$$'
+_SEPARATOR = "$$"
AES_CCM_NONCE_LENGTH = 13
def _encrypt_ccm(secret_key, value, field_max_length=None):
- aesccm = AESCCM(secret_key)
- nonce = os.urandom(AES_CCM_NONCE_LENGTH)
- ct = aesccm.encrypt(nonce, value.encode('utf-8'), None)
- encrypted = base64.b64encode(nonce + ct)
- if field_max_length:
- msg = 'Tried to encode a value too large for this field'
- assert (len(encrypted) + _RESERVED_FIELD_SPACE) <= field_max_length, msg
+ aesccm = AESCCM(secret_key)
+ nonce = os.urandom(AES_CCM_NONCE_LENGTH)
+ ct = aesccm.encrypt(nonce, value.encode("utf-8"), None)
+ encrypted = base64.b64encode(nonce + ct)
+ if field_max_length:
+ msg = "Tried to encode a value too large for this field"
+ assert (len(encrypted) + _RESERVED_FIELD_SPACE) <= field_max_length, msg
- return encrypted
+ return encrypted
def _decrypt_ccm(secret_key, value):
- aesccm = AESCCM(secret_key)
- try:
- decoded = base64.b64decode(value)
- nonce = decoded[:AES_CCM_NONCE_LENGTH]
- ct = decoded[AES_CCM_NONCE_LENGTH:]
- decrypted = aesccm.decrypt(nonce, ct, None)
- return decrypted.decode('utf-8')
- except Exception:
- logger.exception('Got exception when trying to decrypt value `%s`', value)
- raise DecryptionFailureException()
+ aesccm = AESCCM(secret_key)
+ try:
+ decoded = base64.b64decode(value)
+ nonce = decoded[:AES_CCM_NONCE_LENGTH]
+ ct = decoded[AES_CCM_NONCE_LENGTH:]
+ decrypted = aesccm.decrypt(nonce, ct, None)
+ return decrypted.decode("utf-8")
+ except Exception:
+ logger.exception("Got exception when trying to decrypt value `%s`", value)
+ raise DecryptionFailureException()
# Defines the versions of encryptions we support. This will allow us to upgrade to newer encryption
# protocols (fairly seamlessly) if need be in the future.
-_VERSIONS = {
- 'v0': EncryptionVersion('v0', _encrypt_ccm, _decrypt_ccm),
-}
+_VERSIONS = {"v0": EncryptionVersion("v0", _encrypt_ccm, _decrypt_ccm)}
_RESERVED_FIELD_SPACE = len(_SEPARATOR) + max([len(k) for k in _VERSIONS.keys()])
class FieldEncrypter(object):
- """ Helper object for defining how fields are encrypted and decrypted between the database
+ """ Helper object for defining how fields are encrypted and decrypted between the database
and the application.
"""
- def __init__(self, secret_key, version='v0'):
- # NOTE: secret_key will be None when the system is being first initialized, so we allow that
- # case here, but make sure to assert that it is *not* None below if any encryption is actually
- # needed.
- self._secret_key = convert_secret_key(secret_key) if secret_key is not None else None
- self._encryption_version = _VERSIONS[version]
- def encrypt_value(self, value, field_max_length=None):
- """ Encrypts the value using the current version of encryption. """
- assert self._secret_key is not None
- encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length)
- return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
+ def __init__(self, secret_key, version="v0"):
+ # NOTE: secret_key will be None when the system is being first initialized, so we allow that
+ # case here, but make sure to assert that it is *not* None below if any encryption is actually
+ # needed.
+ self._secret_key = (
+ convert_secret_key(secret_key) if secret_key is not None else None
+ )
+ self._encryption_version = _VERSIONS[version]
- def decrypt_value(self, value):
- """ Decrypts the value, returning it. If the value cannot be decrypted
+ def encrypt_value(self, value, field_max_length=None):
+ """ Encrypts the value using the current version of encryption. """
+ assert self._secret_key is not None
+ encrypted_value = self._encryption_version.encrypt(
+ self._secret_key, value, field_max_length
+ )
+ return "%s%s%s" % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
+
+ def decrypt_value(self, value):
+ """ Decrypts the value, returning it. If the value cannot be decrypted
raises a DecryptionFailureException.
"""
- assert self._secret_key is not None
- if _SEPARATOR not in value:
- raise DecryptionFailureException('Invalid encrypted value')
+ assert self._secret_key is not None
+ if _SEPARATOR not in value:
+ raise DecryptionFailureException("Invalid encrypted value")
- version_prefix, data = value.split(_SEPARATOR, 1)
- if version_prefix not in _VERSIONS:
- raise DecryptionFailureException('Unknown version prefix %s' % version_prefix)
-
- return _VERSIONS[version_prefix].decrypt(self._secret_key, data)
+ version_prefix, data = value.split(_SEPARATOR, 1)
+ if version_prefix not in _VERSIONS:
+ raise DecryptionFailureException(
+ "Unknown version prefix %s" % version_prefix
+ )
+ return _VERSIONS[version_prefix].decrypt(self._secret_key, data)
diff --git a/data/fields.py b/data/fields.py
index c79a7e6bd..e72fddd93 100644
--- a/data/fields.py
+++ b/data/fields.py
@@ -12,176 +12,182 @@ from data.text import prefix_search
def random_string(length=16):
- random = SystemRandom()
- return ''.join([random.choice(string.ascii_uppercase + string.digits)
- for _ in range(length)])
+ random = SystemRandom()
+ return "".join(
+ [random.choice(string.ascii_uppercase + string.digits) for _ in range(length)]
+ )
class _ResumableSHAField(TextField):
- def _create_sha(self):
- raise NotImplementedError
+ def _create_sha(self):
+ raise NotImplementedError
- def db_value(self, value):
- if value is None:
- return None
+ def db_value(self, value):
+ if value is None:
+ return None
- sha_state = value.state()
+ sha_state = value.state()
- # One of the fields is a byte string, let's base64 encode it to make sure
- # we can store and fetch it regardless of default collocation.
- sha_state[3] = base64.b64encode(sha_state[3])
+ # One of the fields is a byte string, let's base64 encode it to make sure
+ # we can store and fetch it regardless of default collocation.
+ sha_state[3] = base64.b64encode(sha_state[3])
- return json.dumps(sha_state)
+ return json.dumps(sha_state)
- def python_value(self, value):
- if value is None:
- return None
+ def python_value(self, value):
+ if value is None:
+ return None
- sha_state = json.loads(value)
+ sha_state = json.loads(value)
- # We need to base64 decode the data bytestring.
- sha_state[3] = base64.b64decode(sha_state[3])
- to_resume = self._create_sha()
- to_resume.set_state(sha_state)
- return to_resume
+ # We need to base64 decode the data bytestring.
+ sha_state[3] = base64.b64decode(sha_state[3])
+ to_resume = self._create_sha()
+ to_resume.set_state(sha_state)
+ return to_resume
class ResumableSHA256Field(_ResumableSHAField):
- def _create_sha(self):
- return resumablehashlib.sha256()
+ def _create_sha(self):
+ return resumablehashlib.sha256()
class ResumableSHA1Field(_ResumableSHAField):
- def _create_sha(self):
- return resumablehashlib.sha1()
+ def _create_sha(self):
+ return resumablehashlib.sha1()
class JSONField(TextField):
- def db_value(self, value):
- return json.dumps(value)
+ def db_value(self, value):
+ return json.dumps(value)
- def python_value(self, value):
- if value is None or value == "":
- return {}
- return json.loads(value)
+ def python_value(self, value):
+ if value is None or value == "":
+ return {}
+ return json.loads(value)
class Base64BinaryField(TextField):
- def db_value(self, value):
- if value is None:
- return None
- return base64.b64encode(value)
+ def db_value(self, value):
+ if value is None:
+ return None
+ return base64.b64encode(value)
- def python_value(self, value):
- if value is None:
- return None
- return base64.b64decode(value)
+ def python_value(self, value):
+ if value is None:
+ return None
+ return base64.b64decode(value)
class DecryptedValue(object):
- """ Wrapper around an already decrypted value to be placed into an encrypted field. """
- def __init__(self, decrypted_value):
- assert decrypted_value is not None
- self.value = decrypted_value
+ """ Wrapper around an already decrypted value to be placed into an encrypted field. """
- def decrypt(self):
- return self.value
+ def __init__(self, decrypted_value):
+ assert decrypted_value is not None
+ self.value = decrypted_value
- def matches(self, unencrypted_value):
- """ Returns whether the value of this field matches the unencrypted_value. """
- return self.decrypt() == unencrypted_value
+ def decrypt(self):
+ return self.value
+
+ def matches(self, unencrypted_value):
+ """ Returns whether the value of this field matches the unencrypted_value. """
+ return self.decrypt() == unencrypted_value
class LazyEncryptedValue(object):
- """ Wrapper around an encrypted value in an encrypted field. Will decrypt lazily. """
- def __init__(self, encrypted_value, field):
- self.encrypted_value = encrypted_value
- self._field = field
+ """ Wrapper around an encrypted value in an encrypted field. Will decrypt lazily. """
- def decrypt(self):
- """ Decrypts the value. """
- return self._field.model._meta.encrypter.decrypt_value(self.encrypted_value)
+ def __init__(self, encrypted_value, field):
+ self.encrypted_value = encrypted_value
+ self._field = field
- def matches(self, unencrypted_value):
- """ Returns whether the value of this field matches the unencrypted_value. """
- return self.decrypt() == unencrypted_value
+ def decrypt(self):
+ """ Decrypts the value. """
+ return self._field.model._meta.encrypter.decrypt_value(self.encrypted_value)
- def __eq__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def matches(self, unencrypted_value):
+ """ Returns whether the value of this field matches the unencrypted_value. """
+ return self.decrypt() == unencrypted_value
- def __mod__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __eq__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def __pow__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __mod__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def __contains__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __pow__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def contains(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __contains__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def startswith(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def contains(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def endswith(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def startswith(self, _):
+ raise Exception("Disallowed operation; use `matches`")
+
+ def endswith(self, _):
+ raise Exception("Disallowed operation; use `matches`")
def _add_encryption(field_class, requires_length_check=True):
- """ Adds support for encryption and decryption to the given field class. """
- class indexed_class(field_class):
- def __init__(self, default_token_length=None, *args, **kwargs):
- def _generate_default():
- return DecryptedValue(random_string(default_token_length))
+ """ Adds support for encryption and decryption to the given field class. """
- if default_token_length is not None:
- kwargs['default'] = _generate_default
+ class indexed_class(field_class):
+ def __init__(self, default_token_length=None, *args, **kwargs):
+ def _generate_default():
+ return DecryptedValue(random_string(default_token_length))
- field_class.__init__(self, *args, **kwargs)
- assert not self.index
+ if default_token_length is not None:
+ kwargs["default"] = _generate_default
- def db_value(self, value):
- if value is None:
- return None
+ field_class.__init__(self, *args, **kwargs)
+ assert not self.index
- if isinstance(value, LazyEncryptedValue):
- return value.encrypted_value
+ def db_value(self, value):
+ if value is None:
+ return None
- if isinstance(value, DecryptedValue):
- value = value.value
+ if isinstance(value, LazyEncryptedValue):
+ return value.encrypted_value
- meta = self.model._meta
- return meta.encrypter.encrypt_value(value, self.max_length if requires_length_check else None)
+ if isinstance(value, DecryptedValue):
+ value = value.value
- def python_value(self, value):
- if value is None:
- return None
+ meta = self.model._meta
+ return meta.encrypter.encrypt_value(
+ value, self.max_length if requires_length_check else None
+ )
- return LazyEncryptedValue(value, self)
+ def python_value(self, value):
+ if value is None:
+ return None
- def __eq__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ return LazyEncryptedValue(value, self)
- def __mod__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __eq__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def __pow__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __mod__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def __contains__(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __pow__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def contains(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def __contains__(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def startswith(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def contains(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- def endswith(self, _):
- raise Exception('Disallowed operation; use `matches`')
+ def startswith(self, _):
+ raise Exception("Disallowed operation; use `matches`")
- return indexed_class
+ def endswith(self, _):
+ raise Exception("Disallowed operation; use `matches`")
+
+ return indexed_class
EncryptedCharField = _add_encryption(CharField)
@@ -189,61 +195,60 @@ EncryptedTextField = _add_encryption(TextField, requires_length_check=False)
class EnumField(SmallIntegerField):
- def __init__(self, enum_type, *args, **kwargs):
- kwargs.pop('index', None)
+ def __init__(self, enum_type, *args, **kwargs):
+ kwargs.pop("index", None)
- super(EnumField, self).__init__(index=True, *args, **kwargs)
- self.enum_type = enum_type
+ super(EnumField, self).__init__(index=True, *args, **kwargs)
+ self.enum_type = enum_type
- def db_value(self, value):
- """Convert the python value for storage in the database."""
- return int(value.value)
+ def db_value(self, value):
+ """Convert the python value for storage in the database."""
+ return int(value.value)
- def python_value(self, value):
- """Convert the database value to a pythonic value."""
- return self.enum_type(value) if value is not None else None
+ def python_value(self, value):
+ """Convert the database value to a pythonic value."""
+ return self.enum_type(value) if value is not None else None
- def clone_base(self, **kwargs):
- return super(EnumField, self).clone_base(
- enum_type=self.enum_type,
- **kwargs)
+ def clone_base(self, **kwargs):
+ return super(EnumField, self).clone_base(enum_type=self.enum_type, **kwargs)
def _add_fulltext(field_class):
- """ Adds support for full text indexing and lookup to the given field class. """
- class indexed_class(field_class):
- # Marker used by SQLAlchemy translation layer to add the proper index for full text searching.
- __fulltext__ = True
+ """ Adds support for full text indexing and lookup to the given field class. """
- def __init__(self, match_function, *args, **kwargs):
- field_class.__init__(self, *args, **kwargs)
- self.match_function = match_function
+ class indexed_class(field_class):
+ # Marker used by SQLAlchemy translation layer to add the proper index for full text searching.
+ __fulltext__ = True
- def match(self, query):
- return self.match_function(self, query)
+ def __init__(self, match_function, *args, **kwargs):
+ field_class.__init__(self, *args, **kwargs)
+ self.match_function = match_function
- def match_prefix(self, query):
- return prefix_search(self, query)
+ def match(self, query):
+ return self.match_function(self, query)
- def __mod__(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def match_prefix(self, query):
+ return prefix_search(self, query)
- def __pow__(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def __mod__(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
- def __contains__(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def __pow__(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
- def contains(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def __contains__(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
- def startswith(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def contains(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
- def endswith(self, _):
- raise Exception('Unsafe operation: Use `match` or `match_prefix`')
+ def startswith(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
- return indexed_class
+ def endswith(self, _):
+ raise Exception("Unsafe operation: Use `match` or `match_prefix`")
+
+ return indexed_class
FullIndexedCharField = _add_fulltext(CharField)
@@ -251,47 +256,51 @@ FullIndexedTextField = _add_fulltext(TextField)
class Credential(object):
- """ Credential represents a hashed credential. """
- def __init__(self, hashed):
- self.hashed = hashed
+ """ Credential represents a hashed credential. """
- def matches(self, value):
- """ Returns true if this credential matches the unhashed value given. """
- return bcrypt.hashpw(value.encode('utf-8'), self.hashed) == self.hashed
+ def __init__(self, hashed):
+ self.hashed = hashed
- @classmethod
- def from_string(cls, string_value):
- """ Returns a Credential object from an unhashed string value. """
- return Credential(bcrypt.hashpw(string_value.encode('utf-8'), bcrypt.gensalt()))
+ def matches(self, value):
+ """ Returns true if this credential matches the unhashed value given. """
+ return bcrypt.hashpw(value.encode("utf-8"), self.hashed) == self.hashed
- @classmethod
- def generate(cls, length=20):
- """ Generates a new credential and returns it, along with its unhashed form. """
- token = random_string(length)
- return Credential.from_string(token), token
+ @classmethod
+ def from_string(cls, string_value):
+ """ Returns a Credential object from an unhashed string value. """
+ return Credential(bcrypt.hashpw(string_value.encode("utf-8"), bcrypt.gensalt()))
+
+ @classmethod
+ def generate(cls, length=20):
+ """ Generates a new credential and returns it, along with its unhashed form. """
+ token = random_string(length)
+ return Credential.from_string(token), token
class CredentialField(CharField):
- """ A character field that stores crytographically hashed credentials that should never be
+ """ A character field that stores crytographically hashed credentials that should never be
available to the user in plaintext after initial creation. This field automatically
provides verification.
"""
- def __init__(self, *args, **kwargs):
- CharField.__init__(self, *args, **kwargs)
- assert 'default' not in kwargs
- assert not self.index
- def db_value(self, value):
- if value is None:
- return None
+ def __init__(self, *args, **kwargs):
+ CharField.__init__(self, *args, **kwargs)
+ assert "default" not in kwargs
+ assert not self.index
- if isinstance(value, basestring):
- raise Exception('A string cannot be given to a CredentialField; please wrap in a Credential')
+ def db_value(self, value):
+ if value is None:
+ return None
- return value.hashed
+ if isinstance(value, basestring):
+ raise Exception(
+ "A string cannot be given to a CredentialField; please wrap in a Credential"
+ )
- def python_value(self, value):
- if value is None:
- return None
+ return value.hashed
- return Credential(value)
+ def python_value(self, value):
+ if value is None:
+ return None
+
+ return Credential(value)
diff --git a/data/logs_model/__init__.py b/data/logs_model/__init__.py
index be8cc9402..75c1023f3 100644
--- a/data/logs_model/__init__.py
+++ b/data/logs_model/__init__.py
@@ -8,57 +8,61 @@ logger = logging.getLogger(__name__)
def _transition_model(*args, **kwargs):
- return CombinedLogsModel(
- DocumentLogsModel(*args, **kwargs),
- TableLogsModel(*args, **kwargs),
- )
+ return CombinedLogsModel(
+ DocumentLogsModel(*args, **kwargs), TableLogsModel(*args, **kwargs)
+ )
_LOG_MODELS = {
- 'database': TableLogsModel,
- 'transition_reads_both_writes_es': _transition_model,
- 'elasticsearch': DocumentLogsModel,
+ "database": TableLogsModel,
+ "transition_reads_both_writes_es": _transition_model,
+ "elasticsearch": DocumentLogsModel,
}
-_PULL_LOG_KINDS = {'pull_repo', 'repo_verb'}
+_PULL_LOG_KINDS = {"pull_repo", "repo_verb"}
+
class LogsModelProxy(object):
- def __init__(self):
- self._model = None
+ def __init__(self):
+ self._model = None
- def initialize(self, model):
- self._model = model
- logger.info('===============================')
- logger.info('Using logs model `%s`', self._model)
- logger.info('===============================')
+ def initialize(self, model):
+ self._model = model
+ logger.info("===============================")
+ logger.info("Using logs model `%s`", self._model)
+ logger.info("===============================")
- def __getattr__(self, attr):
- if not self._model:
- raise AttributeError("LogsModelProxy is not initialized")
- return getattr(self._model, attr)
+ def __getattr__(self, attr):
+ if not self._model:
+ raise AttributeError("LogsModelProxy is not initialized")
+ return getattr(self._model, attr)
logs_model = LogsModelProxy()
def configure(app_config):
- logger.debug('Configuring log lodel')
- model_name = app_config.get('LOGS_MODEL', 'database')
- model_config = app_config.get('LOGS_MODEL_CONFIG', {})
+ logger.debug("Configuring log lodel")
+ model_name = app_config.get("LOGS_MODEL", "database")
+ model_config = app_config.get("LOGS_MODEL_CONFIG", {})
- def should_skip_logging(kind_name, namespace_name, is_free_namespace):
- if namespace_name and namespace_name in app_config.get('DISABLED_FOR_AUDIT_LOGS', {}):
- return True
+ def should_skip_logging(kind_name, namespace_name, is_free_namespace):
+ if namespace_name and namespace_name in app_config.get(
+ "DISABLED_FOR_AUDIT_LOGS", {}
+ ):
+ return True
- if kind_name in _PULL_LOG_KINDS:
- if namespace_name and namespace_name in app_config.get('DISABLED_FOR_PULL_LOGS', {}):
- return True
+ if kind_name in _PULL_LOG_KINDS:
+ if namespace_name and namespace_name in app_config.get(
+ "DISABLED_FOR_PULL_LOGS", {}
+ ):
+ return True
- if app_config.get('FEATURE_DISABLE_PULL_LOGS_FOR_FREE_NAMESPACES'):
- if is_free_namespace:
- return True
+ if app_config.get("FEATURE_DISABLE_PULL_LOGS_FOR_FREE_NAMESPACES"):
+ if is_free_namespace:
+ return True
- return False
+ return False
- model_config['should_skip_logging'] = should_skip_logging
- logs_model.initialize(_LOG_MODELS[model_name](**model_config))
+ model_config["should_skip_logging"] = should_skip_logging
+ logs_model.initialize(_LOG_MODELS[model_name](**model_config))
diff --git a/data/logs_model/combined_model.py b/data/logs_model/combined_model.py
index 735101601..ea62ff7a4 100644
--- a/data/logs_model/combined_model.py
+++ b/data/logs_model/combined_model.py
@@ -9,124 +9,201 @@ logger = logging.getLogger(__name__)
def _merge_aggregated_log_counts(*args):
- """ Merge two lists of AggregatedLogCount based on the value of their kind_id and datetime.
+ """ Merge two lists of AggregatedLogCount based on the value of their kind_id and datetime.
"""
- matching_keys = {}
- aggregated_log_counts_list = itertools.chain.from_iterable(args)
+ matching_keys = {}
+ aggregated_log_counts_list = itertools.chain.from_iterable(args)
- def canonical_key_from_kind_date_tuple(kind_id, dt):
- """ Return a comma separated key from an AggregatedLogCount's kind_id and datetime. """
- return str(kind_id) + ',' + str(dt)
+ def canonical_key_from_kind_date_tuple(kind_id, dt):
+ """ Return a comma separated key from an AggregatedLogCount's kind_id and datetime. """
+ return str(kind_id) + "," + str(dt)
- for kind_id, count, dt in aggregated_log_counts_list:
- kind_date_key = canonical_key_from_kind_date_tuple(kind_id, dt)
- if kind_date_key in matching_keys:
- existing_count = matching_keys[kind_date_key][2]
- matching_keys[kind_date_key] = (kind_id, dt, existing_count + count)
- else:
- matching_keys[kind_date_key] = (kind_id, dt, count)
+ for kind_id, count, dt in aggregated_log_counts_list:
+ kind_date_key = canonical_key_from_kind_date_tuple(kind_id, dt)
+ if kind_date_key in matching_keys:
+ existing_count = matching_keys[kind_date_key][2]
+ matching_keys[kind_date_key] = (kind_id, dt, existing_count + count)
+ else:
+ matching_keys[kind_date_key] = (kind_id, dt, count)
- return [AggregatedLogCount(kind_id, count, dt) for (kind_id, dt, count) in matching_keys.values()]
+ return [
+ AggregatedLogCount(kind_id, count, dt)
+ for (kind_id, dt, count) in matching_keys.values()
+ ]
class CombinedLogsModel(SharedModel, ActionLogsDataInterface):
- """
+ """
CombinedLogsModel implements the data model that logs to the first logs model and reads from
both.
"""
- def __init__(self, read_write_logs_model, read_only_logs_model):
- self.read_write_logs_model = read_write_logs_model
- self.read_only_logs_model = read_only_logs_model
+ def __init__(self, read_write_logs_model, read_only_logs_model):
+ self.read_write_logs_model = read_write_logs_model
+ self.read_only_logs_model = read_only_logs_model
- def log_action(self, kind_name, namespace_name=None, performer=None, ip=None, metadata=None,
- repository=None, repository_name=None, timestamp=None, is_free_namespace=False):
- return self.read_write_logs_model.log_action(kind_name, namespace_name, performer, ip, metadata,
- repository, repository_name, timestamp,
- is_free_namespace)
+ def log_action(
+ self,
+ kind_name,
+ namespace_name=None,
+ performer=None,
+ ip=None,
+ metadata=None,
+ repository=None,
+ repository_name=None,
+ timestamp=None,
+ is_free_namespace=False,
+ ):
+ return self.read_write_logs_model.log_action(
+ kind_name,
+ namespace_name,
+ performer,
+ ip,
+ metadata,
+ repository,
+ repository_name,
+ timestamp,
+ is_free_namespace,
+ )
- def count_repository_actions(self, repository, day):
- rw_count = self.read_write_logs_model.count_repository_actions(repository, day)
- ro_count = self.read_only_logs_model.count_repository_actions(repository, day)
- return rw_count + ro_count
+ def count_repository_actions(self, repository, day):
+ rw_count = self.read_write_logs_model.count_repository_actions(repository, day)
+ ro_count = self.read_only_logs_model.count_repository_actions(repository, day)
+ return rw_count + ro_count
- def get_aggregated_log_counts(self, start_datetime, end_datetime, performer_name=None,
- repository_name=None, namespace_name=None, filter_kinds=None):
- rw_model = self.read_write_logs_model
- ro_model = self.read_only_logs_model
- rw_count = rw_model.get_aggregated_log_counts(start_datetime, end_datetime,
- performer_name=performer_name,
- repository_name=repository_name,
- namespace_name=namespace_name,
- filter_kinds=filter_kinds)
- ro_count = ro_model.get_aggregated_log_counts(start_datetime, end_datetime,
- performer_name=performer_name,
- repository_name=repository_name,
- namespace_name=namespace_name,
- filter_kinds=filter_kinds)
- return _merge_aggregated_log_counts(rw_count, ro_count)
+ def get_aggregated_log_counts(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ rw_model = self.read_write_logs_model
+ ro_model = self.read_only_logs_model
+ rw_count = rw_model.get_aggregated_log_counts(
+ start_datetime,
+ end_datetime,
+ performer_name=performer_name,
+ repository_name=repository_name,
+ namespace_name=namespace_name,
+ filter_kinds=filter_kinds,
+ )
+ ro_count = ro_model.get_aggregated_log_counts(
+ start_datetime,
+ end_datetime,
+ performer_name=performer_name,
+ repository_name=repository_name,
+ namespace_name=namespace_name,
+ filter_kinds=filter_kinds,
+ )
+ return _merge_aggregated_log_counts(rw_count, ro_count)
- def yield_logs_for_export(self, start_datetime, end_datetime, repository_id=None,
- namespace_id=None, max_query_time=None):
- rw_model = self.read_write_logs_model
- ro_model = self.read_only_logs_model
- rw_logs = rw_model.yield_logs_for_export(start_datetime, end_datetime, repository_id,
- namespace_id, max_query_time)
- ro_logs = ro_model.yield_logs_for_export(start_datetime, end_datetime, repository_id,
- namespace_id, max_query_time)
- for batch in itertools.chain(rw_logs, ro_logs):
- yield batch
+ def yield_logs_for_export(
+ self,
+ start_datetime,
+ end_datetime,
+ repository_id=None,
+ namespace_id=None,
+ max_query_time=None,
+ ):
+ rw_model = self.read_write_logs_model
+ ro_model = self.read_only_logs_model
+ rw_logs = rw_model.yield_logs_for_export(
+ start_datetime, end_datetime, repository_id, namespace_id, max_query_time
+ )
+ ro_logs = ro_model.yield_logs_for_export(
+ start_datetime, end_datetime, repository_id, namespace_id, max_query_time
+ )
+ for batch in itertools.chain(rw_logs, ro_logs):
+ yield batch
- def lookup_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None, page_token=None, max_page_count=None):
- rw_model = self.read_write_logs_model
- ro_model = self.read_only_logs_model
+ def lookup_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ page_token=None,
+ max_page_count=None,
+ ):
+ rw_model = self.read_write_logs_model
+ ro_model = self.read_only_logs_model
- page_token = page_token or {}
+ page_token = page_token or {}
- new_page_token = {}
- if page_token is None or not page_token.get('under_readonly_model', False):
- rw_page_token = page_token.get('readwrite_page_token')
- rw_logs = rw_model.lookup_logs(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds,
- rw_page_token, max_page_count)
- logs, next_page_token = rw_logs
- new_page_token['under_readonly_model'] = next_page_token is None
- new_page_token['readwrite_page_token'] = next_page_token
- return LogEntriesPage(logs, new_page_token)
- else:
- readonly_page_token = page_token.get('readonly_page_token')
- ro_logs = ro_model.lookup_logs(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds,
- readonly_page_token, max_page_count)
- logs, next_page_token = ro_logs
- if next_page_token is None:
- return LogEntriesPage(logs, None)
+ new_page_token = {}
+ if page_token is None or not page_token.get("under_readonly_model", False):
+ rw_page_token = page_token.get("readwrite_page_token")
+ rw_logs = rw_model.lookup_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ rw_page_token,
+ max_page_count,
+ )
+ logs, next_page_token = rw_logs
+ new_page_token["under_readonly_model"] = next_page_token is None
+ new_page_token["readwrite_page_token"] = next_page_token
+ return LogEntriesPage(logs, new_page_token)
+ else:
+ readonly_page_token = page_token.get("readonly_page_token")
+ ro_logs = ro_model.lookup_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ readonly_page_token,
+ max_page_count,
+ )
+ logs, next_page_token = ro_logs
+ if next_page_token is None:
+ return LogEntriesPage(logs, None)
- new_page_token['under_readonly_model'] = True
- new_page_token['readonly_page_token'] = next_page_token
- return LogEntriesPage(logs, new_page_token)
+ new_page_token["under_readonly_model"] = True
+ new_page_token["readonly_page_token"] = next_page_token
+ return LogEntriesPage(logs, new_page_token)
- def lookup_latest_logs(self, performer_name=None, repository_name=None, namespace_name=None,
- filter_kinds=None, size=20):
- latest_logs = []
- rw_model = self.read_write_logs_model
- ro_model = self.read_only_logs_model
+ def lookup_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ size=20,
+ ):
+ latest_logs = []
+ rw_model = self.read_write_logs_model
+ ro_model = self.read_only_logs_model
- rw_logs = rw_model.lookup_latest_logs(performer_name, repository_name, namespace_name,
- filter_kinds, size)
- latest_logs.extend(rw_logs)
- if len(latest_logs) < size:
- ro_logs = ro_model.lookup_latest_logs(performer_name, repository_name, namespace_name,
- filter_kinds, size - len(latest_logs))
- latest_logs.extend(ro_logs)
+ rw_logs = rw_model.lookup_latest_logs(
+ performer_name, repository_name, namespace_name, filter_kinds, size
+ )
+ latest_logs.extend(rw_logs)
+ if len(latest_logs) < size:
+ ro_logs = ro_model.lookup_latest_logs(
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ size - len(latest_logs),
+ )
+ latest_logs.extend(ro_logs)
- return latest_logs
+ return latest_logs
- def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
- ro_model = self.read_only_logs_model
- rw_model = self.read_write_logs_model
- ro_ctx = ro_model.yield_log_rotation_context(cutoff_date, min_logs_per_rotation)
- rw_ctx = rw_model.yield_log_rotation_context(cutoff_date, min_logs_per_rotation)
- for ctx in itertools.chain(ro_ctx, rw_ctx):
- yield ctx
+ def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
+ ro_model = self.read_only_logs_model
+ rw_model = self.read_write_logs_model
+ ro_ctx = ro_model.yield_log_rotation_context(cutoff_date, min_logs_per_rotation)
+ rw_ctx = rw_model.yield_log_rotation_context(cutoff_date, min_logs_per_rotation)
+ for ctx in itertools.chain(ro_ctx, rw_ctx):
+ yield ctx
diff --git a/data/logs_model/datatypes.py b/data/logs_model/datatypes.py
index 03db6756f..f670c6397 100644
--- a/data/logs_model/datatypes.py
+++ b/data/logs_model/datatypes.py
@@ -11,145 +11,187 @@ from util.morecollections import AttrDict
def _format_date(date):
- """ Output an RFC822 date format. """
- if date is None:
- return None
+ """ Output an RFC822 date format. """
+ if date is None:
+ return None
- return formatdate(timegm(date.utctimetuple()))
+ return formatdate(timegm(date.utctimetuple()))
@lru_cache(maxsize=1)
def _kinds():
- return model.log.get_log_entry_kinds()
+ return model.log.get_log_entry_kinds()
-class LogEntriesPage(namedtuple('LogEntriesPage', ['logs', 'next_page_token'])):
- """ Represents a page returned by the lookup_logs call. The `logs` contains the logs
+class LogEntriesPage(namedtuple("LogEntriesPage", ["logs", "next_page_token"])):
+ """ Represents a page returned by the lookup_logs call. The `logs` contains the logs
found for the page and `next_page_token`, if not None, contains the token to be
encoded and returned for the followup call.
"""
-class Log(namedtuple('Log', [
- 'metadata_json', 'ip', 'datetime', 'performer_email', 'performer_username', 'performer_robot',
- 'account_organization', 'account_username', 'account_email', 'account_robot', 'kind_id'])):
- """ Represents a single log entry returned by the logs model. """
+class Log(
+ namedtuple(
+ "Log",
+ [
+ "metadata_json",
+ "ip",
+ "datetime",
+ "performer_email",
+ "performer_username",
+ "performer_robot",
+ "account_organization",
+ "account_username",
+ "account_email",
+ "account_robot",
+ "kind_id",
+ ],
+ )
+):
+ """ Represents a single log entry returned by the logs model. """
- @classmethod
- def for_logentry(cls, log):
- account_organization = None
- account_username = None
- account_email = None
- account_robot = None
+ @classmethod
+ def for_logentry(cls, log):
+ account_organization = None
+ account_username = None
+ account_email = None
+ account_robot = None
- try:
- account_organization = log.account.organization
- account_username = log.account.username
- account_email = log.account.email
- account_robot = log.account.robot
- except AttributeError:
- pass
+ try:
+ account_organization = log.account.organization
+ account_username = log.account.username
+ account_email = log.account.email
+ account_robot = log.account.robot
+ except AttributeError:
+ pass
- performer_robot = None
- performer_username = None
- performer_email = None
+ performer_robot = None
+ performer_username = None
+ performer_email = None
- try:
- performer_robot = log.performer.robot
- performer_username = log.performer.username
- performer_email = log.performer.email
- except AttributeError:
- pass
+ try:
+ performer_robot = log.performer.robot
+ performer_username = log.performer.username
+ performer_email = log.performer.email
+ except AttributeError:
+ pass
- return Log(log.metadata_json, log.ip, log.datetime, performer_email, performer_username,
- performer_robot, account_organization, account_username, account_email,
- account_robot, log.kind_id)
+ return Log(
+ log.metadata_json,
+ log.ip,
+ log.datetime,
+ performer_email,
+ performer_username,
+ performer_robot,
+ account_organization,
+ account_username,
+ account_email,
+ account_robot,
+ log.kind_id,
+ )
- @classmethod
- def for_elasticsearch_log(cls, log, id_user_map):
- account_organization = None
- account_username = None
- account_email = None
- account_robot = None
+ @classmethod
+ def for_elasticsearch_log(cls, log, id_user_map):
+ account_organization = None
+ account_username = None
+ account_email = None
+ account_robot = None
- try:
- if log.account_id:
- account = id_user_map[log.account_id]
- account_organization = account.organization
- account_username = account.username
- account_email = account.email
- account_robot = account.robot
- except AttributeError:
- pass
+ try:
+ if log.account_id:
+ account = id_user_map[log.account_id]
+ account_organization = account.organization
+ account_username = account.username
+ account_email = account.email
+ account_robot = account.robot
+ except AttributeError:
+ pass
- performer_robot = None
- performer_username = None
- performer_email = None
+ performer_robot = None
+ performer_username = None
+ performer_email = None
- try:
- if log.performer_id:
- performer = id_user_map[log.performer_id]
- performer_robot = performer.robot
- performer_username = performer.username
- performer_email = performer.email
- except AttributeError:
- pass
+ try:
+ if log.performer_id:
+ performer = id_user_map[log.performer_id]
+ performer_robot = performer.robot
+ performer_username = performer.username
+ performer_email = performer.email
+ except AttributeError:
+ pass
- return Log(log.metadata_json, str(log.ip), log.datetime, performer_email, performer_username,
- performer_robot, account_organization, account_username, account_email,
- account_robot, log.kind_id)
+ return Log(
+ log.metadata_json,
+ str(log.ip),
+ log.datetime,
+ performer_email,
+ performer_username,
+ performer_robot,
+ account_organization,
+ account_username,
+ account_email,
+ account_robot,
+ log.kind_id,
+ )
- def to_dict(self, avatar, include_namespace=False):
- view = {
- 'kind': _kinds()[self.kind_id],
- 'metadata': json.loads(self.metadata_json),
- 'ip': self.ip,
- 'datetime': _format_date(self.datetime),
- }
+ def to_dict(self, avatar, include_namespace=False):
+ view = {
+ "kind": _kinds()[self.kind_id],
+ "metadata": json.loads(self.metadata_json),
+ "ip": self.ip,
+ "datetime": _format_date(self.datetime),
+ }
- if self.performer_username:
- performer = AttrDict({'username': self.performer_username, 'email': self.performer_email})
- performer.robot = None
- if self.performer_robot:
- performer.robot = self.performer_robot
+ if self.performer_username:
+ performer = AttrDict(
+ {"username": self.performer_username, "email": self.performer_email}
+ )
+ performer.robot = None
+ if self.performer_robot:
+ performer.robot = self.performer_robot
- view['performer'] = {
- 'kind': 'user',
- 'name': self.performer_username,
- 'is_robot': self.performer_robot,
- 'avatar': avatar.get_data_for_user(performer),
- }
+ view["performer"] = {
+ "kind": "user",
+ "name": self.performer_username,
+ "is_robot": self.performer_robot,
+ "avatar": avatar.get_data_for_user(performer),
+ }
- if include_namespace:
- if self.account_username:
- account = AttrDict({'username': self.account_username, 'email': self.account_email})
- if self.account_organization:
+ if include_namespace:
+ if self.account_username:
+ account = AttrDict(
+ {"username": self.account_username, "email": self.account_email}
+ )
+ if self.account_organization:
- view['namespace'] = {
- 'kind': 'org',
- 'name': self.account_username,
- 'avatar': avatar.get_data_for_org(account),
- }
- else:
- account.robot = None
- if self.account_robot:
- account.robot = self.account_robot
- view['namespace'] = {
- 'kind': 'user',
- 'name': self.account_username,
- 'avatar': avatar.get_data_for_user(account),
- }
+ view["namespace"] = {
+ "kind": "org",
+ "name": self.account_username,
+ "avatar": avatar.get_data_for_org(account),
+ }
+ else:
+ account.robot = None
+ if self.account_robot:
+ account.robot = self.account_robot
+ view["namespace"] = {
+ "kind": "user",
+ "name": self.account_username,
+ "avatar": avatar.get_data_for_user(account),
+ }
- return view
+ return view
-class AggregatedLogCount(namedtuple('AggregatedLogCount', ['kind_id', 'count', 'datetime'])):
- """ Represents the aggregated count of the number of logs, of a particular kind, on a day. """
- def to_dict(self):
- view = {
- 'kind': _kinds()[self.kind_id],
- 'count': self.count,
- 'datetime': _format_date(self.datetime),
- }
+class AggregatedLogCount(
+ namedtuple("AggregatedLogCount", ["kind_id", "count", "datetime"])
+):
+ """ Represents the aggregated count of the number of logs, of a particular kind, on a day. """
- return view
+ def to_dict(self):
+ view = {
+ "kind": _kinds()[self.kind_id],
+ "count": self.count,
+ "datetime": _format_date(self.datetime),
+ }
+
+ return view
diff --git a/data/logs_model/document_logs_model.py b/data/logs_model/document_logs_model.py
index e93cd2062..be1257284 100644
--- a/data/logs_model/document_logs_model.py
+++ b/data/logs_model/document_logs_model.py
@@ -16,18 +16,28 @@ from elasticsearch.exceptions import ConnectionTimeout, NotFoundError
from data import model
from data.database import CloseForLongOperation
from data.model import config
-from data.model.log import (_json_serialize, ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING,
- DataModelException)
+from data.model.log import (
+ _json_serialize,
+ ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING,
+ DataModelException,
+)
from data.logs_model.elastic_logs import LogEntry, configure_es
from data.logs_model.datatypes import Log, AggregatedLogCount, LogEntriesPage
-from data.logs_model.interface import (ActionLogsDataInterface, LogRotationContextInterface,
- LogsIterationTimeout)
+from data.logs_model.interface import (
+ ActionLogsDataInterface,
+ LogRotationContextInterface,
+ LogsIterationTimeout,
+)
from data.logs_model.shared import SharedModel, epoch_ms
from data.logs_model.logs_producer import LogProducerProxy, LogSendException
from data.logs_model.logs_producer.kafka_logs_producer import KafkaLogsProducer
-from data.logs_model.logs_producer.elasticsearch_logs_producer import ElasticsearchLogsProducer
-from data.logs_model.logs_producer.kinesis_stream_logs_producer import KinesisStreamLogsProducer
+from data.logs_model.logs_producer.elasticsearch_logs_producer import (
+ ElasticsearchLogsProducer,
+)
+from data.logs_model.logs_producer.kinesis_stream_logs_producer import (
+ KinesisStreamLogsProducer,
+)
logger = logging.getLogger(__name__)
@@ -43,490 +53,643 @@ DATE_RANGE_LIMIT = 32
COUNT_REPOSITORY_ACTION_TIMEOUT = 30
-
def _date_range_descending(start_datetime, end_datetime, includes_end_datetime=False):
- """ Generate the dates between `end_datetime` and `start_datetime`.
+ """ Generate the dates between `end_datetime` and `start_datetime`.
If `includes_end_datetime` is set, the generator starts at `end_datetime`,
otherwise, starts the generator at `end_datetime` minus 1 second.
"""
- assert end_datetime >= start_datetime
- start_date = start_datetime.date()
+ assert end_datetime >= start_datetime
+ start_date = start_datetime.date()
- if includes_end_datetime:
- current_date = end_datetime.date()
- else:
- current_date = (end_datetime - timedelta(seconds=1)).date()
+ if includes_end_datetime:
+ current_date = end_datetime.date()
+ else:
+ current_date = (end_datetime - timedelta(seconds=1)).date()
- while current_date >= start_date:
- yield current_date
- current_date = current_date - timedelta(days=1)
+ while current_date >= start_date:
+ yield current_date
+ current_date = current_date - timedelta(days=1)
def _date_range_in_single_index(dt1, dt2):
- """ Determine whether a single index can be searched given a range
+ """ Determine whether a single index can be searched given a range
of dates or datetimes. If date instances are given, difference should be 1 day.
NOTE: dt2 is exclusive to the search result set.
i.e. The date range is larger or equal to dt1 and strictly smaller than dt2
"""
- assert isinstance(dt1, date) and isinstance(dt2, date)
+ assert isinstance(dt1, date) and isinstance(dt2, date)
- dt = dt2 - dt1
+ dt = dt2 - dt1
- # Check if date or datetime
- if not isinstance(dt1, datetime) and not isinstance(dt2, datetime):
- return dt == timedelta(days=1)
+ # Check if date or datetime
+ if not isinstance(dt1, datetime) and not isinstance(dt2, datetime):
+ return dt == timedelta(days=1)
- if dt < timedelta(days=1) and dt >= timedelta(days=0):
- return dt2.day == dt1.day
+ if dt < timedelta(days=1) and dt >= timedelta(days=0):
+ return dt2.day == dt1.day
- # Check if datetime can be interpreted as a date: hour, minutes, seconds or microseconds set to 0
- if dt == timedelta(days=1):
- return dt1.hour == 0 and dt1.minute == 0 and dt1.second == 0 and dt1.microsecond == 0
+ # Check if datetime can be interpreted as a date: hour, minutes, seconds or microseconds set to 0
+ if dt == timedelta(days=1):
+ return (
+ dt1.hour == 0
+ and dt1.minute == 0
+ and dt1.second == 0
+ and dt1.microsecond == 0
+ )
- return False
+ return False
def _for_elasticsearch_logs(logs, repository_id=None, namespace_id=None):
- namespace_ids = set()
- for log in logs:
- namespace_ids.add(log.account_id)
- namespace_ids.add(log.performer_id)
- assert namespace_id is None or log.account_id == namespace_id
- assert repository_id is None or log.repository_id == repository_id
+ namespace_ids = set()
+ for log in logs:
+ namespace_ids.add(log.account_id)
+ namespace_ids.add(log.performer_id)
+ assert namespace_id is None or log.account_id == namespace_id
+ assert repository_id is None or log.repository_id == repository_id
- id_user_map = model.user.get_user_map_by_ids(namespace_ids)
- return [Log.for_elasticsearch_log(log, id_user_map) for log in logs]
+ id_user_map = model.user.get_user_map_by_ids(namespace_ids)
+ return [Log.for_elasticsearch_log(log, id_user_map) for log in logs]
def _random_id():
- """ Generates a unique uuid4 string for the random_id field in LogEntry.
+ """ Generates a unique uuid4 string for the random_id field in LogEntry.
It is used as tie-breaker for sorting logs based on datetime:
https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-search-after.html
"""
- return str(uuid.uuid4())
+ return str(uuid.uuid4())
@add_metaclass(ABCMeta)
class ElasticsearchLogsModelInterface(object):
- """
+ """
Interface for Elasticsearch specific operations with the logs model.
These operations are usually index based.
"""
- @abstractmethod
- def can_delete_index(self, index, cutoff_date):
- """ Return whether the given index is older than the given cutoff date. """
+ @abstractmethod
+ def can_delete_index(self, index, cutoff_date):
+ """ Return whether the given index is older than the given cutoff date. """
- @abstractmethod
- def list_indices(self):
- """ List the logs model's indices. """
+ @abstractmethod
+ def list_indices(self):
+ """ List the logs model's indices. """
-class DocumentLogsModel(SharedModel, ActionLogsDataInterface, ElasticsearchLogsModelInterface):
- """
+class DocumentLogsModel(
+ SharedModel, ActionLogsDataInterface, ElasticsearchLogsModelInterface
+):
+ """
DocumentLogsModel implements the data model for the logs API backed by an
elasticsearch service.
"""
- def __init__(self, should_skip_logging=None, elasticsearch_config=None, producer=None, **kwargs):
- self._should_skip_logging = should_skip_logging
- self._logs_producer = LogProducerProxy()
- self._es_client = configure_es(**elasticsearch_config)
- if producer == 'kafka':
- kafka_config = kwargs['kafka_config']
- self._logs_producer.initialize(KafkaLogsProducer(**kafka_config))
- elif producer == 'elasticsearch':
- self._logs_producer.initialize(ElasticsearchLogsProducer())
- elif producer == 'kinesis_stream':
- kinesis_stream_config = kwargs['kinesis_stream_config']
- self._logs_producer.initialize(KinesisStreamLogsProducer(**kinesis_stream_config))
- else:
- raise Exception('Invalid log producer: %s' % producer)
+ def __init__(
+ self,
+ should_skip_logging=None,
+ elasticsearch_config=None,
+ producer=None,
+ **kwargs
+ ):
+ self._should_skip_logging = should_skip_logging
+ self._logs_producer = LogProducerProxy()
+ self._es_client = configure_es(**elasticsearch_config)
- @staticmethod
- def _get_ids_by_names(repository_name, namespace_name, performer_name):
- """ Retrieve repository/namespace/performer ids based on their names.
+ if producer == "kafka":
+ kafka_config = kwargs["kafka_config"]
+ self._logs_producer.initialize(KafkaLogsProducer(**kafka_config))
+ elif producer == "elasticsearch":
+ self._logs_producer.initialize(ElasticsearchLogsProducer())
+ elif producer == "kinesis_stream":
+ kinesis_stream_config = kwargs["kinesis_stream_config"]
+ self._logs_producer.initialize(
+ KinesisStreamLogsProducer(**kinesis_stream_config)
+ )
+ else:
+ raise Exception("Invalid log producer: %s" % producer)
+
+ @staticmethod
+ def _get_ids_by_names(repository_name, namespace_name, performer_name):
+ """ Retrieve repository/namespace/performer ids based on their names.
throws DataModelException when the namespace_name does not match any
user in the database.
returns database ID or None if not exists.
"""
- repository_id = None
- account_id = None
- performer_id = None
+ repository_id = None
+ account_id = None
+ performer_id = None
- if repository_name and namespace_name:
- repository = model.repository.get_repository(namespace_name, repository_name)
- if repository:
- repository_id = repository.id
- account_id = repository.namespace_user.id
+ if repository_name and namespace_name:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
+ if repository:
+ repository_id = repository.id
+ account_id = repository.namespace_user.id
- if namespace_name and account_id is None:
- account = model.user.get_user_or_org(namespace_name)
- if account is None:
- raise DataModelException('Invalid namespace requested')
+ if namespace_name and account_id is None:
+ account = model.user.get_user_or_org(namespace_name)
+ if account is None:
+ raise DataModelException("Invalid namespace requested")
- account_id = account.id
+ account_id = account.id
- if performer_name:
- performer = model.user.get_user(performer_name)
- if performer:
- performer_id = performer.id
+ if performer_name:
+ performer = model.user.get_user(performer_name)
+ if performer:
+ performer_id = performer.id
- return repository_id, account_id, performer_id
+ return repository_id, account_id, performer_id
- def _base_query(self, performer_id=None, repository_id=None, account_id=None, filter_kinds=None,
- index=None):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ def _base_query(
+ self,
+ performer_id=None,
+ repository_id=None,
+ account_id=None,
+ filter_kinds=None,
+ index=None,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- if index is not None:
- search = LogEntry.search(index=index)
- else:
- search = LogEntry.search()
+ if index is not None:
+ search = LogEntry.search(index=index)
+ else:
+ search = LogEntry.search()
- if performer_id is not None:
- assert isinstance(performer_id, int)
- search = search.filter('term', performer_id=performer_id)
+ if performer_id is not None:
+ assert isinstance(performer_id, int)
+ search = search.filter("term", performer_id=performer_id)
- if repository_id is not None:
- assert isinstance(repository_id, int)
- search = search.filter('term', repository_id=repository_id)
+ if repository_id is not None:
+ assert isinstance(repository_id, int)
+ search = search.filter("term", repository_id=repository_id)
- if account_id is not None and repository_id is None:
- assert isinstance(account_id, int)
- search = search.filter('term', account_id=account_id)
+ if account_id is not None and repository_id is None:
+ assert isinstance(account_id, int)
+ search = search.filter("term", account_id=account_id)
- if filter_kinds is not None:
- kind_map = model.log.get_log_entry_kinds()
- ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
- search = search.exclude('terms', kind_id=ignore_ids)
+ if filter_kinds is not None:
+ kind_map = model.log.get_log_entry_kinds()
+ ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
+ search = search.exclude("terms", kind_id=ignore_ids)
- return search
+ return search
- def _base_query_date_range(self, start_datetime, end_datetime, performer_id, repository_id,
- account_id, filter_kinds, index=None):
- skip_datetime_check = False
- if _date_range_in_single_index(start_datetime, end_datetime):
- index = self._es_client.index_name(start_datetime)
- skip_datetime_check = self._es_client.index_exists(index)
+ def _base_query_date_range(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_id,
+ repository_id,
+ account_id,
+ filter_kinds,
+ index=None,
+ ):
+ skip_datetime_check = False
+ if _date_range_in_single_index(start_datetime, end_datetime):
+ index = self._es_client.index_name(start_datetime)
+ skip_datetime_check = self._es_client.index_exists(index)
- if index and (skip_datetime_check or self._es_client.index_exists(index)):
- search = self._base_query(performer_id, repository_id, account_id, filter_kinds,
- index=index)
- else:
- search = self._base_query(performer_id, repository_id, account_id, filter_kinds)
+ if index and (skip_datetime_check or self._es_client.index_exists(index)):
+ search = self._base_query(
+ performer_id, repository_id, account_id, filter_kinds, index=index
+ )
+ else:
+ search = self._base_query(
+ performer_id, repository_id, account_id, filter_kinds
+ )
- if not skip_datetime_check:
- search = search.query('range', datetime={'gte': start_datetime, 'lt': end_datetime})
+ if not skip_datetime_check:
+ search = search.query(
+ "range", datetime={"gte": start_datetime, "lt": end_datetime}
+ )
- return search
+ return search
- def _load_logs_for_day(self, logs_date, performer_id, repository_id, account_id, filter_kinds,
- after_datetime=None, after_random_id=None, size=PAGE_SIZE):
- index = self._es_client.index_name(logs_date)
- if not self._es_client.index_exists(index):
- return []
+ def _load_logs_for_day(
+ self,
+ logs_date,
+ performer_id,
+ repository_id,
+ account_id,
+ filter_kinds,
+ after_datetime=None,
+ after_random_id=None,
+ size=PAGE_SIZE,
+ ):
+ index = self._es_client.index_name(logs_date)
+ if not self._es_client.index_exists(index):
+ return []
- search = self._base_query(performer_id, repository_id, account_id, filter_kinds,
- index=index)
- search = search.sort({'datetime': 'desc'}, {'random_id.keyword': 'desc'})
- search = search.extra(size=size)
+ search = self._base_query(
+ performer_id, repository_id, account_id, filter_kinds, index=index
+ )
+ search = search.sort({"datetime": "desc"}, {"random_id.keyword": "desc"})
+ search = search.extra(size=size)
- if after_datetime is not None and after_random_id is not None:
- after_datetime_epoch_ms = epoch_ms(after_datetime)
- search = search.extra(search_after=[after_datetime_epoch_ms, after_random_id])
+ if after_datetime is not None and after_random_id is not None:
+ after_datetime_epoch_ms = epoch_ms(after_datetime)
+ search = search.extra(
+ search_after=[after_datetime_epoch_ms, after_random_id]
+ )
- return search.execute()
+ return search.execute()
- def _load_latest_logs(self, performer_id, repository_id, account_id, filter_kinds, size):
- """ Return the latest logs from Elasticsearch.
+ def _load_latest_logs(
+ self, performer_id, repository_id, account_id, filter_kinds, size
+ ):
+ """ Return the latest logs from Elasticsearch.
Look at indices up to theset logrotateworker threshold, or up to 30 days if not defined.
"""
- # Set the last index to check to be the logrotateworker threshold, or 30 days
- end_datetime = datetime.now()
- start_datetime = end_datetime - timedelta(days=DATE_RANGE_LIMIT)
+ # Set the last index to check to be the logrotateworker threshold, or 30 days
+ end_datetime = datetime.now()
+ start_datetime = end_datetime - timedelta(days=DATE_RANGE_LIMIT)
- latest_logs = []
- for day in _date_range_descending(start_datetime, end_datetime, includes_end_datetime=True):
- try:
- logs = self._load_logs_for_day(day, performer_id, repository_id, account_id, filter_kinds,
- size=size)
- latest_logs.extend(logs)
- except NotFoundError:
- continue
+ latest_logs = []
+ for day in _date_range_descending(
+ start_datetime, end_datetime, includes_end_datetime=True
+ ):
+ try:
+ logs = self._load_logs_for_day(
+ day,
+ performer_id,
+ repository_id,
+ account_id,
+ filter_kinds,
+ size=size,
+ )
+ latest_logs.extend(logs)
+ except NotFoundError:
+ continue
- if len(latest_logs) >= size:
- break
+ if len(latest_logs) >= size:
+ break
- return _for_elasticsearch_logs(latest_logs[:size], repository_id, account_id)
+ return _for_elasticsearch_logs(latest_logs[:size], repository_id, account_id)
- def lookup_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None, page_token=None, max_page_count=None):
- assert start_datetime is not None and end_datetime is not None
+ def lookup_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ page_token=None,
+ max_page_count=None,
+ ):
+ assert start_datetime is not None and end_datetime is not None
- # Check for a valid combined model token when migrating online from a combined model
- if page_token is not None and page_token.get('readwrite_page_token') is not None:
- page_token = page_token.get('readwrite_page_token')
+ # Check for a valid combined model token when migrating online from a combined model
+ if (
+ page_token is not None
+ and page_token.get("readwrite_page_token") is not None
+ ):
+ page_token = page_token.get("readwrite_page_token")
- if page_token is not None and max_page_count is not None:
- page_number = page_token.get('page_number')
- if page_number is not None and page_number + 1 > max_page_count:
- return LogEntriesPage([], None)
+ if page_token is not None and max_page_count is not None:
+ page_number = page_token.get("page_number")
+ if page_number is not None and page_number + 1 > max_page_count:
+ return LogEntriesPage([], None)
- repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
- repository_name, namespace_name, performer_name)
+ repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
+ repository_name, namespace_name, performer_name
+ )
- after_datetime = None
- after_random_id = None
- if page_token is not None:
- after_datetime = parse_datetime(page_token['datetime'])
- after_random_id = page_token['random_id']
+ after_datetime = None
+ after_random_id = None
+ if page_token is not None:
+ after_datetime = parse_datetime(page_token["datetime"])
+ after_random_id = page_token["random_id"]
- if after_datetime is not None:
- end_datetime = min(end_datetime, after_datetime)
+ if after_datetime is not None:
+ end_datetime = min(end_datetime, after_datetime)
- all_logs = []
+ all_logs = []
+
+ with CloseForLongOperation(config.app_config):
+ for current_date in _date_range_descending(start_datetime, end_datetime):
+ try:
+ logs = self._load_logs_for_day(
+ current_date,
+ performer_id,
+ repository_id,
+ account_id,
+ filter_kinds,
+ after_datetime,
+ after_random_id,
+ size=PAGE_SIZE + 1,
+ )
+
+ all_logs.extend(logs)
+ except NotFoundError:
+ continue
+
+ if len(all_logs) > PAGE_SIZE:
+ break
+
+ next_page_token = None
+ all_logs = all_logs[0 : PAGE_SIZE + 1]
+
+ if len(all_logs) == PAGE_SIZE + 1:
+ # The last element in the response is used to check if there's more elements.
+ # The second element in the response is used as the pagination token because search_after does
+ # not include the exact match, and so the next page will start with the last element.
+ # This keeps the behavior exactly the same as table_logs_model, so that
+ # the caller can expect when a pagination token is non-empty, there must be
+ # at least 1 log to be retrieved.
+ next_page_token = {
+ "datetime": all_logs[-2].datetime.isoformat(),
+ "random_id": all_logs[-2].random_id,
+ "page_number": page_token["page_number"] + 1 if page_token else 1,
+ }
+
+ return LogEntriesPage(
+ _for_elasticsearch_logs(all_logs[:PAGE_SIZE], repository_id, account_id),
+ next_page_token,
+ )
+
+ def lookup_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ size=20,
+ ):
+ repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
+ repository_name, namespace_name, performer_name
+ )
+
+ with CloseForLongOperation(config.app_config):
+ latest_logs = self._load_latest_logs(
+ performer_id, repository_id, account_id, filter_kinds, size
+ )
+
+ return latest_logs
+
+ def get_aggregated_log_counts(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ if end_datetime - start_datetime >= timedelta(days=DATE_RANGE_LIMIT):
+ raise Exception(
+ "Cannot lookup aggregated logs over a period longer than a month"
+ )
+
+ repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
+ repository_name, namespace_name, performer_name
+ )
+
+ with CloseForLongOperation(config.app_config):
+ search = self._base_query_date_range(
+ start_datetime,
+ end_datetime,
+ performer_id,
+ repository_id,
+ account_id,
+ filter_kinds,
+ )
+ search.aggs.bucket("by_id", "terms", field="kind_id").bucket(
+ "by_date", "date_histogram", field="datetime", interval="day"
+ )
+ # es returns all buckets when size=0
+ search = search.extra(size=0)
+ resp = search.execute()
+
+ if not resp.aggregations:
+ return []
+
+ counts = []
+ by_id = resp.aggregations["by_id"]
+
+ for id_bucket in by_id.buckets:
+ for date_bucket in id_bucket.by_date.buckets:
+ if date_bucket.doc_count > 0:
+ counts.append(
+ AggregatedLogCount(
+ id_bucket.key, date_bucket.doc_count, date_bucket.key
+ )
+ )
+
+ return counts
+
+ def count_repository_actions(self, repository, day):
+ index = self._es_client.index_name(day)
+ search = self._base_query_date_range(
+ day, day + timedelta(days=1), None, repository.id, None, None, index=index
+ )
+ search = search.params(request_timeout=COUNT_REPOSITORY_ACTION_TIMEOUT)
- with CloseForLongOperation(config.app_config):
- for current_date in _date_range_descending(start_datetime, end_datetime):
try:
- logs = self._load_logs_for_day(current_date, performer_id, repository_id, account_id,
- filter_kinds, after_datetime, after_random_id,
- size=PAGE_SIZE+1)
-
- all_logs.extend(logs)
+ return search.count()
except NotFoundError:
- continue
+ return 0
- if len(all_logs) > PAGE_SIZE:
- break
+ def log_action(
+ self,
+ kind_name,
+ namespace_name=None,
+ performer=None,
+ ip=None,
+ metadata=None,
+ repository=None,
+ repository_name=None,
+ timestamp=None,
+ is_free_namespace=False,
+ ):
+ if self._should_skip_logging and self._should_skip_logging(
+ kind_name, namespace_name, is_free_namespace
+ ):
+ return
- next_page_token = None
- all_logs = all_logs[0:PAGE_SIZE+1]
+ if repository_name is not None:
+ assert repository is None
+ assert namespace_name is not None
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
- if len(all_logs) == PAGE_SIZE + 1:
- # The last element in the response is used to check if there's more elements.
- # The second element in the response is used as the pagination token because search_after does
- # not include the exact match, and so the next page will start with the last element.
- # This keeps the behavior exactly the same as table_logs_model, so that
- # the caller can expect when a pagination token is non-empty, there must be
- # at least 1 log to be retrieved.
- next_page_token = {
- 'datetime': all_logs[-2].datetime.isoformat(),
- 'random_id': all_logs[-2].random_id,
- 'page_number': page_token['page_number'] + 1 if page_token else 1,
- }
+ if timestamp is None:
+ timestamp = datetime.today()
- return LogEntriesPage(_for_elasticsearch_logs(all_logs[:PAGE_SIZE], repository_id, account_id),
- next_page_token)
+ account_id = None
+ performer_id = None
+ repository_id = None
- def lookup_latest_logs(self, performer_name=None, repository_name=None, namespace_name=None,
- filter_kinds=None, size=20):
- repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
- repository_name, namespace_name, performer_name)
+ if namespace_name is not None:
+ account_id = model.user.get_namespace_user(namespace_name).id
- with CloseForLongOperation(config.app_config):
- latest_logs = self._load_latest_logs(performer_id, repository_id, account_id, filter_kinds,
- size)
+ if performer is not None:
+ performer_id = performer.id
- return latest_logs
+ if repository is not None:
+ repository_id = repository.id
+ metadata_json = json.dumps(metadata or {}, default=_json_serialize)
+ kind_id = model.log._get_log_entry_kind(kind_name)
+ log = LogEntry(
+ random_id=_random_id(),
+ kind_id=kind_id,
+ account_id=account_id,
+ performer_id=performer_id,
+ ip=ip,
+ metadata_json=metadata_json,
+ repository_id=repository_id,
+ datetime=timestamp,
+ )
- def get_aggregated_log_counts(self, start_datetime, end_datetime, performer_name=None,
- repository_name=None, namespace_name=None, filter_kinds=None):
- if end_datetime - start_datetime >= timedelta(days=DATE_RANGE_LIMIT):
- raise Exception('Cannot lookup aggregated logs over a period longer than a month')
+ try:
+ self._logs_producer.send(log)
+ except LogSendException as lse:
+ strict_logging_disabled = config.app_config.get(
+ "ALLOW_PULLS_WITHOUT_STRICT_LOGGING"
+ )
+ logger.exception(
+ "log_action failed", extra=({"exception": lse}).update(log.to_dict())
+ )
+ if not (
+ strict_logging_disabled
+ and kind_name in ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING
+ ):
+ raise
- repository_id, account_id, performer_id = DocumentLogsModel._get_ids_by_names(
- repository_name, namespace_name, performer_name)
+ def yield_logs_for_export(
+ self,
+ start_datetime,
+ end_datetime,
+ repository_id=None,
+ namespace_id=None,
+ max_query_time=None,
+ ):
+ max_query_time = (
+ max_query_time.total_seconds() if max_query_time is not None else 300
+ )
+ search = self._base_query_date_range(
+ start_datetime, end_datetime, None, repository_id, namespace_id, None
+ )
- with CloseForLongOperation(config.app_config):
- search = self._base_query_date_range(start_datetime, end_datetime, performer_id,
- repository_id, account_id, filter_kinds)
- search.aggs.bucket('by_id', 'terms', field='kind_id').bucket('by_date', 'date_histogram',
- field='datetime', interval='day')
- # es returns all buckets when size=0
- search = search.extra(size=0)
- resp = search.execute()
+ def raise_on_timeout(batch_generator):
+ start = time()
+ for batch in batch_generator:
+ elapsed = time() - start
+ if elapsed > max_query_time:
+ logger.error(
+ "Retrieval of logs `%s/%s` timed out with time of `%s`",
+ namespace_id,
+ repository_id,
+ elapsed,
+ )
+ raise LogsIterationTimeout()
- if not resp.aggregations:
- return []
+ yield batch
+ start = time()
- counts = []
- by_id = resp.aggregations['by_id']
+ def read_batch(scroll):
+ batch = []
+ for log in scroll:
+ batch.append(log)
+ if len(batch) == DEFAULT_RESULT_WINDOW:
+ yield _for_elasticsearch_logs(
+ batch, repository_id=repository_id, namespace_id=namespace_id
+ )
+ batch = []
- for id_bucket in by_id.buckets:
- for date_bucket in id_bucket.by_date.buckets:
- if date_bucket.doc_count > 0:
- counts.append(AggregatedLogCount(id_bucket.key, date_bucket.doc_count, date_bucket.key))
+ if batch:
+ yield _for_elasticsearch_logs(
+ batch, repository_id=repository_id, namespace_id=namespace_id
+ )
- return counts
+ search = search.params(
+ size=DEFAULT_RESULT_WINDOW, request_timeout=max_query_time
+ )
- def count_repository_actions(self, repository, day):
- index = self._es_client.index_name(day)
- search = self._base_query_date_range(day, day + timedelta(days=1),
- None,
- repository.id,
- None,
- None,
- index=index)
- search = search.params(request_timeout=COUNT_REPOSITORY_ACTION_TIMEOUT)
+ try:
+ with CloseForLongOperation(config.app_config):
+ for batch in raise_on_timeout(read_batch(search.scan())):
+ yield batch
+ except ConnectionTimeout:
+ raise LogsIterationTimeout()
- try:
- return search.count()
- except NotFoundError:
- return 0
+ def can_delete_index(self, index, cutoff_date):
+ return self._es_client.can_delete_index(index, cutoff_date)
- def log_action(self, kind_name, namespace_name=None, performer=None, ip=None, metadata=None,
- repository=None, repository_name=None, timestamp=None, is_free_namespace=False):
- if self._should_skip_logging and self._should_skip_logging(kind_name, namespace_name,
- is_free_namespace):
- return
+ def list_indices(self):
+ return self._es_client.list_indices()
- if repository_name is not None:
- assert repository is None
- assert namespace_name is not None
- repository = model.repository.get_repository(namespace_name, repository_name)
+ def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
+ """ Yield a context manager for a group of outdated logs. """
+ all_indices = self.list_indices()
+ for index in all_indices:
+ if not self.can_delete_index(index, cutoff_date):
+ continue
- if timestamp is None:
- timestamp = datetime.today()
-
- account_id = None
- performer_id = None
- repository_id = None
-
- if namespace_name is not None:
- account_id = model.user.get_namespace_user(namespace_name).id
-
- if performer is not None:
- performer_id = performer.id
-
- if repository is not None:
- repository_id = repository.id
-
- metadata_json = json.dumps(metadata or {}, default=_json_serialize)
- kind_id = model.log._get_log_entry_kind(kind_name)
- log = LogEntry(random_id=_random_id(), kind_id=kind_id, account_id=account_id,
- performer_id=performer_id, ip=ip, metadata_json=metadata_json,
- repository_id=repository_id, datetime=timestamp)
-
- try:
- self._logs_producer.send(log)
- except LogSendException as lse:
- strict_logging_disabled = config.app_config.get('ALLOW_PULLS_WITHOUT_STRICT_LOGGING')
- logger.exception('log_action failed', extra=({'exception': lse}).update(log.to_dict()))
- if not (strict_logging_disabled and kind_name in ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING):
- raise
-
- def yield_logs_for_export(self, start_datetime, end_datetime, repository_id=None,
- namespace_id=None, max_query_time=None):
- max_query_time = max_query_time.total_seconds() if max_query_time is not None else 300
- search = self._base_query_date_range(start_datetime, end_datetime, None, repository_id,
- namespace_id, None)
-
- def raise_on_timeout(batch_generator):
- start = time()
- for batch in batch_generator:
- elapsed = time() - start
- if elapsed > max_query_time:
- logger.error('Retrieval of logs `%s/%s` timed out with time of `%s`', namespace_id,
- repository_id, elapsed)
- raise LogsIterationTimeout()
-
- yield batch
- start = time()
-
- def read_batch(scroll):
- batch = []
- for log in scroll:
- batch.append(log)
- if len(batch) == DEFAULT_RESULT_WINDOW:
- yield _for_elasticsearch_logs(batch, repository_id=repository_id,
- namespace_id=namespace_id)
- batch = []
-
- if batch:
- yield _for_elasticsearch_logs(batch, repository_id=repository_id, namespace_id=namespace_id)
-
- search = search.params(size=DEFAULT_RESULT_WINDOW, request_timeout=max_query_time)
-
- try:
- with CloseForLongOperation(config.app_config):
- for batch in raise_on_timeout(read_batch(search.scan())):
- yield batch
- except ConnectionTimeout:
- raise LogsIterationTimeout()
-
- def can_delete_index(self, index, cutoff_date):
- return self._es_client.can_delete_index(index, cutoff_date)
-
- def list_indices(self):
- return self._es_client.list_indices()
-
- def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
- """ Yield a context manager for a group of outdated logs. """
- all_indices = self.list_indices()
- for index in all_indices:
- if not self.can_delete_index(index, cutoff_date):
- continue
-
- context = ElasticsearchLogRotationContext(index, min_logs_per_rotation, self._es_client)
- yield context
+ context = ElasticsearchLogRotationContext(
+ index, min_logs_per_rotation, self._es_client
+ )
+ yield context
class ElasticsearchLogRotationContext(LogRotationContextInterface):
- """
+ """
ElasticsearchLogRotationContext yield batch of logs from an index.
When completed without exceptions, this context will delete its associated
Elasticsearch index.
"""
- def __init__(self, index, min_logs_per_rotation, es_client):
- self._es_client = es_client
- self.min_logs_per_rotation = min_logs_per_rotation
- self.index = index
- self.start_pos = 0
- self.end_pos = 0
+ def __init__(self, index, min_logs_per_rotation, es_client):
+ self._es_client = es_client
+ self.min_logs_per_rotation = min_logs_per_rotation
+ self.index = index
- self.scroll = None
+ self.start_pos = 0
+ self.end_pos = 0
- def __enter__(self):
- search = self._base_query()
- self.scroll = search.scan()
- return self
+ self.scroll = None
- def __exit__(self, ex_type, ex_value, ex_traceback):
- if ex_type is None and ex_value is None and ex_traceback is None:
- logger.debug('Deleting index %s', self.index)
- self._es_client.delete_index(self.index)
+ def __enter__(self):
+ search = self._base_query()
+ self.scroll = search.scan()
+ return self
- def yield_logs_batch(self):
- def batched_logs(gen, size):
- batch = []
- for log in gen:
- batch.append(log)
- if len(batch) == size:
- yield batch
- batch = []
+ def __exit__(self, ex_type, ex_value, ex_traceback):
+ if ex_type is None and ex_value is None and ex_traceback is None:
+ logger.debug("Deleting index %s", self.index)
+ self._es_client.delete_index(self.index)
- if batch:
- yield batch
+ def yield_logs_batch(self):
+ def batched_logs(gen, size):
+ batch = []
+ for log in gen:
+ batch.append(log)
+ if len(batch) == size:
+ yield batch
+ batch = []
- for batch in batched_logs(self.scroll, self.min_logs_per_rotation):
- self.end_pos = self.start_pos + len(batch) - 1
- yield batch, self._generate_filename()
- self.start_pos = self.end_pos + 1
+ if batch:
+ yield batch
- def _base_query(self):
- search = LogEntry.search(index=self.index)
- return search
+ for batch in batched_logs(self.scroll, self.min_logs_per_rotation):
+ self.end_pos = self.start_pos + len(batch) - 1
+ yield batch, self._generate_filename()
+ self.start_pos = self.end_pos + 1
- def _generate_filename(self):
- """ Generate the filenames used to archive the action logs. """
- filename = '%s_%d-%d' % (self.index, self.start_pos, self.end_pos)
- filename = '.'.join((filename, 'txt.gz'))
- return filename
+ def _base_query(self):
+ search = LogEntry.search(index=self.index)
+ return search
+
+ def _generate_filename(self):
+ """ Generate the filenames used to archive the action logs. """
+ filename = "%s_%d-%d" % (self.index, self.start_pos, self.end_pos)
+ filename = ".".join((filename, "txt.gz"))
+ return filename
diff --git a/data/logs_model/elastic_logs.py b/data/logs_model/elastic_logs.py
index cd3ff675d..e6147aa68 100644
--- a/data/logs_model/elastic_logs.py
+++ b/data/logs_model/elastic_logs.py
@@ -14,13 +14,13 @@ from elasticsearch_dsl.connections import connections
logger = logging.getLogger(__name__)
# Name of the connection used for Elasticearch's template API
-ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS = 'logentry_template'
+ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS = "logentry_template"
# Prefix of autogenerated indices
-INDEX_NAME_PREFIX = 'logentry_'
+INDEX_NAME_PREFIX = "logentry_"
# Time-based index date format
-INDEX_DATE_FORMAT = '%Y-%m-%d'
+INDEX_DATE_FORMAT = "%Y-%m-%d"
# Timeout for default connection
ELASTICSEARCH_DEFAULT_CONNECTION_TIMEOUT = 15
@@ -29,227 +29,265 @@ ELASTICSEARCH_DEFAULT_CONNECTION_TIMEOUT = 15
ELASTICSEARCH_TEMPLATE_CONNECTION_TIMEOUT = 60
# Force an index template update
-ELASTICSEARCH_FORCE_INDEX_TEMPLATE_UPDATE = os.environ.get('FORCE_INDEX_TEMPLATE_UPDATE', '')
+ELASTICSEARCH_FORCE_INDEX_TEMPLATE_UPDATE = os.environ.get(
+ "FORCE_INDEX_TEMPLATE_UPDATE", ""
+)
# Valid index prefix pattern
-VALID_INDEX_PATTERN = r'^((?!\.$|\.\.$|[-_+])([^A-Z:\/*?\"<>|,# ]){1,255})$'
+VALID_INDEX_PATTERN = r"^((?!\.$|\.\.$|[-_+])([^A-Z:\/*?\"<>|,# ]){1,255})$"
class LogEntry(Document):
- # random_id is the tie-breaker for sorting in pagination.
- # random_id is also used for deduplication of records when using a "at-least-once" delivery stream.
- # Reference: https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-search-after.html
- #
- # We use don't use the _id of a document since a `doc_values` is not build for this field:
- # An on-disk data structure that stores the same data in a columnar format
- # for optimized sorting and aggregations.
- # Reference: https://github.com/elastic/elasticsearch/issues/35369
- random_id = Text(fields={'keyword': Keyword()})
- kind_id = Integer()
- account_id = Integer()
- performer_id = Integer()
- repository_id = Integer()
- ip = Ip()
- metadata_json = Text()
- datetime = Date()
+ # random_id is the tie-breaker for sorting in pagination.
+ # random_id is also used for deduplication of records when using a "at-least-once" delivery stream.
+ # Reference: https://www.elastic.co/guide/en/elasticsearch/reference/current/search-request-search-after.html
+ #
+ # We use don't use the _id of a document since a `doc_values` is not build for this field:
+ # An on-disk data structure that stores the same data in a columnar format
+ # for optimized sorting and aggregations.
+ # Reference: https://github.com/elastic/elasticsearch/issues/35369
+ random_id = Text(fields={"keyword": Keyword()})
+ kind_id = Integer()
+ account_id = Integer()
+ performer_id = Integer()
+ repository_id = Integer()
+ ip = Ip()
+ metadata_json = Text()
+ datetime = Date()
- _initialized = False
+ _initialized = False
- @classmethod
- def init(cls, index_prefix, index_settings=None, skip_template_init=False):
- """
+ @classmethod
+ def init(cls, index_prefix, index_settings=None, skip_template_init=False):
+ """
Create the index template, and populate LogEntry's mapping and index settings.
"""
- wildcard_index = Index(name=index_prefix + '*')
- wildcard_index.settings(**(index_settings or {}))
- wildcard_index.document(cls)
- cls._index = wildcard_index
- cls._index_prefix = index_prefix
+ wildcard_index = Index(name=index_prefix + "*")
+ wildcard_index.settings(**(index_settings or {}))
+ wildcard_index.document(cls)
+ cls._index = wildcard_index
+ cls._index_prefix = index_prefix
- if not skip_template_init:
- cls.create_or_update_template()
+ if not skip_template_init:
+ cls.create_or_update_template()
- # Since the elasticsearch-dsl API requires the document's index being defined as an inner class at the class level,
- # this function needs to be called first before being able to call `save`.
- cls._initialized = True
+ # Since the elasticsearch-dsl API requires the document's index being defined as an inner class at the class level,
+ # this function needs to be called first before being able to call `save`.
+ cls._initialized = True
- @classmethod
- def create_or_update_template(cls):
- assert cls._index and cls._index_prefix
- index_template = cls._index.as_template(cls._index_prefix)
- index_template.save(using=ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS)
+ @classmethod
+ def create_or_update_template(cls):
+ assert cls._index and cls._index_prefix
+ index_template = cls._index.as_template(cls._index_prefix)
+ index_template.save(using=ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS)
- def save(self, **kwargs):
- # We group the logs based on year, month and day as different indexes, so that
- # dropping those indexes based on retention range is easy.
- #
- # NOTE: This is only used if logging directly to Elasticsearch
- # When using Kinesis or Kafka, the consumer of these streams
- # will be responsible for the management of the indices' lifecycle.
- assert LogEntry._initialized
- kwargs['index'] = self.datetime.strftime(self._index_prefix + INDEX_DATE_FORMAT)
- return super(LogEntry, self).save(**kwargs)
+ def save(self, **kwargs):
+ # We group the logs based on year, month and day as different indexes, so that
+ # dropping those indexes based on retention range is easy.
+ #
+ # NOTE: This is only used if logging directly to Elasticsearch
+ # When using Kinesis or Kafka, the consumer of these streams
+ # will be responsible for the management of the indices' lifecycle.
+ assert LogEntry._initialized
+ kwargs["index"] = self.datetime.strftime(self._index_prefix + INDEX_DATE_FORMAT)
+ return super(LogEntry, self).save(**kwargs)
class ElasticsearchLogs(object):
- """
+ """
Model for logs operations stored in an Elasticsearch cluster.
"""
- def __init__(self, host=None, port=None, access_key=None, secret_key=None, aws_region=None,
- index_settings=None, use_ssl=True, index_prefix=INDEX_NAME_PREFIX):
- # For options in index_settings, refer to:
- # https://www.elastic.co/guide/en/elasticsearch/guide/master/_index_settings.html
- # some index settings are set at index creation time, and therefore, you should NOT
- # change those settings once the index is set.
- self._host = host
- self._port = port
- self._access_key = access_key
- self._secret_key = secret_key
- self._aws_region = aws_region
- self._index_prefix = index_prefix
- self._index_settings = index_settings
- self._use_ssl = use_ssl
+ def __init__(
+ self,
+ host=None,
+ port=None,
+ access_key=None,
+ secret_key=None,
+ aws_region=None,
+ index_settings=None,
+ use_ssl=True,
+ index_prefix=INDEX_NAME_PREFIX,
+ ):
+ # For options in index_settings, refer to:
+ # https://www.elastic.co/guide/en/elasticsearch/guide/master/_index_settings.html
+ # some index settings are set at index creation time, and therefore, you should NOT
+ # change those settings once the index is set.
+ self._host = host
+ self._port = port
+ self._access_key = access_key
+ self._secret_key = secret_key
+ self._aws_region = aws_region
+ self._index_prefix = index_prefix
+ self._index_settings = index_settings
+ self._use_ssl = use_ssl
- self._client = None
- self._initialized = False
+ self._client = None
+ self._initialized = False
- def _initialize(self):
- """
+ def _initialize(self):
+ """
Initialize a connection to an ES cluster and
creates an index template if it does not exist.
"""
- if not self._initialized:
- http_auth = None
- if self._access_key and self._secret_key and self._aws_region:
- http_auth = AWS4Auth(self._access_key, self._secret_key, self._aws_region, 'es')
- elif self._access_key and self._secret_key:
- http_auth = (self._access_key, self._secret_key)
- else:
- logger.warn("Connecting to Elasticsearch without HTTP auth")
+ if not self._initialized:
+ http_auth = None
+ if self._access_key and self._secret_key and self._aws_region:
+ http_auth = AWS4Auth(
+ self._access_key, self._secret_key, self._aws_region, "es"
+ )
+ elif self._access_key and self._secret_key:
+ http_auth = (self._access_key, self._secret_key)
+ else:
+ logger.warn("Connecting to Elasticsearch without HTTP auth")
- self._client = connections.create_connection(
- hosts=[{
- 'host': self._host,
- 'port': self._port
- }],
- http_auth=http_auth,
- use_ssl=self._use_ssl,
- verify_certs=True,
- connection_class=RequestsHttpConnection,
- timeout=ELASTICSEARCH_DEFAULT_CONNECTION_TIMEOUT,
- )
+ self._client = connections.create_connection(
+ hosts=[{"host": self._host, "port": self._port}],
+ http_auth=http_auth,
+ use_ssl=self._use_ssl,
+ verify_certs=True,
+ connection_class=RequestsHttpConnection,
+ timeout=ELASTICSEARCH_DEFAULT_CONNECTION_TIMEOUT,
+ )
- # Create a second connection with a timeout of 60s vs 10s.
- # For some reason the PUT template API can take anywhere between
- # 10s and 30s on the test cluster.
- # This only needs to be done once to initialize the index template
- connections.create_connection(
- alias=ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS,
- hosts=[{
- 'host': self._host,
- 'port': self._port
- }],
- http_auth=http_auth,
- use_ssl=self._use_ssl,
- verify_certs=True,
- connection_class=RequestsHttpConnection,
- timeout=ELASTICSEARCH_TEMPLATE_CONNECTION_TIMEOUT,
- )
+ # Create a second connection with a timeout of 60s vs 10s.
+ # For some reason the PUT template API can take anywhere between
+ # 10s and 30s on the test cluster.
+ # This only needs to be done once to initialize the index template
+ connections.create_connection(
+ alias=ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS,
+ hosts=[{"host": self._host, "port": self._port}],
+ http_auth=http_auth,
+ use_ssl=self._use_ssl,
+ verify_certs=True,
+ connection_class=RequestsHttpConnection,
+ timeout=ELASTICSEARCH_TEMPLATE_CONNECTION_TIMEOUT,
+ )
- try:
- force_template_update = ELASTICSEARCH_FORCE_INDEX_TEMPLATE_UPDATE.lower() == 'true'
- self._client.indices.get_template(self._index_prefix)
- LogEntry.init(self._index_prefix, self._index_settings,
- skip_template_init=not force_template_update)
- except NotFoundError:
- LogEntry.init(self._index_prefix, self._index_settings, skip_template_init=False)
- finally:
+ try:
+ force_template_update = (
+ ELASTICSEARCH_FORCE_INDEX_TEMPLATE_UPDATE.lower() == "true"
+ )
+ self._client.indices.get_template(self._index_prefix)
+ LogEntry.init(
+ self._index_prefix,
+ self._index_settings,
+ skip_template_init=not force_template_update,
+ )
+ except NotFoundError:
+ LogEntry.init(
+ self._index_prefix, self._index_settings, skip_template_init=False
+ )
+ finally:
+ try:
+ connections.remove_connection(
+ ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS
+ )
+ except KeyError as ke:
+ logger.exception(
+ "Elasticsearch connection not found to remove %s: %s",
+ ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS,
+ ke,
+ )
+
+ self._initialized = True
+
+ def index_name(self, day):
+ """ Return an index name for the given day. """
+ return self._index_prefix + day.strftime(INDEX_DATE_FORMAT)
+
+ def index_exists(self, index):
try:
- connections.remove_connection(ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS)
- except KeyError as ke:
- logger.exception('Elasticsearch connection not found to remove %s: %s',
- ELASTICSEARCH_TEMPLATE_CONNECTION_ALIAS, ke)
+ return index in self._client.indices.get(index)
+ except NotFoundError:
+ return False
- self._initialized = True
-
- def index_name(self, day):
- """ Return an index name for the given day. """
- return self._index_prefix + day.strftime(INDEX_DATE_FORMAT)
-
- def index_exists(self, index):
- try:
- return index in self._client.indices.get(index)
- except NotFoundError:
- return False
-
- @staticmethod
- def _valid_index_prefix(prefix):
- """ Check that the given index prefix is valid with the set of
+ @staticmethod
+ def _valid_index_prefix(prefix):
+ """ Check that the given index prefix is valid with the set of
indices used by this class.
"""
- return re.match(VALID_INDEX_PATTERN, prefix) is not None
+ return re.match(VALID_INDEX_PATTERN, prefix) is not None
- def _valid_index_name(self, index):
- """ Check that the given index name is valid and follows the format:
+ def _valid_index_name(self, index):
+ """ Check that the given index name is valid and follows the format:
YYYY-MM-DD
"""
- if not ElasticsearchLogs._valid_index_prefix(index):
- return False
+ if not ElasticsearchLogs._valid_index_prefix(index):
+ return False
- if not index.startswith(self._index_prefix) or len(index) > 255:
- return False
+ if not index.startswith(self._index_prefix) or len(index) > 255:
+ return False
- index_dt_str = index.split(self._index_prefix, 1)[-1]
- try:
- datetime.strptime(index_dt_str, INDEX_DATE_FORMAT)
- return True
- except ValueError:
- logger.exception('Invalid date format (YYYY-MM-DD) for index: %s', index)
- return False
+ index_dt_str = index.split(self._index_prefix, 1)[-1]
+ try:
+ datetime.strptime(index_dt_str, INDEX_DATE_FORMAT)
+ return True
+ except ValueError:
+ logger.exception("Invalid date format (YYYY-MM-DD) for index: %s", index)
+ return False
- def can_delete_index(self, index, cutoff_date):
- """ Check if the given index can be deleted based on the given index's date and cutoff date. """
- assert self._valid_index_name(index)
- index_dt = datetime.strptime(index[len(self._index_prefix):], INDEX_DATE_FORMAT)
- return index_dt < cutoff_date and cutoff_date - index_dt >= timedelta(days=1)
+ def can_delete_index(self, index, cutoff_date):
+ """ Check if the given index can be deleted based on the given index's date and cutoff date. """
+ assert self._valid_index_name(index)
+ index_dt = datetime.strptime(
+ index[len(self._index_prefix) :], INDEX_DATE_FORMAT
+ )
+ return index_dt < cutoff_date and cutoff_date - index_dt >= timedelta(days=1)
- def list_indices(self):
- self._initialize()
- try:
- return self._client.indices.get(self._index_prefix + '*').keys()
- except NotFoundError as nfe:
- logger.exception('`%s` indices not found: %s', self._index_prefix, nfe.info)
- return []
- except AuthorizationException as ae:
- logger.exception('Unauthorized for indices `%s`: %s', self._index_prefix, ae.info)
- return None
+ def list_indices(self):
+ self._initialize()
+ try:
+ return self._client.indices.get(self._index_prefix + "*").keys()
+ except NotFoundError as nfe:
+ logger.exception("`%s` indices not found: %s", self._index_prefix, nfe.info)
+ return []
+ except AuthorizationException as ae:
+ logger.exception(
+ "Unauthorized for indices `%s`: %s", self._index_prefix, ae.info
+ )
+ return None
- def delete_index(self, index):
- self._initialize()
- assert self._valid_index_name(index)
+ def delete_index(self, index):
+ self._initialize()
+ assert self._valid_index_name(index)
- try:
- self._client.indices.delete(index)
- return index
- except NotFoundError as nfe:
- logger.exception('`%s` indices not found: %s', index, nfe.info)
- return None
- except AuthorizationException as ae:
- logger.exception('Unauthorized to delete index `%s`: %s', index, ae.info)
- return None
+ try:
+ self._client.indices.delete(index)
+ return index
+ except NotFoundError as nfe:
+ logger.exception("`%s` indices not found: %s", index, nfe.info)
+ return None
+ except AuthorizationException as ae:
+ logger.exception("Unauthorized to delete index `%s`: %s", index, ae.info)
+ return None
-def configure_es(host, port, access_key=None, secret_key=None, aws_region=None,
- index_prefix=None, use_ssl=True, index_settings=None):
- """
+def configure_es(
+ host,
+ port,
+ access_key=None,
+ secret_key=None,
+ aws_region=None,
+ index_prefix=None,
+ use_ssl=True,
+ index_settings=None,
+):
+ """
For options in index_settings, refer to:
https://www.elastic.co/guide/en/elasticsearch/guide/master/_index_settings.html
some index settings are set at index creation time, and therefore, you should NOT
change those settings once the index is set.
"""
- es_client = ElasticsearchLogs(host=host, port=port, access_key=access_key, secret_key=secret_key,
- aws_region=aws_region, index_prefix=index_prefix or INDEX_NAME_PREFIX,
- use_ssl=use_ssl, index_settings=index_settings)
- es_client._initialize()
- return es_client
+ es_client = ElasticsearchLogs(
+ host=host,
+ port=port,
+ access_key=access_key,
+ secret_key=secret_key,
+ aws_region=aws_region,
+ index_prefix=index_prefix or INDEX_NAME_PREFIX,
+ use_ssl=use_ssl,
+ index_settings=index_settings,
+ )
+ es_client._initialize()
+ return es_client
diff --git a/data/logs_model/inmemory_model.py b/data/logs_model/inmemory_model.py
index f9a219f51..2a2da020c 100644
--- a/data/logs_model/inmemory_model.py
+++ b/data/logs_model/inmemory_model.py
@@ -8,237 +8,345 @@ from dateutil.relativedelta import relativedelta
from data import model
from data.logs_model.datatypes import AggregatedLogCount, LogEntriesPage, Log
-from data.logs_model.interface import (ActionLogsDataInterface, LogRotationContextInterface,
- LogsIterationTimeout)
+from data.logs_model.interface import (
+ ActionLogsDataInterface,
+ LogRotationContextInterface,
+ LogsIterationTimeout,
+)
logger = logging.getLogger(__name__)
-LogAndRepository = namedtuple('LogAndRepository', ['log', 'stored_log', 'repository'])
+LogAndRepository = namedtuple("LogAndRepository", ["log", "stored_log", "repository"])
+
+StoredLog = namedtuple(
+ "StoredLog",
+ [
+ "kind_id",
+ "account_id",
+ "performer_id",
+ "ip",
+ "metadata_json",
+ "repository_id",
+ "datetime",
+ ],
+)
-StoredLog = namedtuple('StoredLog', ['kind_id',
- 'account_id',
- 'performer_id',
- 'ip',
- 'metadata_json',
- 'repository_id',
- 'datetime'])
class InMemoryModel(ActionLogsDataInterface):
- """
+ """
InMemoryModel implements the data model for logs in-memory. FOR TESTING ONLY.
"""
- def __init__(self):
- self.logs = []
- def _filter_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ def __init__(self):
+ self.logs = []
- for log_and_repo in self.logs:
- if log_and_repo.log.datetime < start_datetime or log_and_repo.log.datetime > end_datetime:
- continue
+ def _filter_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- if performer_name and log_and_repo.log.performer_username != performer_name:
- continue
+ for log_and_repo in self.logs:
+ if (
+ log_and_repo.log.datetime < start_datetime
+ or log_and_repo.log.datetime > end_datetime
+ ):
+ continue
- if (repository_name and
- (not log_and_repo.repository or log_and_repo.repository.name != repository_name)):
- continue
+ if performer_name and log_and_repo.log.performer_username != performer_name:
+ continue
- if namespace_name and log_and_repo.log.account_username != namespace_name:
- continue
+ if repository_name and (
+ not log_and_repo.repository
+ or log_and_repo.repository.name != repository_name
+ ):
+ continue
- if filter_kinds:
- kind_map = model.log.get_log_entry_kinds()
- ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
- if log_and_repo.log.kind_id in ignore_ids:
- continue
+ if namespace_name and log_and_repo.log.account_username != namespace_name:
+ continue
- yield log_and_repo
+ if filter_kinds:
+ kind_map = model.log.get_log_entry_kinds()
+ ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
+ if log_and_repo.log.kind_id in ignore_ids:
+ continue
- def _filter_latest_logs(self, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ yield log_and_repo
- for log_and_repo in sorted(self.logs, key=lambda t: t.log.datetime, reverse=True):
- if performer_name and log_and_repo.log.performer_username != performer_name:
- continue
+ def _filter_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- if (repository_name and
- (not log_and_repo.repository or log_and_repo.repository.name != repository_name)):
- continue
+ for log_and_repo in sorted(
+ self.logs, key=lambda t: t.log.datetime, reverse=True
+ ):
+ if performer_name and log_and_repo.log.performer_username != performer_name:
+ continue
- if namespace_name and log_and_repo.log.account_username != namespace_name:
- continue
+ if repository_name and (
+ not log_and_repo.repository
+ or log_and_repo.repository.name != repository_name
+ ):
+ continue
- if filter_kinds:
- kind_map = model.log.get_log_entry_kinds()
- ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
- if log_and_repo.log.kind_id in ignore_ids:
- continue
+ if namespace_name and log_and_repo.log.account_username != namespace_name:
+ continue
- yield log_and_repo
+ if filter_kinds:
+ kind_map = model.log.get_log_entry_kinds()
+ ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds]
+ if log_and_repo.log.kind_id in ignore_ids:
+ continue
- def lookup_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None, page_token=None, max_page_count=None):
- logs = []
- for log_and_repo in self._filter_logs(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds):
- logs.append(log_and_repo.log)
- return LogEntriesPage(logs, None)
+ yield log_and_repo
- def lookup_latest_logs(self, performer_name=None, repository_name=None, namespace_name=None,
- filter_kinds=None, size=20):
- latest_logs = []
- for log_and_repo in self._filter_latest_logs(performer_name, repository_name, namespace_name,
- filter_kinds):
- if size is not None and len(latest_logs) == size:
- break
+ def lookup_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ page_token=None,
+ max_page_count=None,
+ ):
+ logs = []
+ for log_and_repo in self._filter_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ ):
+ logs.append(log_and_repo.log)
+ return LogEntriesPage(logs, None)
- latest_logs.append(log_and_repo.log)
+ def lookup_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ size=20,
+ ):
+ latest_logs = []
+ for log_and_repo in self._filter_latest_logs(
+ performer_name, repository_name, namespace_name, filter_kinds
+ ):
+ if size is not None and len(latest_logs) == size:
+ break
- return latest_logs
+ latest_logs.append(log_and_repo.log)
- def get_aggregated_log_counts(self, start_datetime, end_datetime, performer_name=None,
- repository_name=None, namespace_name=None, filter_kinds=None):
- entries = {}
- for log_and_repo in self._filter_logs(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds):
- entry = log_and_repo.log
- synthetic_date = datetime(start_datetime.year, start_datetime.month, int(entry.datetime.day),
- tzinfo=get_localzone())
- if synthetic_date.day < start_datetime.day:
- synthetic_date = synthetic_date + relativedelta(months=1)
+ return latest_logs
- key = '%s-%s' % (entry.kind_id, entry.datetime.day)
+ def get_aggregated_log_counts(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ entries = {}
+ for log_and_repo in self._filter_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ ):
+ entry = log_and_repo.log
+ synthetic_date = datetime(
+ start_datetime.year,
+ start_datetime.month,
+ int(entry.datetime.day),
+ tzinfo=get_localzone(),
+ )
+ if synthetic_date.day < start_datetime.day:
+ synthetic_date = synthetic_date + relativedelta(months=1)
- if key in entries:
- entries[key] = AggregatedLogCount(entry.kind_id, entries[key].count + 1,
- synthetic_date)
- else:
- entries[key] = AggregatedLogCount(entry.kind_id, 1, synthetic_date)
+ key = "%s-%s" % (entry.kind_id, entry.datetime.day)
- return entries.values()
+ if key in entries:
+ entries[key] = AggregatedLogCount(
+ entry.kind_id, entries[key].count + 1, synthetic_date
+ )
+ else:
+ entries[key] = AggregatedLogCount(entry.kind_id, 1, synthetic_date)
- def count_repository_actions(self, repository, day):
- count = 0
- for log_and_repo in self.logs:
- if log_and_repo.repository != repository:
- continue
+ return entries.values()
- if log_and_repo.log.datetime.day != day.day:
- continue
+ def count_repository_actions(self, repository, day):
+ count = 0
+ for log_and_repo in self.logs:
+ if log_and_repo.repository != repository:
+ continue
- count += 1
+ if log_and_repo.log.datetime.day != day.day:
+ continue
- return count
+ count += 1
- def queue_logs_export(self, start_datetime, end_datetime, export_action_logs_queue,
- namespace_name=None, repository_name=None, callback_url=None,
- callback_email=None, filter_kinds=None):
- raise NotImplementedError
+ return count
- def log_action(self, kind_name, namespace_name=None, performer=None, ip=None, metadata=None,
- repository=None, repository_name=None, timestamp=None, is_free_namespace=False):
- timestamp = timestamp or datetime.today()
+ def queue_logs_export(
+ self,
+ start_datetime,
+ end_datetime,
+ export_action_logs_queue,
+ namespace_name=None,
+ repository_name=None,
+ callback_url=None,
+ callback_email=None,
+ filter_kinds=None,
+ ):
+ raise NotImplementedError
- if not repository and repository_name and namespace_name:
- repository = model.repository.get_repository(namespace_name, repository_name)
+ def log_action(
+ self,
+ kind_name,
+ namespace_name=None,
+ performer=None,
+ ip=None,
+ metadata=None,
+ repository=None,
+ repository_name=None,
+ timestamp=None,
+ is_free_namespace=False,
+ ):
+ timestamp = timestamp or datetime.today()
- account = None
- account_id = None
- performer_id = None
- repository_id = None
+ if not repository and repository_name and namespace_name:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
- if namespace_name is not None:
- account = model.user.get_namespace_user(namespace_name)
- account_id = account.id
+ account = None
+ account_id = None
+ performer_id = None
+ repository_id = None
- if performer is not None:
- performer_id = performer.id
+ if namespace_name is not None:
+ account = model.user.get_namespace_user(namespace_name)
+ account_id = account.id
- if repository is not None:
- repository_id = repository.id
+ if performer is not None:
+ performer_id = performer.id
- metadata_json = json.dumps(metadata or {})
- kind_id = model.log.get_log_entry_kinds()[kind_name]
+ if repository is not None:
+ repository_id = repository.id
- stored_log = StoredLog(
- kind_id,
- account_id,
- performer_id,
- ip,
- metadata_json,
- repository_id,
- timestamp
- )
+ metadata_json = json.dumps(metadata or {})
+ kind_id = model.log.get_log_entry_kinds()[kind_name]
- log = Log(metadata_json=metadata,
- ip=ip,
- datetime=timestamp,
- performer_email=performer.email if performer else None,
- performer_username=performer.username if performer else None,
- performer_robot=performer.robot if performer else None,
- account_organization=account.organization if account else None,
- account_username=account.username if account else None,
- account_email=account.email if account else None,
- account_robot=account.robot if account else None,
- kind_id=kind_id)
+ stored_log = StoredLog(
+ kind_id,
+ account_id,
+ performer_id,
+ ip,
+ metadata_json,
+ repository_id,
+ timestamp,
+ )
- self.logs.append(LogAndRepository(log, stored_log, repository))
+ log = Log(
+ metadata_json=metadata,
+ ip=ip,
+ datetime=timestamp,
+ performer_email=performer.email if performer else None,
+ performer_username=performer.username if performer else None,
+ performer_robot=performer.robot if performer else None,
+ account_organization=account.organization if account else None,
+ account_username=account.username if account else None,
+ account_email=account.email if account else None,
+ account_robot=account.robot if account else None,
+ kind_id=kind_id,
+ )
- def yield_logs_for_export(self, start_datetime, end_datetime, repository_id=None,
- namespace_id=None, max_query_time=None):
- # Just for testing.
- if max_query_time is not None:
- raise LogsIterationTimeout()
+ self.logs.append(LogAndRepository(log, stored_log, repository))
- logs = []
- for log_and_repo in self._filter_logs(start_datetime, end_datetime):
- if (repository_id and
- (not log_and_repo.repository or log_and_repo.repository.id != repository_id)):
- continue
+ def yield_logs_for_export(
+ self,
+ start_datetime,
+ end_datetime,
+ repository_id=None,
+ namespace_id=None,
+ max_query_time=None,
+ ):
+ # Just for testing.
+ if max_query_time is not None:
+ raise LogsIterationTimeout()
- if namespace_id:
- if log_and_repo.log.account_username is None:
- continue
+ logs = []
+ for log_and_repo in self._filter_logs(start_datetime, end_datetime):
+ if repository_id and (
+ not log_and_repo.repository
+ or log_and_repo.repository.id != repository_id
+ ):
+ continue
- namespace = model.user.get_namespace_user(log_and_repo.log.account_username)
- if namespace.id != namespace_id:
- continue
+ if namespace_id:
+ if log_and_repo.log.account_username is None:
+ continue
- logs.append(log_and_repo.log)
+ namespace = model.user.get_namespace_user(
+ log_and_repo.log.account_username
+ )
+ if namespace.id != namespace_id:
+ continue
- yield logs
+ logs.append(log_and_repo.log)
- def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
- expired_logs = [log_and_repo for log_and_repo in self.logs
- if log_and_repo.log.datetime <= cutoff_date]
- while True:
- if not expired_logs:
- break
- context = InMemoryLogRotationContext(expired_logs[:min_logs_per_rotation], self.logs)
- expired_logs = expired_logs[min_logs_per_rotation:]
- yield context
+ yield logs
+
+ def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
+ expired_logs = [
+ log_and_repo
+ for log_and_repo in self.logs
+ if log_and_repo.log.datetime <= cutoff_date
+ ]
+ while True:
+ if not expired_logs:
+ break
+ context = InMemoryLogRotationContext(
+ expired_logs[:min_logs_per_rotation], self.logs
+ )
+ expired_logs = expired_logs[min_logs_per_rotation:]
+ yield context
class InMemoryLogRotationContext(LogRotationContextInterface):
- def __init__(self, expired_logs, all_logs):
- self.expired_logs = expired_logs
- self.all_logs = all_logs
+ def __init__(self, expired_logs, all_logs):
+ self.expired_logs = expired_logs
+ self.all_logs = all_logs
- def __enter__(self):
- return self
+ def __enter__(self):
+ return self
- def __exit__(self, ex_type, ex_value, ex_traceback):
- if ex_type is None and ex_value is None and ex_traceback is None:
- for log in self.expired_logs:
- self.all_logs.remove(log)
+ def __exit__(self, ex_type, ex_value, ex_traceback):
+ if ex_type is None and ex_value is None and ex_traceback is None:
+ for log in self.expired_logs:
+ self.all_logs.remove(log)
- def yield_logs_batch(self):
- """ Yield a batch of logs and a filename for that batch. """
- filename = 'inmemory_model_filename_placeholder'
- filename = '.'.join((filename, 'txt.gz'))
- yield [log_and_repo.stored_log for log_and_repo in self.expired_logs], filename
+ def yield_logs_batch(self):
+ """ Yield a batch of logs and a filename for that batch. """
+ filename = "inmemory_model_filename_placeholder"
+ filename = ".".join((filename, "txt.gz"))
+ yield [log_and_repo.stored_log for log_and_repo in self.expired_logs], filename
diff --git a/data/logs_model/interface.py b/data/logs_model/interface.py
index 705d46cc0..b7c254259 100644
--- a/data/logs_model/interface.py
+++ b/data/logs_model/interface.py
@@ -1,64 +1,112 @@
from abc import ABCMeta, abstractmethod
from six import add_metaclass
+
class LogsIterationTimeout(Exception):
- """ Exception raised if logs iteration times out. """
+ """ Exception raised if logs iteration times out. """
@add_metaclass(ABCMeta)
class ActionLogsDataInterface(object):
- """ Interface for code to work with the logs data model. The logs data model consists
+ """ Interface for code to work with the logs data model. The logs data model consists
of all access for reading and writing action logs.
"""
- @abstractmethod
- def lookup_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None, page_token=None, max_page_count=None):
- """ Looks up all logs between the start_datetime and end_datetime, filtered
+
+ @abstractmethod
+ def lookup_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ page_token=None,
+ max_page_count=None,
+ ):
+ """ Looks up all logs between the start_datetime and end_datetime, filtered
by performer (a user), repository or namespace. Note that one (and only one) of the three
can be specified. Returns a LogEntriesPage. `filter_kinds`, if specified, is a set/list
of the kinds of logs to filter out.
"""
- @abstractmethod
- def lookup_latest_logs(self, performer_name=None, repository_name=None, namespace_name=None,
- filter_kinds=None, size=20):
- """ Looks up latest logs of a specific kind, filtered by performer (a user),
+ @abstractmethod
+ def lookup_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ size=20,
+ ):
+ """ Looks up latest logs of a specific kind, filtered by performer (a user),
repository or namespace. Note that one (and only one) of the three can be specified.
Returns a list of `Log`.
"""
- @abstractmethod
- def get_aggregated_log_counts(self, start_datetime, end_datetime, performer_name=None,
- repository_name=None, namespace_name=None, filter_kinds=None):
- """ Returns the aggregated count of logs, by kind, between the start_datetime and end_datetime,
+ @abstractmethod
+ def get_aggregated_log_counts(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ """ Returns the aggregated count of logs, by kind, between the start_datetime and end_datetime,
filtered by performer (a user), repository or namespace. Note that one (and only one) of
the three can be specified. Returns a list of AggregatedLogCount.
"""
- @abstractmethod
- def count_repository_actions(self, repository, day):
- """ Returns the total number of repository actions over the given day, in the given repository
+ @abstractmethod
+ def count_repository_actions(self, repository, day):
+ """ Returns the total number of repository actions over the given day, in the given repository
or None on error.
"""
- @abstractmethod
- def queue_logs_export(self, start_datetime, end_datetime, export_action_logs_queue,
- namespace_name=None, repository_name=None, callback_url=None,
- callback_email=None, filter_kinds=None):
- """ Queues logs between the start_datetime and end_time, filtered by a repository or namespace,
+ @abstractmethod
+ def queue_logs_export(
+ self,
+ start_datetime,
+ end_datetime,
+ export_action_logs_queue,
+ namespace_name=None,
+ repository_name=None,
+ callback_url=None,
+ callback_email=None,
+ filter_kinds=None,
+ ):
+ """ Queues logs between the start_datetime and end_time, filtered by a repository or namespace,
for export to the specified URL and/or email address. Returns the ID of the export job
queued or None if error.
"""
- @abstractmethod
- def log_action(self, kind_name, namespace_name=None, performer=None, ip=None, metadata=None,
- repository=None, repository_name=None, timestamp=None, is_free_namespace=False):
- """ Logs a single action as having taken place. """
+ @abstractmethod
+ def log_action(
+ self,
+ kind_name,
+ namespace_name=None,
+ performer=None,
+ ip=None,
+ metadata=None,
+ repository=None,
+ repository_name=None,
+ timestamp=None,
+ is_free_namespace=False,
+ ):
+ """ Logs a single action as having taken place. """
- @abstractmethod
- def yield_logs_for_export(self, start_datetime, end_datetime, repository_id=None,
- namespace_id=None, max_query_time=None):
- """ Returns an iterator that yields bundles of all logs found between the start_datetime and
+ @abstractmethod
+ def yield_logs_for_export(
+ self,
+ start_datetime,
+ end_datetime,
+ repository_id=None,
+ namespace_id=None,
+ max_query_time=None,
+ ):
+ """ Returns an iterator that yields bundles of all logs found between the start_datetime and
end_datetime, optionally filtered by the repository or namespace. This function should be
used for any bulk lookup operations, and should be implemented by implementors to put
minimal strain on the backing storage for large operations. If there was an error in setting
@@ -69,9 +117,9 @@ class ActionLogsDataInterface(object):
LogsIterationTimeout will be raised instead of returning the logs bundle.
"""
- @abstractmethod
- def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
- """
+ @abstractmethod
+ def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
+ """
A generator that yields contexts implementing the LogRotationContextInterface.
Each context represents a set of logs to be archived and deleted once
the context completes without exceptions.
@@ -86,10 +134,11 @@ class ActionLogsDataInterface(object):
@add_metaclass(ABCMeta)
class LogRotationContextInterface(object):
- """ Interface for iterating over a set of logs to be archived. """
- @abstractmethod
- def yield_logs_batch(self):
- """
+ """ Interface for iterating over a set of logs to be archived. """
+
+ @abstractmethod
+ def yield_logs_batch(self):
+ """
Generator yielding batch of logs and a filename for that batch.
A batch is a subset of the logs part of the context.
"""
diff --git a/data/logs_model/logs_producer/__init__.py b/data/logs_model/logs_producer/__init__.py
index 17bd605ad..a843658e2 100644
--- a/data/logs_model/logs_producer/__init__.py
+++ b/data/logs_model/logs_producer/__init__.py
@@ -5,23 +5,24 @@ logger = logging.getLogger(__name__)
class LogSendException(Exception):
- """ A generic error when sending the logs to its destination.
+ """ A generic error when sending the logs to its destination.
e.g. Kinesis, Kafka, Elasticsearch, ...
"""
- pass
+
+ pass
class LogProducerProxy(object):
- def __init__(self):
- self._model = None
+ def __init__(self):
+ self._model = None
- def initialize(self, model):
- self._model = model
- logger.info('===============================')
- logger.info('Using producer `%s`', self._model)
- logger.info('===============================')
+ def initialize(self, model):
+ self._model = model
+ logger.info("===============================")
+ logger.info("Using producer `%s`", self._model)
+ logger.info("===============================")
- def __getattr__(self, attr):
- if not self._model:
- raise AttributeError("LogsModelProxy is not initialized")
- return getattr(self._model, attr)
+ def __getattr__(self, attr):
+ if not self._model:
+ raise AttributeError("LogsModelProxy is not initialized")
+ return getattr(self._model, attr)
diff --git a/data/logs_model/logs_producer/elasticsearch_logs_producer.py b/data/logs_model/logs_producer/elasticsearch_logs_producer.py
index 175fb4ac6..2b311dd4e 100644
--- a/data/logs_model/logs_producer/elasticsearch_logs_producer.py
+++ b/data/logs_model/logs_producer/elasticsearch_logs_producer.py
@@ -10,16 +10,27 @@ logger = logging.getLogger(__name__)
class ElasticsearchLogsProducer(LogProducerInterface):
- """ Log producer writing log entries to Elasticsearch.
+ """ Log producer writing log entries to Elasticsearch.
This implementation writes directly to Elasticsearch without a streaming/queueing service.
"""
- def send(self, logentry):
- try:
- logentry.save()
- except ElasticsearchException as ex:
- logger.exception('ElasticsearchLogsProducer error sending log to Elasticsearch: %s', ex)
- raise LogSendException('ElasticsearchLogsProducer error sending log to Elasticsearch: %s' % ex)
- except Exception as e:
- logger.exception('ElasticsearchLogsProducer exception sending log to Elasticsearch: %s', e)
- raise LogSendException('ElasticsearchLogsProducer exception sending log to Elasticsearch: %s' % e)
+
+ def send(self, logentry):
+ try:
+ logentry.save()
+ except ElasticsearchException as ex:
+ logger.exception(
+ "ElasticsearchLogsProducer error sending log to Elasticsearch: %s", ex
+ )
+ raise LogSendException(
+ "ElasticsearchLogsProducer error sending log to Elasticsearch: %s" % ex
+ )
+ except Exception as e:
+ logger.exception(
+ "ElasticsearchLogsProducer exception sending log to Elasticsearch: %s",
+ e,
+ )
+ raise LogSendException(
+ "ElasticsearchLogsProducer exception sending log to Elasticsearch: %s"
+ % e
+ )
diff --git a/data/logs_model/logs_producer/interface.py b/data/logs_model/logs_producer/interface.py
index d0d9b71d4..c9693725a 100644
--- a/data/logs_model/logs_producer/interface.py
+++ b/data/logs_model/logs_producer/interface.py
@@ -1,8 +1,9 @@
from abc import ABCMeta, abstractmethod
from six import add_metaclass
+
@add_metaclass(ABCMeta)
class LogProducerInterface(object):
- @abstractmethod
- def send(self, logentry):
- """ Send a log entry to the configured log infrastructure. """
+ @abstractmethod
+ def send(self, logentry):
+ """ Send a log entry to the configured log infrastructure. """
diff --git a/data/logs_model/logs_producer/kafka_logs_producer.py b/data/logs_model/logs_producer/kafka_logs_producer.py
index 9c13a441d..69ccae961 100644
--- a/data/logs_model/logs_producer/kafka_logs_producer.py
+++ b/data/logs_model/logs_producer/kafka_logs_producer.py
@@ -15,31 +15,44 @@ DEFAULT_MAX_BLOCK_SECONDS = 5
class KafkaLogsProducer(LogProducerInterface):
- """ Log producer writing log entries to a Kafka stream. """
- def __init__(self, bootstrap_servers=None, topic=None, client_id=None, max_block_seconds=None):
- self.bootstrap_servers = bootstrap_servers
- self.topic = topic
- self.client_id = client_id
- self.max_block_ms = (max_block_seconds or DEFAULT_MAX_BLOCK_SECONDS) * 1000
+ """ Log producer writing log entries to a Kafka stream. """
- self._producer = KafkaProducer(bootstrap_servers=self.bootstrap_servers,
- client_id=self.client_id,
- max_block_ms=self.max_block_ms,
- value_serializer=logs_json_serializer)
+ def __init__(
+ self, bootstrap_servers=None, topic=None, client_id=None, max_block_seconds=None
+ ):
+ self.bootstrap_servers = bootstrap_servers
+ self.topic = topic
+ self.client_id = client_id
+ self.max_block_ms = (max_block_seconds or DEFAULT_MAX_BLOCK_SECONDS) * 1000
- def send(self, logentry):
- try:
- # send() has a (max_block_ms) timeout and get() has a (max_block_ms) timeout
- # for an upper bound of 2x(max_block_ms) before guaranteed delivery
- future = self._producer.send(self.topic, logentry.to_dict(), timestamp_ms=epoch_ms(logentry.datetime))
- record_metadata = future.get(timeout=self.max_block_ms)
- assert future.succeeded
- except KafkaTimeoutError as kte:
- logger.exception('KafkaLogsProducer timeout sending log to Kafka: %s', kte)
- raise LogSendException('KafkaLogsProducer timeout sending log to Kafka: %s' % kte)
- except KafkaError as ke:
- logger.exception('KafkaLogsProducer error sending log to Kafka: %s', ke)
- raise LogSendException('KafkaLogsProducer error sending log to Kafka: %s' % ke)
- except Exception as e:
- logger.exception('KafkaLogsProducer exception sending log to Kafka: %s', e)
- raise LogSendException('KafkaLogsProducer exception sending log to Kafka: %s' % e)
+ self._producer = KafkaProducer(
+ bootstrap_servers=self.bootstrap_servers,
+ client_id=self.client_id,
+ max_block_ms=self.max_block_ms,
+ value_serializer=logs_json_serializer,
+ )
+
+ def send(self, logentry):
+ try:
+ # send() has a (max_block_ms) timeout and get() has a (max_block_ms) timeout
+ # for an upper bound of 2x(max_block_ms) before guaranteed delivery
+ future = self._producer.send(
+ self.topic, logentry.to_dict(), timestamp_ms=epoch_ms(logentry.datetime)
+ )
+ record_metadata = future.get(timeout=self.max_block_ms)
+ assert future.succeeded
+ except KafkaTimeoutError as kte:
+ logger.exception("KafkaLogsProducer timeout sending log to Kafka: %s", kte)
+ raise LogSendException(
+ "KafkaLogsProducer timeout sending log to Kafka: %s" % kte
+ )
+ except KafkaError as ke:
+ logger.exception("KafkaLogsProducer error sending log to Kafka: %s", ke)
+ raise LogSendException(
+ "KafkaLogsProducer error sending log to Kafka: %s" % ke
+ )
+ except Exception as e:
+ logger.exception("KafkaLogsProducer exception sending log to Kafka: %s", e)
+ raise LogSendException(
+ "KafkaLogsProducer exception sending log to Kafka: %s" % e
+ )
diff --git a/data/logs_model/logs_producer/kinesis_stream_logs_producer.py b/data/logs_model/logs_producer/kinesis_stream_logs_producer.py
index d4c03f711..1fb60f798 100644
--- a/data/logs_model/logs_producer/kinesis_stream_logs_producer.py
+++ b/data/logs_model/logs_producer/kinesis_stream_logs_producer.py
@@ -13,7 +13,7 @@ from data.logs_model.logs_producer import LogSendException
logger = logging.getLogger(__name__)
-KINESIS_PARTITION_KEY_PREFIX = 'logentry_partition_key_'
+KINESIS_PARTITION_KEY_PREFIX = "logentry_partition_key_"
DEFAULT_CONNECT_TIMEOUT = 5
DEFAULT_READ_TIMEOUT = 5
MAX_RETRY_ATTEMPTS = 5
@@ -21,55 +21,79 @@ DEFAULT_MAX_POOL_CONNECTIONS = 10
def _partition_key(number_of_shards=None):
- """ Generate a partition key for AWS Kinesis stream.
+ """ Generate a partition key for AWS Kinesis stream.
If the number of shards is specified, generate keys where the size of the key space is
the number of shards.
"""
- key = None
- if number_of_shards is not None:
- shard_number = random.randrange(0, number_of_shards)
- key = hashlib.sha1(KINESIS_PARTITION_KEY_PREFIX + str(shard_number)).hexdigest()
- else:
- key = hashlib.sha1(KINESIS_PARTITION_KEY_PREFIX + str(random.getrandbits(256))).hexdigest()
+ key = None
+ if number_of_shards is not None:
+ shard_number = random.randrange(0, number_of_shards)
+ key = hashlib.sha1(KINESIS_PARTITION_KEY_PREFIX + str(shard_number)).hexdigest()
+ else:
+ key = hashlib.sha1(
+ KINESIS_PARTITION_KEY_PREFIX + str(random.getrandbits(256))
+ ).hexdigest()
- return key
+ return key
class KinesisStreamLogsProducer(LogProducerInterface):
- """ Log producer writing log entries to an Amazon Kinesis Data Stream. """
- def __init__(self, stream_name, aws_region, aws_access_key=None, aws_secret_key=None,
- connect_timeout=None, read_timeout=None, max_retries=None,
- max_pool_connections=None):
- self._stream_name = stream_name
- self._aws_region = aws_region
- self._aws_access_key = aws_access_key
- self._aws_secret_key = aws_secret_key
- self._connect_timeout = connect_timeout or DEFAULT_CONNECT_TIMEOUT
- self._read_timeout = read_timeout or DEFAULT_READ_TIMEOUT
- self._max_retries = max_retries or MAX_RETRY_ATTEMPTS
- self._max_pool_connections=max_pool_connections or DEFAULT_MAX_POOL_CONNECTIONS
+ """ Log producer writing log entries to an Amazon Kinesis Data Stream. """
- client_config = Config(connect_timeout=self._connect_timeout,
- read_timeout=self._read_timeout ,
- retries={'max_attempts': self._max_retries},
- max_pool_connections=self._max_pool_connections)
- self._producer = boto3.client('kinesis', use_ssl=True,
- region_name=self._aws_region,
- aws_access_key_id=self._aws_access_key,
- aws_secret_access_key=self._aws_secret_key,
- config=client_config)
+ def __init__(
+ self,
+ stream_name,
+ aws_region,
+ aws_access_key=None,
+ aws_secret_key=None,
+ connect_timeout=None,
+ read_timeout=None,
+ max_retries=None,
+ max_pool_connections=None,
+ ):
+ self._stream_name = stream_name
+ self._aws_region = aws_region
+ self._aws_access_key = aws_access_key
+ self._aws_secret_key = aws_secret_key
+ self._connect_timeout = connect_timeout or DEFAULT_CONNECT_TIMEOUT
+ self._read_timeout = read_timeout or DEFAULT_READ_TIMEOUT
+ self._max_retries = max_retries or MAX_RETRY_ATTEMPTS
+ self._max_pool_connections = (
+ max_pool_connections or DEFAULT_MAX_POOL_CONNECTIONS
+ )
- def send(self, logentry):
- try:
- data = logs_json_serializer(logentry)
- self._producer.put_record(
- StreamName=self._stream_name,
- Data=data,
- PartitionKey=_partition_key()
- )
- except ClientError as ce:
- logger.exception('KinesisStreamLogsProducer client error sending log to Kinesis: %s', ce)
- raise LogSendException('KinesisStreamLogsProducer client error sending log to Kinesis: %s' % ce)
- except Exception as e:
- logger.exception('KinesisStreamLogsProducer exception sending log to Kinesis: %s', e)
- raise LogSendException('KinesisStreamLogsProducer exception sending log to Kinesis: %s' % e)
+ client_config = Config(
+ connect_timeout=self._connect_timeout,
+ read_timeout=self._read_timeout,
+ retries={"max_attempts": self._max_retries},
+ max_pool_connections=self._max_pool_connections,
+ )
+ self._producer = boto3.client(
+ "kinesis",
+ use_ssl=True,
+ region_name=self._aws_region,
+ aws_access_key_id=self._aws_access_key,
+ aws_secret_access_key=self._aws_secret_key,
+ config=client_config,
+ )
+
+ def send(self, logentry):
+ try:
+ data = logs_json_serializer(logentry)
+ self._producer.put_record(
+ StreamName=self._stream_name, Data=data, PartitionKey=_partition_key()
+ )
+ except ClientError as ce:
+ logger.exception(
+ "KinesisStreamLogsProducer client error sending log to Kinesis: %s", ce
+ )
+ raise LogSendException(
+ "KinesisStreamLogsProducer client error sending log to Kinesis: %s" % ce
+ )
+ except Exception as e:
+ logger.exception(
+ "KinesisStreamLogsProducer exception sending log to Kinesis: %s", e
+ )
+ raise LogSendException(
+ "KinesisStreamLogsProducer exception sending log to Kinesis: %s" % e
+ )
diff --git a/data/logs_model/logs_producer/test/test_json_logs_serializer.py b/data/logs_model/logs_producer/test/test_json_logs_serializer.py
index a45b0c5bb..4c2c18333 100644
--- a/data/logs_model/logs_producer/test/test_json_logs_serializer.py
+++ b/data/logs_model/logs_producer/test/test_json_logs_serializer.py
@@ -17,29 +17,50 @@ TEST_DATETIME = datetime.utcnow()
TEST_JSON_STRING = '{"a": "b", "c": "d"}'
TEST_JSON_STRING_WITH_UNICODE = u'{"éëê": "îôû"}'
-VALID_LOGENTRY = LogEntry(random_id='123-45', ip='0.0.0.0', metadata_json=TEST_JSON_STRING, datetime=TEST_DATETIME)
-VALID_LOGENTRY_WITH_UNICODE = LogEntry(random_id='123-45', ip='0.0.0.0', metadata_json=TEST_JSON_STRING_WITH_UNICODE, datetime=TEST_DATETIME)
+VALID_LOGENTRY = LogEntry(
+ random_id="123-45",
+ ip="0.0.0.0",
+ metadata_json=TEST_JSON_STRING,
+ datetime=TEST_DATETIME,
+)
+VALID_LOGENTRY_WITH_UNICODE = LogEntry(
+ random_id="123-45",
+ ip="0.0.0.0",
+ metadata_json=TEST_JSON_STRING_WITH_UNICODE,
+ datetime=TEST_DATETIME,
+)
-VALID_LOGENTRY_EXPECTED_OUTPUT = '{"datetime": "%s", "ip": "0.0.0.0", "metadata_json": "{\\"a\\": \\"b\\", \\"c\\": \\"d\\"}", "random_id": "123-45"}' % TEST_DATETIME.isoformat()
-VALID_LOGENTRY_WITH_UNICODE_EXPECTED_OUTPUT = '{"datetime": "%s", "ip": "0.0.0.0", "metadata_json": "{\\"\\u00e9\\u00eb\\u00ea\\": \\"\\u00ee\\u00f4\\u00fb\\"}", "random_id": "123-45"}' % TEST_DATETIME.isoformat()
+VALID_LOGENTRY_EXPECTED_OUTPUT = (
+ '{"datetime": "%s", "ip": "0.0.0.0", "metadata_json": "{\\"a\\": \\"b\\", \\"c\\": \\"d\\"}", "random_id": "123-45"}'
+ % TEST_DATETIME.isoformat()
+)
+VALID_LOGENTRY_WITH_UNICODE_EXPECTED_OUTPUT = (
+ '{"datetime": "%s", "ip": "0.0.0.0", "metadata_json": "{\\"\\u00e9\\u00eb\\u00ea\\": \\"\\u00ee\\u00f4\\u00fb\\"}", "random_id": "123-45"}'
+ % TEST_DATETIME.isoformat()
+)
@pytest.mark.parametrize(
- 'is_valid, given_input, expected_output',
- [
- # Valid inputs
- pytest.param(True, VALID_LOGENTRY, VALID_LOGENTRY_EXPECTED_OUTPUT),
- # With unicode
- pytest.param(True, VALID_LOGENTRY_WITH_UNICODE, VALID_LOGENTRY_WITH_UNICODE_EXPECTED_OUTPUT),
- ])
+ "is_valid, given_input, expected_output",
+ [
+ # Valid inputs
+ pytest.param(True, VALID_LOGENTRY, VALID_LOGENTRY_EXPECTED_OUTPUT),
+ # With unicode
+ pytest.param(
+ True,
+ VALID_LOGENTRY_WITH_UNICODE,
+ VALID_LOGENTRY_WITH_UNICODE_EXPECTED_OUTPUT,
+ ),
+ ],
+)
def test_logs_json_serializer(is_valid, given_input, expected_output):
- if not is_valid:
- with pytest.raises(ValueError) as ve:
- data = logs_json_serializer(given_input)
- else:
- data = logs_json_serializer(given_input, sort_keys=True)
- assert data == expected_output
+ if not is_valid:
+ with pytest.raises(ValueError) as ve:
+ data = logs_json_serializer(given_input)
+ else:
+ data = logs_json_serializer(given_input, sort_keys=True)
+ assert data == expected_output
- # Make sure the datetime was serialized in the correct ISO8601
- datetime_str = json.loads(data)['datetime']
- assert datetime_str == TEST_DATETIME.isoformat()
+ # Make sure the datetime was serialized in the correct ISO8601
+ datetime_str = json.loads(data)["datetime"]
+ assert datetime_str == TEST_DATETIME.isoformat()
diff --git a/data/logs_model/logs_producer/util.py b/data/logs_model/logs_producer/util.py
index d6c3e2d8d..fb8fa0038 100644
--- a/data/logs_model/logs_producer/util.py
+++ b/data/logs_model/logs_producer/util.py
@@ -1,15 +1,22 @@
import json
from datetime import datetime
-class LogEntryJSONEncoder(json.JSONEncoder):
- """ JSON encoder to encode datetimes to ISO8601 format. """
- def default(self, obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- return super(LogEntryJSONEncoder, self).default(obj)
+class LogEntryJSONEncoder(json.JSONEncoder):
+ """ JSON encoder to encode datetimes to ISO8601 format. """
+
+ def default(self, obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat()
+
+ return super(LogEntryJSONEncoder, self).default(obj)
+
def logs_json_serializer(logentry, sort_keys=False):
- """ Serializes a LogEntry to json bytes. """
- return json.dumps(logentry.to_dict(), cls=LogEntryJSONEncoder,
- ensure_ascii=True, sort_keys=sort_keys).encode('ascii')
+ """ Serializes a LogEntry to json bytes. """
+ return json.dumps(
+ logentry.to_dict(),
+ cls=LogEntryJSONEncoder,
+ ensure_ascii=True,
+ sort_keys=sort_keys,
+ ).encode("ascii")
diff --git a/data/logs_model/shared.py b/data/logs_model/shared.py
index 550cac95e..210b57dc9 100644
--- a/data/logs_model/shared.py
+++ b/data/logs_model/shared.py
@@ -7,47 +7,62 @@ from data import model
class SharedModel:
- def queue_logs_export(self, start_datetime, end_datetime, export_action_logs_queue,
- namespace_name=None, repository_name=None, callback_url=None,
- callback_email=None, filter_kinds=None):
- """ Queues logs between the start_datetime and end_time, filtered by a repository or namespace,
+ def queue_logs_export(
+ self,
+ start_datetime,
+ end_datetime,
+ export_action_logs_queue,
+ namespace_name=None,
+ repository_name=None,
+ callback_url=None,
+ callback_email=None,
+ filter_kinds=None,
+ ):
+ """ Queues logs between the start_datetime and end_time, filtered by a repository or namespace,
for export to the specified URL and/or email address. Returns the ID of the export job
queued or None if error.
"""
- export_id = str(uuid.uuid4())
- namespace = model.user.get_namespace_user(namespace_name)
- if namespace is None:
- return None
+ export_id = str(uuid.uuid4())
+ namespace = model.user.get_namespace_user(namespace_name)
+ if namespace is None:
+ return None
- repository = None
- if repository_name is not None:
- repository = model.repository.get_repository(namespace_name, repository_name)
- if repository is None:
- return None
+ repository = None
+ if repository_name is not None:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
+ if repository is None:
+ return None
- export_action_logs_queue.put([namespace_name],
- json.dumps({
- 'export_id': export_id,
- 'repository_id': repository.id if repository else None,
- 'namespace_id': namespace.id,
- 'namespace_name': namespace.username,
- 'repository_name': repository.name if repository else None,
- 'start_time': start_datetime.strftime('%m/%d/%Y'),
- 'end_time': end_datetime.strftime('%m/%d/%Y'),
- 'callback_url': callback_url,
- 'callback_email': callback_email,
- }), retries_remaining=3)
+ export_action_logs_queue.put(
+ [namespace_name],
+ json.dumps(
+ {
+ "export_id": export_id,
+ "repository_id": repository.id if repository else None,
+ "namespace_id": namespace.id,
+ "namespace_name": namespace.username,
+ "repository_name": repository.name if repository else None,
+ "start_time": start_datetime.strftime("%m/%d/%Y"),
+ "end_time": end_datetime.strftime("%m/%d/%Y"),
+ "callback_url": callback_url,
+ "callback_email": callback_email,
+ }
+ ),
+ retries_remaining=3,
+ )
- return export_id
+ return export_id
def epoch_ms(dt):
- return (timegm(dt.timetuple()) * 1000) + (dt.microsecond / 1000)
+ return (timegm(dt.timetuple()) * 1000) + (dt.microsecond / 1000)
def get_kinds_filter(kinds):
- """ Given a list of kinds, return the set of kinds not that are not part of that list.
+ """ Given a list of kinds, return the set of kinds not that are not part of that list.
i.e Returns the list of kinds to be filtered out. """
- kind_map = model.log.get_log_entry_kinds()
- kind_map = {key: kind_map[key] for key in kind_map if not isinstance(key, int)}
- return [kind_name for kind_name in kind_map if kind_name not in kinds]
+ kind_map = model.log.get_log_entry_kinds()
+ kind_map = {key: kind_map[key] for key in kind_map if not isinstance(key, int)}
+ return [kind_name for kind_name in kind_map if kind_name not in kinds]
diff --git a/data/logs_model/table_logs_model.py b/data/logs_model/table_logs_model.py
index 697bf2dc6..6cfcdda17 100644
--- a/data/logs_model/table_logs_model.py
+++ b/data/logs_model/table_logs_model.py
@@ -10,16 +10,19 @@ from dateutil.relativedelta import relativedelta
from data import model
from data.model import config
from data.database import LogEntry, LogEntry2, LogEntry3, UseThenDisconnect
-from data.logs_model.interface import ActionLogsDataInterface, LogsIterationTimeout, \
- LogRotationContextInterface
+from data.logs_model.interface import (
+ ActionLogsDataInterface,
+ LogsIterationTimeout,
+ LogRotationContextInterface,
+)
from data.logs_model.datatypes import Log, AggregatedLogCount, LogEntriesPage
from data.logs_model.shared import SharedModel
from data.model.log import get_stale_logs, get_stale_logs_start_id, delete_stale_logs
logger = logging.getLogger(__name__)
-MINIMUM_RANGE_SIZE = 1 # second
-MAXIMUM_RANGE_SIZE = 60 * 60 * 24 * 30 # seconds ~= 1 month
+MINIMUM_RANGE_SIZE = 1 # second
+MAXIMUM_RANGE_SIZE = 60 * 60 * 24 * 30 # seconds ~= 1 month
EXPECTED_ITERATION_LOG_COUNT = 1000
@@ -27,265 +30,379 @@ LOG_MODELS = [LogEntry3, LogEntry2, LogEntry]
class TableLogsModel(SharedModel, ActionLogsDataInterface):
- """
+ """
TableLogsModel implements the data model for the logs API backed by a single table
in the database.
"""
- def __init__(self, should_skip_logging=None, **kwargs):
- self._should_skip_logging = should_skip_logging
- def lookup_logs(self, start_datetime, end_datetime, performer_name=None, repository_name=None,
- namespace_name=None, filter_kinds=None, page_token=None, max_page_count=None):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ def __init__(self, should_skip_logging=None, **kwargs):
+ self._should_skip_logging = should_skip_logging
- assert start_datetime is not None
- assert end_datetime is not None
+ def lookup_logs(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ page_token=None,
+ max_page_count=None,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- repository = None
- if repository_name and namespace_name:
- repository = model.repository.get_repository(namespace_name, repository_name)
- assert repository
+ assert start_datetime is not None
+ assert end_datetime is not None
- performer = None
- if performer_name:
- performer = model.user.get_user(performer_name)
- assert performer
+ repository = None
+ if repository_name and namespace_name:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
+ assert repository
- def get_logs(m, page_token):
- logs_query = model.log.get_logs_query(start_datetime, end_datetime, performer=performer,
- repository=repository, namespace=namespace_name,
- ignore=filter_kinds, model=m)
+ performer = None
+ if performer_name:
+ performer = model.user.get_user(performer_name)
+ assert performer
- logs, next_page_token = model.modelutil.paginate(logs_query, m,
- descending=True,
- page_token=page_token,
- limit=20,
- max_page=max_page_count,
- sort_field_name='datetime')
+ def get_logs(m, page_token):
+ logs_query = model.log.get_logs_query(
+ start_datetime,
+ end_datetime,
+ performer=performer,
+ repository=repository,
+ namespace=namespace_name,
+ ignore=filter_kinds,
+ model=m,
+ )
- return logs, next_page_token
+ logs, next_page_token = model.modelutil.paginate(
+ logs_query,
+ m,
+ descending=True,
+ page_token=page_token,
+ limit=20,
+ max_page=max_page_count,
+ sort_field_name="datetime",
+ )
- TOKEN_TABLE_ID = 'tti'
- table_index = 0
- logs = []
- next_page_token = page_token or None
+ return logs, next_page_token
- # Skip empty pages (empty table)
- while len(logs) == 0 and table_index < len(LOG_MODELS) - 1:
- table_specified = next_page_token is not None and next_page_token.get(TOKEN_TABLE_ID) is not None
- if table_specified:
- table_index = next_page_token.get(TOKEN_TABLE_ID)
+ TOKEN_TABLE_ID = "tti"
+ table_index = 0
+ logs = []
+ next_page_token = page_token or None
- logs_result, next_page_token = get_logs(LOG_MODELS[table_index], next_page_token)
- logs.extend(logs_result)
+ # Skip empty pages (empty table)
+ while len(logs) == 0 and table_index < len(LOG_MODELS) - 1:
+ table_specified = (
+ next_page_token is not None
+ and next_page_token.get(TOKEN_TABLE_ID) is not None
+ )
+ if table_specified:
+ table_index = next_page_token.get(TOKEN_TABLE_ID)
- if next_page_token is None and table_index < len(LOG_MODELS) - 1:
- next_page_token = {TOKEN_TABLE_ID: table_index + 1}
+ logs_result, next_page_token = get_logs(
+ LOG_MODELS[table_index], next_page_token
+ )
+ logs.extend(logs_result)
- return LogEntriesPage([Log.for_logentry(log) for log in logs], next_page_token)
+ if next_page_token is None and table_index < len(LOG_MODELS) - 1:
+ next_page_token = {TOKEN_TABLE_ID: table_index + 1}
- def lookup_latest_logs(self, performer_name=None, repository_name=None, namespace_name=None,
- filter_kinds=None, size=20):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ return LogEntriesPage([Log.for_logentry(log) for log in logs], next_page_token)
- repository = None
- if repository_name and namespace_name:
- repository = model.repository.get_repository(namespace_name, repository_name)
- assert repository
+ def lookup_latest_logs(
+ self,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ size=20,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- performer = None
- if performer_name:
- performer = model.user.get_user(performer_name)
- assert performer
+ repository = None
+ if repository_name and namespace_name:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
+ assert repository
- def get_latest_logs(m):
- logs_query = model.log.get_latest_logs_query(performer=performer, repository=repository,
- namespace=namespace_name, ignore=filter_kinds,
- model=m, size=size)
+ performer = None
+ if performer_name:
+ performer = model.user.get_user(performer_name)
+ assert performer
- logs = list(logs_query)
- return [Log.for_logentry(log) for log in logs]
+ def get_latest_logs(m):
+ logs_query = model.log.get_latest_logs_query(
+ performer=performer,
+ repository=repository,
+ namespace=namespace_name,
+ ignore=filter_kinds,
+ model=m,
+ size=size,
+ )
- return get_latest_logs(LOG_MODELS[0])
+ logs = list(logs_query)
+ return [Log.for_logentry(log) for log in logs]
- def get_aggregated_log_counts(self, start_datetime, end_datetime, performer_name=None,
- repository_name=None, namespace_name=None, filter_kinds=None):
- if filter_kinds is not None:
- assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
+ return get_latest_logs(LOG_MODELS[0])
- if end_datetime - start_datetime >= timedelta(weeks=4):
- raise Exception('Cannot lookup aggregated logs over a period longer than a month')
+ def get_aggregated_log_counts(
+ self,
+ start_datetime,
+ end_datetime,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ filter_kinds=None,
+ ):
+ if filter_kinds is not None:
+ assert all(isinstance(kind_name, str) for kind_name in filter_kinds)
- repository = None
- if repository_name and namespace_name:
- repository = model.repository.get_repository(namespace_name, repository_name)
+ if end_datetime - start_datetime >= timedelta(weeks=4):
+ raise Exception(
+ "Cannot lookup aggregated logs over a period longer than a month"
+ )
- performer = None
- if performer_name:
- performer = model.user.get_user(performer_name)
+ repository = None
+ if repository_name and namespace_name:
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
- entries = {}
- for log_model in LOG_MODELS:
- aggregated = model.log.get_aggregated_logs(start_datetime, end_datetime,
- performer=performer,
- repository=repository,
- namespace=namespace_name,
- ignore=filter_kinds,
- model=log_model)
+ performer = None
+ if performer_name:
+ performer = model.user.get_user(performer_name)
- for entry in aggregated:
- synthetic_date = datetime(start_datetime.year, start_datetime.month, int(entry.day),
- tzinfo=get_localzone())
- if synthetic_date.day < start_datetime.day:
- synthetic_date = synthetic_date + relativedelta(months=1)
+ entries = {}
+ for log_model in LOG_MODELS:
+ aggregated = model.log.get_aggregated_logs(
+ start_datetime,
+ end_datetime,
+ performer=performer,
+ repository=repository,
+ namespace=namespace_name,
+ ignore=filter_kinds,
+ model=log_model,
+ )
- key = '%s-%s' % (entry.kind_id, entry.day)
+ for entry in aggregated:
+ synthetic_date = datetime(
+ start_datetime.year,
+ start_datetime.month,
+ int(entry.day),
+ tzinfo=get_localzone(),
+ )
+ if synthetic_date.day < start_datetime.day:
+ synthetic_date = synthetic_date + relativedelta(months=1)
- if key in entries:
- entries[key] = AggregatedLogCount(entry.kind_id, entry.count + entries[key].count,
- synthetic_date)
- else:
- entries[key] = AggregatedLogCount(entry.kind_id, entry.count, synthetic_date)
+ key = "%s-%s" % (entry.kind_id, entry.day)
- return entries.values()
+ if key in entries:
+ entries[key] = AggregatedLogCount(
+ entry.kind_id, entry.count + entries[key].count, synthetic_date
+ )
+ else:
+ entries[key] = AggregatedLogCount(
+ entry.kind_id, entry.count, synthetic_date
+ )
- def count_repository_actions(self, repository, day):
- return model.repositoryactioncount.count_repository_actions(repository, day)
+ return entries.values()
- def log_action(self, kind_name, namespace_name=None, performer=None, ip=None, metadata=None,
- repository=None, repository_name=None, timestamp=None, is_free_namespace=False):
- if self._should_skip_logging and self._should_skip_logging(kind_name, namespace_name,
- is_free_namespace):
- return
+ def count_repository_actions(self, repository, day):
+ return model.repositoryactioncount.count_repository_actions(repository, day)
- if repository_name is not None:
- assert repository is None
- assert namespace_name is not None
- repository = model.repository.get_repository(namespace_name, repository_name)
+ def log_action(
+ self,
+ kind_name,
+ namespace_name=None,
+ performer=None,
+ ip=None,
+ metadata=None,
+ repository=None,
+ repository_name=None,
+ timestamp=None,
+ is_free_namespace=False,
+ ):
+ if self._should_skip_logging and self._should_skip_logging(
+ kind_name, namespace_name, is_free_namespace
+ ):
+ return
- model.log.log_action(kind_name, namespace_name, performer=performer, repository=repository,
- ip=ip, metadata=metadata or {}, timestamp=timestamp)
+ if repository_name is not None:
+ assert repository is None
+ assert namespace_name is not None
+ repository = model.repository.get_repository(
+ namespace_name, repository_name
+ )
- def yield_logs_for_export(self, start_datetime, end_datetime, repository_id=None,
- namespace_id=None, max_query_time=None):
- # Using an adjusting scale, start downloading log rows in batches, starting at
- # MINIMUM_RANGE_SIZE and doubling until we've reached EXPECTED_ITERATION_LOG_COUNT or
- # the lookup range has reached MAXIMUM_RANGE_SIZE. If at any point this operation takes
- # longer than the MAXIMUM_WORK_PERIOD_SECONDS, terminate the batch operation as timed out.
- batch_start_time = datetime.utcnow()
+ model.log.log_action(
+ kind_name,
+ namespace_name,
+ performer=performer,
+ repository=repository,
+ ip=ip,
+ metadata=metadata or {},
+ timestamp=timestamp,
+ )
- current_start_datetime = start_datetime
- current_batch_size = timedelta(seconds=MINIMUM_RANGE_SIZE)
+ def yield_logs_for_export(
+ self,
+ start_datetime,
+ end_datetime,
+ repository_id=None,
+ namespace_id=None,
+ max_query_time=None,
+ ):
+ # Using an adjusting scale, start downloading log rows in batches, starting at
+ # MINIMUM_RANGE_SIZE and doubling until we've reached EXPECTED_ITERATION_LOG_COUNT or
+ # the lookup range has reached MAXIMUM_RANGE_SIZE. If at any point this operation takes
+ # longer than the MAXIMUM_WORK_PERIOD_SECONDS, terminate the batch operation as timed out.
+ batch_start_time = datetime.utcnow()
- while current_start_datetime < end_datetime:
- # Verify we haven't been working for too long.
- work_elapsed = datetime.utcnow() - batch_start_time
- if max_query_time is not None and work_elapsed > max_query_time:
- logger.error('Retrieval of logs `%s/%s` timed out with time of `%s`',
- namespace_id, repository_id, work_elapsed)
- raise LogsIterationTimeout()
+ current_start_datetime = start_datetime
+ current_batch_size = timedelta(seconds=MINIMUM_RANGE_SIZE)
- current_end_datetime = current_start_datetime + current_batch_size
- current_end_datetime = min(current_end_datetime, end_datetime)
+ while current_start_datetime < end_datetime:
+ # Verify we haven't been working for too long.
+ work_elapsed = datetime.utcnow() - batch_start_time
+ if max_query_time is not None and work_elapsed > max_query_time:
+ logger.error(
+ "Retrieval of logs `%s/%s` timed out with time of `%s`",
+ namespace_id,
+ repository_id,
+ work_elapsed,
+ )
+ raise LogsIterationTimeout()
- # Load the next set of logs.
- def load_logs():
- logger.debug('Retrieving logs over range %s -> %s with namespace %s and repository %s',
- current_start_datetime, current_end_datetime, namespace_id, repository_id)
+ current_end_datetime = current_start_datetime + current_batch_size
+ current_end_datetime = min(current_end_datetime, end_datetime)
- logs_query = model.log.get_logs_query(namespace=namespace_id,
- repository=repository_id,
- start_time=current_start_datetime,
- end_time=current_end_datetime)
- logs = list(logs_query)
- for log in logs:
- if namespace_id is not None:
- assert log.account_id == namespace_id
+ # Load the next set of logs.
+ def load_logs():
+ logger.debug(
+ "Retrieving logs over range %s -> %s with namespace %s and repository %s",
+ current_start_datetime,
+ current_end_datetime,
+ namespace_id,
+ repository_id,
+ )
- if repository_id is not None:
- assert log.repository_id == repository_id
+ logs_query = model.log.get_logs_query(
+ namespace=namespace_id,
+ repository=repository_id,
+ start_time=current_start_datetime,
+ end_time=current_end_datetime,
+ )
+ logs = list(logs_query)
+ for log in logs:
+ if namespace_id is not None:
+ assert log.account_id == namespace_id
- logs = [Log.for_logentry(log) for log in logs]
- return logs
+ if repository_id is not None:
+ assert log.repository_id == repository_id
- logs, elapsed = _run_and_time(load_logs)
- if max_query_time is not None and elapsed > max_query_time:
- logger.error('Retrieval of logs for export `%s/%s` with range `%s-%s` timed out at `%s`',
- namespace_id, repository_id, current_start_datetime, current_end_datetime,
- elapsed)
- raise LogsIterationTimeout()
+ logs = [Log.for_logentry(log) for log in logs]
+ return logs
- yield logs
+ logs, elapsed = _run_and_time(load_logs)
+ if max_query_time is not None and elapsed > max_query_time:
+ logger.error(
+ "Retrieval of logs for export `%s/%s` with range `%s-%s` timed out at `%s`",
+ namespace_id,
+ repository_id,
+ current_start_datetime,
+ current_end_datetime,
+ elapsed,
+ )
+ raise LogsIterationTimeout()
- # Move forward.
- current_start_datetime = current_end_datetime
+ yield logs
- # Increase the batch size if necessary.
- if len(logs) < EXPECTED_ITERATION_LOG_COUNT:
- seconds = min(MAXIMUM_RANGE_SIZE, current_batch_size.total_seconds() * 2)
- current_batch_size = timedelta(seconds=seconds)
+ # Move forward.
+ current_start_datetime = current_end_datetime
- def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
- """ Yield a context manager for a group of outdated logs. """
- for log_model in LOG_MODELS:
- while True:
- with UseThenDisconnect(config.app_config):
- start_id = get_stale_logs_start_id(log_model)
+ # Increase the batch size if necessary.
+ if len(logs) < EXPECTED_ITERATION_LOG_COUNT:
+ seconds = min(
+ MAXIMUM_RANGE_SIZE, current_batch_size.total_seconds() * 2
+ )
+ current_batch_size = timedelta(seconds=seconds)
- if start_id is None:
- logger.warning('Failed to find start id')
- break
+ def yield_log_rotation_context(self, cutoff_date, min_logs_per_rotation):
+ """ Yield a context manager for a group of outdated logs. """
+ for log_model in LOG_MODELS:
+ while True:
+ with UseThenDisconnect(config.app_config):
+ start_id = get_stale_logs_start_id(log_model)
- logger.debug('Found starting ID %s', start_id)
- lookup_end_id = start_id + min_logs_per_rotation
- logs = [log for log in get_stale_logs(start_id, lookup_end_id,
- log_model, cutoff_date)]
+ if start_id is None:
+ logger.warning("Failed to find start id")
+ break
- if not logs:
- logger.debug('No further logs found')
- break
+ logger.debug("Found starting ID %s", start_id)
+ lookup_end_id = start_id + min_logs_per_rotation
+ logs = [
+ log
+ for log in get_stale_logs(
+ start_id, lookup_end_id, log_model, cutoff_date
+ )
+ ]
- end_id = max([log.id for log in logs])
- context = DatabaseLogRotationContext(logs, log_model, start_id, end_id)
- yield context
+ if not logs:
+ logger.debug("No further logs found")
+ break
+
+ end_id = max([log.id for log in logs])
+ context = DatabaseLogRotationContext(logs, log_model, start_id, end_id)
+ yield context
def _run_and_time(fn):
- start_time = datetime.utcnow()
- result = fn()
- return result, datetime.utcnow() - start_time
+ start_time = datetime.utcnow()
+ result = fn()
+ return result, datetime.utcnow() - start_time
table_logs_model = TableLogsModel()
class DatabaseLogRotationContext(LogRotationContextInterface):
- """
+ """
DatabaseLogRotationContext represents a batch of logs to be archived together.
i.e A set of logs to be archived in the same file (based on the number of logs per rotation).
When completed without exceptions, this context will delete the stale logs
from rows `start_id` to `end_id`.
"""
- def __init__(self, logs, log_model, start_id, end_id):
- self.logs = logs
- self.log_model = log_model
- self.start_id = start_id
- self.end_id = end_id
- def __enter__(self):
- return self
+ def __init__(self, logs, log_model, start_id, end_id):
+ self.logs = logs
+ self.log_model = log_model
+ self.start_id = start_id
+ self.end_id = end_id
- def __exit__(self, ex_type, ex_value, ex_traceback):
- if ex_type is None and ex_value is None and ex_traceback is None:
- with UseThenDisconnect(config.app_config):
- logger.debug('Deleting logs from IDs %s to %s', self.start_id, self.end_id)
- delete_stale_logs(self.start_id, self.end_id, self.log_model)
+ def __enter__(self):
+ return self
- def yield_logs_batch(self):
- """ Yield a batch of logs and a filename for that batch. """
- filename = '%d-%d-%s.txt.gz' % (self.start_id, self.end_id,
- self.log_model.__name__.lower())
- yield self.logs, filename
+ def __exit__(self, ex_type, ex_value, ex_traceback):
+ if ex_type is None and ex_value is None and ex_traceback is None:
+ with UseThenDisconnect(config.app_config):
+ logger.debug(
+ "Deleting logs from IDs %s to %s", self.start_id, self.end_id
+ )
+ delete_stale_logs(self.start_id, self.end_id, self.log_model)
+
+ def yield_logs_batch(self):
+ """ Yield a batch of logs and a filename for that batch. """
+ filename = "%d-%d-%s.txt.gz" % (
+ self.start_id,
+ self.end_id,
+ self.log_model.__name__.lower(),
+ )
+ yield self.logs, filename
diff --git a/data/logs_model/test/mock_elasticsearch.py b/data/logs_model/test/mock_elasticsearch.py
index bd26a10c7..d3b4da883 100644
--- a/data/logs_model/test/mock_elasticsearch.py
+++ b/data/logs_model/test/mock_elasticsearch.py
@@ -8,368 +8,304 @@ from data.logs_model.datatypes import LogEntriesPage, Log, AggregatedLogCount
def _status(d, code=200):
- return {"status_code": code, "content": json.dumps(d)}
+ return {"status_code": code, "content": json.dumps(d)}
def _shards(d, total=5, failed=0, successful=5):
- d.update({"_shards": {"total": total, "failed": failed, "successful": successful}})
- return d
+ d.update({"_shards": {"total": total, "failed": failed, "successful": successful}})
+ return d
def _hits(hits):
- return {"hits": {"total": len(hits), "max_score": None, "hits": hits}}
+ return {"hits": {"total": len(hits), "max_score": None, "hits": hits}}
-INDEX_LIST_RESPONSE_HIT1_HIT2 = _status({
- "logentry_2018-03-08": {},
- "logentry_2018-04-02": {}
-})
+INDEX_LIST_RESPONSE_HIT1_HIT2 = _status(
+ {"logentry_2018-03-08": {}, "logentry_2018-04-02": {}}
+)
-INDEX_LIST_RESPONSE_HIT2 = _status({
- "logentry_2018-04-02": {}
-})
+INDEX_LIST_RESPONSE_HIT2 = _status({"logentry_2018-04-02": {}})
-INDEX_LIST_RESPONSE = _status({
- "logentry_2019-01-01": {},
- "logentry_2017-03-08": {},
- "logentry_2018-03-08": {},
- "logentry_2018-04-02": {}
-})
+INDEX_LIST_RESPONSE = _status(
+ {
+ "logentry_2019-01-01": {},
+ "logentry_2017-03-08": {},
+ "logentry_2018-03-08": {},
+ "logentry_2018-04-02": {},
+ }
+)
DEFAULT_TEMPLATE_RESPONSE = _status({"acknowledged": True})
INDEX_RESPONSE_2019_01_01 = _status(
- _shards({
- "_index": "logentry_2019-01-01",
- "_type": "_doc",
- "_id": "1",
- "_version": 1,
- "_seq_no": 0,
- "_primary_term": 1,
- "result": "created"
- }))
+ _shards(
+ {
+ "_index": "logentry_2019-01-01",
+ "_type": "_doc",
+ "_id": "1",
+ "_version": 1,
+ "_seq_no": 0,
+ "_primary_term": 1,
+ "result": "created",
+ }
+ )
+)
INDEX_RESPONSE_2017_03_08 = _status(
- _shards({
- "_index": "logentry_2017-03-08",
- "_type": "_doc",
- "_id": "1",
- "_version": 1,
- "_seq_no": 0,
- "_primary_term": 1,
- "result": "created"
- }))
+ _shards(
+ {
+ "_index": "logentry_2017-03-08",
+ "_type": "_doc",
+ "_id": "1",
+ "_version": 1,
+ "_seq_no": 0,
+ "_primary_term": 1,
+ "result": "created",
+ }
+ )
+)
FAILURE_400 = _status({}, 400)
INDEX_REQUEST_2019_01_01 = [
- "logentry_2019-01-01", {
- "account_id":
- 1,
- "repository_id":
- 1,
- "ip":
- "192.168.1.1",
- "random_id":
- 233,
- "datetime":
- "2019-01-01T03:30:00",
- "metadata_json": json.loads("{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1520479800}"),
- "performer_id":
- 1,
- "kind_id":
- 1
- }
+ "logentry_2019-01-01",
+ {
+ "account_id": 1,
+ "repository_id": 1,
+ "ip": "192.168.1.1",
+ "random_id": 233,
+ "datetime": "2019-01-01T03:30:00",
+ "metadata_json": json.loads(
+ '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1520479800}'
+ ),
+ "performer_id": 1,
+ "kind_id": 1,
+ },
]
INDEX_REQUEST_2017_03_08 = [
- "logentry_2017-03-08", {
- "repository_id":
- 1,
- "account_id":
- 1,
- "ip":
- "192.168.1.1",
- "random_id":
- 233,
- "datetime":
- "2017-03-08T03:30:00",
- "metadata_json": json.loads("{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1520479800}"),
- "performer_id":
- 1,
- "kind_id":
- 2
- }
+ "logentry_2017-03-08",
+ {
+ "repository_id": 1,
+ "account_id": 1,
+ "ip": "192.168.1.1",
+ "random_id": 233,
+ "datetime": "2017-03-08T03:30:00",
+ "metadata_json": json.loads(
+ '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1520479800}'
+ ),
+ "performer_id": 1,
+ "kind_id": 2,
+ },
]
_hit1 = {
- "_index": "logentry_2018-03-08",
- "_type": "doc",
- "_id": "1",
- "_score": None,
- "_source": {
- "random_id":
- 233,
- "kind_id":
- 1,
- "account_id":
- 1,
- "performer_id":
- 1,
- "repository_id":
- 1,
- "ip":
- "192.168.1.1",
- "metadata_json":
- "{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1520479800}",
- "datetime":
- "2018-03-08T03:30",
- },
- "sort": [1520479800000, 233]
+ "_index": "logentry_2018-03-08",
+ "_type": "doc",
+ "_id": "1",
+ "_score": None,
+ "_source": {
+ "random_id": 233,
+ "kind_id": 1,
+ "account_id": 1,
+ "performer_id": 1,
+ "repository_id": 1,
+ "ip": "192.168.1.1",
+ "metadata_json": '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1520479800}',
+ "datetime": "2018-03-08T03:30",
+ },
+ "sort": [1520479800000, 233],
}
_hit2 = {
- "_index": "logentry_2018-04-02",
- "_type": "doc",
- "_id": "2",
- "_score": None,
- "_source": {
- "random_id":
- 233,
- "kind_id":
- 2,
- "account_id":
- 1,
- "performer_id":
- 1,
- "repository_id":
- 1,
- "ip":
- "192.168.1.2",
- "metadata_json":
- "{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1522639800}",
- "datetime":
- "2018-04-02T03:30",
- },
- "sort": [1522639800000, 233]
+ "_index": "logentry_2018-04-02",
+ "_type": "doc",
+ "_id": "2",
+ "_score": None,
+ "_source": {
+ "random_id": 233,
+ "kind_id": 2,
+ "account_id": 1,
+ "performer_id": 1,
+ "repository_id": 1,
+ "ip": "192.168.1.2",
+ "metadata_json": '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1522639800}',
+ "datetime": "2018-04-02T03:30",
+ },
+ "sort": [1522639800000, 233],
}
_log1 = Log(
- "{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1520479800}",
- "192.168.1.1", parse("2018-03-08T03:30"), "user1.email", "user1.username", "user1.robot",
- "user1.organization", "user1.username", "user1.email", "user1.robot", 1)
+ '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1520479800}',
+ "192.168.1.1",
+ parse("2018-03-08T03:30"),
+ "user1.email",
+ "user1.username",
+ "user1.robot",
+ "user1.organization",
+ "user1.username",
+ "user1.email",
+ "user1.robot",
+ 1,
+)
_log2 = Log(
- "{\"\\ud83d\\ude02\": \"\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\", \"key\": \"value\", \"time\": 1522639800}",
- "192.168.1.2", parse("2018-04-02T03:30"), "user1.email", "user1.username", "user1.robot",
- "user1.organization", "user1.username", "user1.email", "user1.robot", 2)
+ '{"\\ud83d\\ude02": "\\ud83d\\ude02\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c\\ud83d\\udc4c", "key": "value", "time": 1522639800}',
+ "192.168.1.2",
+ parse("2018-04-02T03:30"),
+ "user1.email",
+ "user1.username",
+ "user1.robot",
+ "user1.organization",
+ "user1.username",
+ "user1.email",
+ "user1.robot",
+ 2,
+)
SEARCH_RESPONSE_START = _status(_shards(_hits([_hit1, _hit2])))
SEARCH_RESPONSE_END = _status(_shards(_hits([_hit2])))
SEARCH_REQUEST_START = {
- "sort": [{
- "datetime": "desc"
- }, {
- "random_id.keyword": "desc"
- }],
- "query": {
- "bool": {
- "filter": [{
- "term": {
- "performer_id": 1
+ "sort": [{"datetime": "desc"}, {"random_id.keyword": "desc"}],
+ "query": {
+ "bool": {
+ "filter": [{"term": {"performer_id": 1}}, {"term": {"repository_id": 1}}]
}
- }, {
- "term": {
- "repository_id": 1
- }
- }]
- }
- },
- "size": 2
+ },
+ "size": 2,
}
SEARCH_REQUEST_END = {
- "sort": [{
- "datetime": "desc"
- }, {
- "random_id.keyword": "desc"
- }],
- "query": {
- "bool": {
- "filter": [{
- "term": {
- "performer_id": 1
+ "sort": [{"datetime": "desc"}, {"random_id.keyword": "desc"}],
+ "query": {
+ "bool": {
+ "filter": [{"term": {"performer_id": 1}}, {"term": {"repository_id": 1}}]
}
- }, {
- "term": {
- "repository_id": 1
- }
- }]
- }
- },
- "search_after": [1520479800000, 233],
- "size": 2
+ },
+ "search_after": [1520479800000, 233],
+ "size": 2,
}
SEARCH_REQUEST_FILTER = {
- "sort": [{
- "datetime": "desc"
- }, {
- "random_id.keyword": "desc"
- }],
- "query": {
- "bool": {
- "filter": [{
- "term": {
- "performer_id": 1
- }
- }, {
- "term": {
- "repository_id": 1
- }
- }, {
+ "sort": [{"datetime": "desc"}, {"random_id.keyword": "desc"}],
+ "query": {
"bool": {
- "must_not": [{
- "terms": {
- "kind_id": [1]
- }
- }]
+ "filter": [
+ {"term": {"performer_id": 1}},
+ {"term": {"repository_id": 1}},
+ {"bool": {"must_not": [{"terms": {"kind_id": [1]}}]}},
+ ]
}
- }]
- }
- },
- "size": 2
+ },
+ "size": 2,
}
SEARCH_PAGE_TOKEN = {
- "datetime": datetime(2018, 3, 8, 3, 30).isoformat(),
- "random_id": 233,
- "page_number": 1
+ "datetime": datetime(2018, 3, 8, 3, 30).isoformat(),
+ "random_id": 233,
+ "page_number": 1,
}
SEARCH_PAGE_START = LogEntriesPage(logs=[_log1], next_page_token=SEARCH_PAGE_TOKEN)
SEARCH_PAGE_END = LogEntriesPage(logs=[_log2], next_page_token=None)
SEARCH_PAGE_EMPTY = LogEntriesPage([], None)
AGGS_RESPONSE = _status(
- _shards({
- "hits": {
- "total": 4,
- "max_score": None,
- "hits": []
- },
- "aggregations": {
- "by_id": {
- "doc_count_error_upper_bound":
- 0,
- "sum_other_doc_count":
- 0,
- "buckets": [{
- "key": 2,
- "doc_count": 3,
- "by_date": {
- "buckets": [{
- "key_as_string": "2009-11-12T00:00:00.000Z",
- "key": 1257984000000,
- "doc_count": 1
- }, {
- "key_as_string": "2009-11-13T00:00:00.000Z",
- "key": 1258070400000,
- "doc_count": 0
- }, {
- "key_as_string": "2009-11-14T00:00:00.000Z",
- "key": 1258156800000,
- "doc_count": 2
- }]
- }
- }, {
- "key": 1,
- "doc_count": 1,
- "by_date": {
- "buckets": [{
- "key_as_string": "2009-11-15T00:00:00.000Z",
- "key": 1258243200000,
- "doc_count": 1
- }]
- }
- }]
- }
- }
- }))
+ _shards(
+ {
+ "hits": {"total": 4, "max_score": None, "hits": []},
+ "aggregations": {
+ "by_id": {
+ "doc_count_error_upper_bound": 0,
+ "sum_other_doc_count": 0,
+ "buckets": [
+ {
+ "key": 2,
+ "doc_count": 3,
+ "by_date": {
+ "buckets": [
+ {
+ "key_as_string": "2009-11-12T00:00:00.000Z",
+ "key": 1257984000000,
+ "doc_count": 1,
+ },
+ {
+ "key_as_string": "2009-11-13T00:00:00.000Z",
+ "key": 1258070400000,
+ "doc_count": 0,
+ },
+ {
+ "key_as_string": "2009-11-14T00:00:00.000Z",
+ "key": 1258156800000,
+ "doc_count": 2,
+ },
+ ]
+ },
+ },
+ {
+ "key": 1,
+ "doc_count": 1,
+ "by_date": {
+ "buckets": [
+ {
+ "key_as_string": "2009-11-15T00:00:00.000Z",
+ "key": 1258243200000,
+ "doc_count": 1,
+ }
+ ]
+ },
+ },
+ ],
+ }
+ },
+ }
+ )
+)
AGGS_REQUEST = {
- "query": {
- "bool": {
- "filter": [{
- "term": {
- "performer_id": 1
- }
- }, {
- "term": {
- "repository_id": 1
- }
- }, {
+ "query": {
"bool": {
- "must_not": [{
- "terms": {
- "kind_id": [2]
- }
- }]
+ "filter": [
+ {"term": {"performer_id": 1}},
+ {"term": {"repository_id": 1}},
+ {"bool": {"must_not": [{"terms": {"kind_id": [2]}}]}},
+ ],
+ "must": [
+ {
+ "range": {
+ "datetime": {
+ "lt": "2018-04-08T03:30:00",
+ "gte": "2018-03-08T03:30:00",
+ }
+ }
+ }
+ ],
}
- }],
- "must": [{
- "range": {
- "datetime": {
- "lt": "2018-04-08T03:30:00",
- "gte": "2018-03-08T03:30:00"
- }
+ },
+ "aggs": {
+ "by_id": {
+ "terms": {"field": "kind_id"},
+ "aggs": {
+ "by_date": {"date_histogram": {"field": "datetime", "interval": "day"}}
+ },
}
- }]
- }
- },
- "aggs": {
- "by_id": {
- "terms": {
- "field": "kind_id"
- },
- "aggs": {
- "by_date": {
- "date_histogram": {
- "field": "datetime",
- "interval": "day"
- }
- }
- }
- }
- },
- "size": 0
+ },
+ "size": 0,
}
AGGS_COUNT = [
- AggregatedLogCount(1, 1, parse("2009-11-15T00:00:00.000")),
- AggregatedLogCount(2, 1, parse("2009-11-12T00:00:00.000")),
- AggregatedLogCount(2, 2, parse("2009-11-14T00:00:00.000"))
+ AggregatedLogCount(1, 1, parse("2009-11-15T00:00:00.000")),
+ AggregatedLogCount(2, 1, parse("2009-11-12T00:00:00.000")),
+ AggregatedLogCount(2, 2, parse("2009-11-14T00:00:00.000")),
]
-COUNT_REQUEST = {
- "query": {
- "bool": {
- "filter": [{
- "term": {
- "repository_id": 1
- }
- }]
- }
- }
-}
-COUNT_RESPONSE = _status(_shards({
- "count": 1,
-}))
+COUNT_REQUEST = {"query": {"bool": {"filter": [{"term": {"repository_id": 1}}]}}}
+COUNT_RESPONSE = _status(_shards({"count": 1}))
# assume there are 2 pages
_scroll_id = "DnF1ZXJ5VGhlbkZldGNoBQAAAAAAACEmFkk1aGlTRzdSUWllejZmYTlEYTN3SVEAAAAAAAAhJRZJNWhpU0c3UlFpZXo2ZmE5RGEzd0lRAAAAAAAAHtAWLWZpaFZXVzVSTy1OTXA5V3MwcHZrZwAAAAAAAB7RFi1maWhWV1c1Uk8tTk1wOVdzMHB2a2cAAAAAAAAhJxZJNWhpU0c3UlFpZXo2ZmE5RGEzd0lR"
def _scroll(d):
- d["_scroll_id"] = _scroll_id
- return d
+ d["_scroll_id"] = _scroll_id
+ return d
SCROLL_CREATE = _status(_shards(_scroll(_hits([_hit1]))))
@@ -379,22 +315,24 @@ SCROLL_DELETE = _status({"succeeded": True, "num_freed": 5})
SCROLL_LOGS = [[_log1], [_log2]]
SCROLL_REQUESTS = [
- [
- "5m", 1, {
- "sort": "_doc",
- "query": {
- "range": {
- "datetime": {
- "lt": "2018-04-02T00:00:00",
- "gte": "2018-03-08T00:00:00"
- }
- }
- }
- }
- ],
- [{"scroll": "5m", "scroll_id": _scroll_id}],
- [{"scroll":"5m", "scroll_id": _scroll_id}],
- [{"scroll_id": [_scroll_id]}],
+ [
+ "5m",
+ 1,
+ {
+ "sort": "_doc",
+ "query": {
+ "range": {
+ "datetime": {
+ "lt": "2018-04-02T00:00:00",
+ "gte": "2018-03-08T00:00:00",
+ }
+ }
+ },
+ },
+ ],
+ [{"scroll": "5m", "scroll_id": _scroll_id}],
+ [{"scroll": "5m", "scroll_id": _scroll_id}],
+ [{"scroll_id": [_scroll_id]}],
]
SCROLL_RESPONSES = [SCROLL_CREATE, SCROLL_GET, SCROLL_GET_2, SCROLL_DELETE]
diff --git a/data/logs_model/test/test_combined_model.py b/data/logs_model/test/test_combined_model.py
index 7b288e72f..c6a6c6297 100644
--- a/data/logs_model/test/test_combined_model.py
+++ b/data/logs_model/test/test_combined_model.py
@@ -10,121 +10,144 @@ from test.fixtures import *
@pytest.fixture()
def first_model():
- return InMemoryModel()
+ return InMemoryModel()
@pytest.fixture()
def second_model():
- return InMemoryModel()
+ return InMemoryModel()
@pytest.fixture()
def combined_model(first_model, second_model, initialized_db):
- return CombinedLogsModel(first_model, second_model)
+ return CombinedLogsModel(first_model, second_model)
def test_log_action(first_model, second_model, combined_model, initialized_db):
- day = date(2019, 1, 1)
+ day = date(2019, 1, 1)
- # Write to the combined model.
- with freeze_time(day):
- combined_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ # Write to the combined model.
+ with freeze_time(day):
+ combined_model.log_action(
+ "push_repo",
+ namespace_name="devtable",
+ repository_name="simple",
+ ip="1.2.3.4",
+ )
- simple_repo = model.repository.get_repository('devtable', 'simple')
+ simple_repo = model.repository.get_repository("devtable", "simple")
- # Make sure it is found in the first model but not the second.
- assert combined_model.count_repository_actions(simple_repo, day) == 1
- assert first_model.count_repository_actions(simple_repo, day) == 1
- assert second_model.count_repository_actions(simple_repo, day) == 0
+ # Make sure it is found in the first model but not the second.
+ assert combined_model.count_repository_actions(simple_repo, day) == 1
+ assert first_model.count_repository_actions(simple_repo, day) == 1
+ assert second_model.count_repository_actions(simple_repo, day) == 0
-def test_count_repository_actions(first_model, second_model, combined_model, initialized_db):
- # Write to each model.
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+def test_count_repository_actions(
+ first_model, second_model, combined_model, initialized_db
+):
+ # Write to each model.
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- # Ensure the counts match as expected.
- day = datetime.today() - timedelta(minutes=60)
- simple_repo = model.repository.get_repository('devtable', 'simple')
+ # Ensure the counts match as expected.
+ day = datetime.today() - timedelta(minutes=60)
+ simple_repo = model.repository.get_repository("devtable", "simple")
- assert first_model.count_repository_actions(simple_repo, day) == 3
- assert second_model.count_repository_actions(simple_repo, day) == 2
- assert combined_model.count_repository_actions(simple_repo, day) == 5
+ assert first_model.count_repository_actions(simple_repo, day) == 3
+ assert second_model.count_repository_actions(simple_repo, day) == 2
+ assert combined_model.count_repository_actions(simple_repo, day) == 5
-def test_yield_logs_for_export(first_model, second_model, combined_model, initialized_db):
- now = datetime.now()
+def test_yield_logs_for_export(
+ first_model, second_model, combined_model, initialized_db
+):
+ now = datetime.now()
- # Write to each model.
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ # Write to each model.
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- later = datetime.now()
+ later = datetime.now()
- # Ensure the full set of logs is yielded.
- first_logs = list(first_model.yield_logs_for_export(now, later))[0]
- second_logs = list(second_model.yield_logs_for_export(now, later))[0]
+ # Ensure the full set of logs is yielded.
+ first_logs = list(first_model.yield_logs_for_export(now, later))[0]
+ second_logs = list(second_model.yield_logs_for_export(now, later))[0]
- combined = list(combined_model.yield_logs_for_export(now, later))
- full_combined = []
- for subset in combined:
- full_combined.extend(subset)
+ combined = list(combined_model.yield_logs_for_export(now, later))
+ full_combined = []
+ for subset in combined:
+ full_combined.extend(subset)
- assert len(full_combined) == len(first_logs) + len(second_logs)
- assert full_combined == (first_logs + second_logs)
+ assert len(full_combined) == len(first_logs) + len(second_logs)
+ assert full_combined == (first_logs + second_logs)
def test_lookup_logs(first_model, second_model, combined_model, initialized_db):
- now = datetime.now()
+ now = datetime.now()
- # Write to each model.
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- first_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ # Write to each model.
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ first_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- second_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ second_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- later = datetime.now()
+ later = datetime.now()
- def _collect_logs(model):
- page_token = None
- all_logs = []
- while True:
- paginated_logs = model.lookup_logs(now, later, page_token=page_token)
- page_token = paginated_logs.next_page_token
- all_logs.extend(paginated_logs.logs)
- if page_token is None:
- break
- return all_logs
+ def _collect_logs(model):
+ page_token = None
+ all_logs = []
+ while True:
+ paginated_logs = model.lookup_logs(now, later, page_token=page_token)
+ page_token = paginated_logs.next_page_token
+ all_logs.extend(paginated_logs.logs)
+ if page_token is None:
+ break
+ return all_logs
- first_logs = _collect_logs(first_model)
- second_logs = _collect_logs(second_model)
- combined = _collect_logs(combined_model)
+ first_logs = _collect_logs(first_model)
+ second_logs = _collect_logs(second_model)
+ combined = _collect_logs(combined_model)
- assert len(combined) == len(first_logs) + len(second_logs)
- assert combined == (first_logs + second_logs)
+ assert len(combined) == len(first_logs) + len(second_logs)
+ assert combined == (first_logs + second_logs)
diff --git a/data/logs_model/test/test_elasticsearch.py b/data/logs_model/test/test_elasticsearch.py
index a305010f4..2449c6168 100644
--- a/data/logs_model/test/test_elasticsearch.py
+++ b/data/logs_model/test/test_elasticsearch.py
@@ -12,256 +12,340 @@ from dateutil.parser import parse
from httmock import urlmatch, HTTMock
from data.model.log import _json_serialize
-from data.logs_model.elastic_logs import ElasticsearchLogs, INDEX_NAME_PREFIX, INDEX_DATE_FORMAT
+from data.logs_model.elastic_logs import (
+ ElasticsearchLogs,
+ INDEX_NAME_PREFIX,
+ INDEX_DATE_FORMAT,
+)
from data.logs_model import configure, LogsModelProxy
from mock_elasticsearch import *
-FAKE_ES_HOST = 'fakees'
-FAKE_ES_HOST_PATTERN = r'fakees.*'
+FAKE_ES_HOST = "fakees"
+FAKE_ES_HOST_PATTERN = r"fakees.*"
FAKE_ES_PORT = 443
FAKE_AWS_ACCESS_KEY = None
FAKE_AWS_SECRET_KEY = None
FAKE_AWS_REGION = None
+
@pytest.fixture()
def logs_model_config():
- conf = {
- 'LOGS_MODEL': 'elasticsearch',
- 'LOGS_MODEL_CONFIG': {
- 'producer': 'elasticsearch',
- 'elasticsearch_config': {
- 'host': FAKE_ES_HOST,
- 'port': FAKE_ES_PORT,
- 'access_key': FAKE_AWS_ACCESS_KEY,
- 'secret_key': FAKE_AWS_SECRET_KEY,
- 'aws_region': FAKE_AWS_REGION
- }
+ conf = {
+ "LOGS_MODEL": "elasticsearch",
+ "LOGS_MODEL_CONFIG": {
+ "producer": "elasticsearch",
+ "elasticsearch_config": {
+ "host": FAKE_ES_HOST,
+ "port": FAKE_ES_PORT,
+ "access_key": FAKE_AWS_ACCESS_KEY,
+ "secret_key": FAKE_AWS_SECRET_KEY,
+ "aws_region": FAKE_AWS_REGION,
+ },
+ },
}
- }
- return conf
+ return conf
-FAKE_LOG_ENTRY_KINDS = {'push_repo': 1, 'pull_repo': 2}
+FAKE_LOG_ENTRY_KINDS = {"push_repo": 1, "pull_repo": 2}
FAKE_NAMESPACES = {
- 'user1':
- Mock(id=1, organization="user1.organization", username="user1.username", email="user1.email",
- robot="user1.robot"),
- 'user2':
- Mock(id=2, organization="user2.organization", username="user2.username", email="user2.email",
- robot="user2.robot")
+ "user1": Mock(
+ id=1,
+ organization="user1.organization",
+ username="user1.username",
+ email="user1.email",
+ robot="user1.robot",
+ ),
+ "user2": Mock(
+ id=2,
+ organization="user2.organization",
+ username="user2.username",
+ email="user2.email",
+ robot="user2.robot",
+ ),
}
FAKE_REPOSITORIES = {
- 'user1/repo1': Mock(id=1, namespace_user=FAKE_NAMESPACES['user1']),
- 'user2/repo2': Mock(id=2, namespace_user=FAKE_NAMESPACES['user2']),
+ "user1/repo1": Mock(id=1, namespace_user=FAKE_NAMESPACES["user1"]),
+ "user2/repo2": Mock(id=2, namespace_user=FAKE_NAMESPACES["user2"]),
}
@pytest.fixture()
def logs_model():
- # prevent logs model from changing
- logs_model = LogsModelProxy()
- with patch('data.logs_model.logs_model', logs_model):
- yield logs_model
+ # prevent logs model from changing
+ logs_model = LogsModelProxy()
+ with patch("data.logs_model.logs_model", logs_model):
+ yield logs_model
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def app_config(logs_model_config):
- fake_config = {}
- fake_config.update(logs_model_config)
- with patch("data.logs_model.document_logs_model.config.app_config", fake_config):
- yield fake_config
+ fake_config = {}
+ fake_config.update(logs_model_config)
+ with patch("data.logs_model.document_logs_model.config.app_config", fake_config):
+ yield fake_config
@pytest.fixture()
def mock_page_size():
- with patch('data.logs_model.document_logs_model.PAGE_SIZE', 1):
- yield
+ with patch("data.logs_model.document_logs_model.PAGE_SIZE", 1):
+ yield
@pytest.fixture()
def mock_max_result_window():
- with patch('data.logs_model.document_logs_model.DEFAULT_RESULT_WINDOW', 1):
- yield
+ with patch("data.logs_model.document_logs_model.DEFAULT_RESULT_WINDOW", 1):
+ yield
@pytest.fixture
def mock_random_id():
- mock_random = Mock(return_value=233)
- with patch('data.logs_model.document_logs_model._random_id', mock_random):
- yield
+ mock_random = Mock(return_value=233)
+ with patch("data.logs_model.document_logs_model._random_id", mock_random):
+ yield
@pytest.fixture()
def mock_db_model():
- def get_user_map_by_ids(namespace_ids):
- mapping = {}
- for i in namespace_ids:
- for name in FAKE_NAMESPACES:
- if FAKE_NAMESPACES[name].id == i:
- mapping[i] = FAKE_NAMESPACES[name]
- return mapping
+ def get_user_map_by_ids(namespace_ids):
+ mapping = {}
+ for i in namespace_ids:
+ for name in FAKE_NAMESPACES:
+ if FAKE_NAMESPACES[name].id == i:
+ mapping[i] = FAKE_NAMESPACES[name]
+ return mapping
- model = Mock(
- user=Mock(
- get_namespace_user=FAKE_NAMESPACES.get,
- get_user_or_org=FAKE_NAMESPACES.get,
- get_user=FAKE_NAMESPACES.get,
- get_user_map_by_ids=get_user_map_by_ids,
- ),
- repository=Mock(get_repository=lambda user_name, repo_name: FAKE_REPOSITORIES.get(
- user_name + '/' + repo_name),
- ),
- log=Mock(
- _get_log_entry_kind=lambda name: FAKE_LOG_ENTRY_KINDS[name],
- _json_serialize=_json_serialize,
- get_log_entry_kinds=Mock(return_value=FAKE_LOG_ENTRY_KINDS),
- ),
- )
+ model = Mock(
+ user=Mock(
+ get_namespace_user=FAKE_NAMESPACES.get,
+ get_user_or_org=FAKE_NAMESPACES.get,
+ get_user=FAKE_NAMESPACES.get,
+ get_user_map_by_ids=get_user_map_by_ids,
+ ),
+ repository=Mock(
+ get_repository=lambda user_name, repo_name: FAKE_REPOSITORIES.get(
+ user_name + "/" + repo_name
+ )
+ ),
+ log=Mock(
+ _get_log_entry_kind=lambda name: FAKE_LOG_ENTRY_KINDS[name],
+ _json_serialize=_json_serialize,
+ get_log_entry_kinds=Mock(return_value=FAKE_LOG_ENTRY_KINDS),
+ ),
+ )
- with patch('data.logs_model.document_logs_model.model', model), patch(
- 'data.logs_model.datatypes.model', model):
- yield
+ with patch("data.logs_model.document_logs_model.model", model), patch(
+ "data.logs_model.datatypes.model", model
+ ):
+ yield
def parse_query(query):
- return {s.split('=')[0]: s.split('=')[1] for s in query.split("&") if s != ""}
+ return {s.split("=")[0]: s.split("=")[1] for s in query.split("&") if s != ""}
@pytest.fixture()
def mock_elasticsearch():
- mock = Mock()
- mock.template.side_effect = NotImplementedError
- mock.index.side_effect = NotImplementedError
- mock.count.side_effect = NotImplementedError
- mock.scroll_get.side_effect = NotImplementedError
- mock.scroll_delete.side_effect = NotImplementedError
- mock.search_scroll_create.side_effect = NotImplementedError
- mock.search_aggs.side_effect = NotImplementedError
- mock.search_after.side_effect = NotImplementedError
- mock.list_indices.side_effect = NotImplementedError
+ mock = Mock()
+ mock.template.side_effect = NotImplementedError
+ mock.index.side_effect = NotImplementedError
+ mock.count.side_effect = NotImplementedError
+ mock.scroll_get.side_effect = NotImplementedError
+ mock.scroll_delete.side_effect = NotImplementedError
+ mock.search_scroll_create.side_effect = NotImplementedError
+ mock.search_aggs.side_effect = NotImplementedError
+ mock.search_after.side_effect = NotImplementedError
+ mock.list_indices.side_effect = NotImplementedError
- @urlmatch(netloc=r'.*', path=r'.*')
- def default(url, req):
- raise Exception('\nurl={}\nmethod={}\nreq.url={}\nheaders={}\nbody={}'.format(
- url, req.method, req.url, req.headers, req.body))
+ @urlmatch(netloc=r".*", path=r".*")
+ def default(url, req):
+ raise Exception(
+ "\nurl={}\nmethod={}\nreq.url={}\nheaders={}\nbody={}".format(
+ url, req.method, req.url, req.headers, req.body
+ )
+ )
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/_template/.*')
- def template(url, req):
- return mock.template(url.query.split('/')[-1], req.body)
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/_template/.*")
+ def template(url, req):
+ return mock.template(url.query.split("/")[-1], req.body)
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/logentry_(\*|[0-9\-]+)')
- def list_indices(url, req):
- return mock.list_indices()
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/logentry_(\*|[0-9\-]+)")
+ def list_indices(url, req):
+ return mock.list_indices()
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/logentry_[0-9\-]*/_doc')
- def index(url, req):
- index = url.path.split('/')[1]
- body = json.loads(req.body)
- body['metadata_json'] = json.loads(body['metadata_json'])
- return mock.index(index, body)
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/logentry_[0-9\-]*/_doc")
+ def index(url, req):
+ index = url.path.split("/")[1]
+ body = json.loads(req.body)
+ body["metadata_json"] = json.loads(body["metadata_json"])
+ return mock.index(index, body)
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/logentry_([0-9\-]*|\*)/_count')
- def count(_, req):
- return mock.count(json.loads(req.body))
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/logentry_([0-9\-]*|\*)/_count")
+ def count(_, req):
+ return mock.count(json.loads(req.body))
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/_search/scroll')
- def scroll(url, req):
- if req.method == 'DELETE':
- return mock.scroll_delete(json.loads(req.body))
- elif req.method == 'GET':
- request_obj = json.loads(req.body)
- return mock.scroll_get(request_obj)
- raise NotImplementedError()
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/_search/scroll")
+ def scroll(url, req):
+ if req.method == "DELETE":
+ return mock.scroll_delete(json.loads(req.body))
+ elif req.method == "GET":
+ request_obj = json.loads(req.body)
+ return mock.scroll_get(request_obj)
+ raise NotImplementedError()
- @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r'/logentry_(\*|[0-9\-]*)/_search')
- def search(url, req):
- if "scroll" in url.query:
- query = parse_query(url.query)
- window_size = query['scroll']
- maximum_result_size = int(query['size'])
- return mock.search_scroll_create(window_size, maximum_result_size, json.loads(req.body))
- elif "aggs" in req.body:
- return mock.search_aggs(json.loads(req.body))
- else:
- return mock.search_after(json.loads(req.body))
+ @urlmatch(netloc=FAKE_ES_HOST_PATTERN, path=r"/logentry_(\*|[0-9\-]*)/_search")
+ def search(url, req):
+ if "scroll" in url.query:
+ query = parse_query(url.query)
+ window_size = query["scroll"]
+ maximum_result_size = int(query["size"])
+ return mock.search_scroll_create(
+ window_size, maximum_result_size, json.loads(req.body)
+ )
+ elif "aggs" in req.body:
+ return mock.search_aggs(json.loads(req.body))
+ else:
+ return mock.search_after(json.loads(req.body))
- with HTTMock(scroll, count, search, index, template, list_indices, default):
- yield mock
+ with HTTMock(scroll, count, search, index, template, list_indices, default):
+ yield mock
@pytest.mark.parametrize(
- """
+ """
unlogged_pulls_ok, kind_name, namespace_name, repository, repository_name,
timestamp,
index_response, expected_request, throws
""",
- [
- # Invalid inputs
- pytest.param(
- False, 'non-existing', None, None, None,
- None,
- None, None, True,
- id="Invalid Kind"
- ),
- pytest.param(
- False, 'pull_repo', 'user1', Mock(id=1), 'repo1',
- None,
- None, None, True,
- id="Invalid Parameters"
- ),
+ [
+ # Invalid inputs
+ pytest.param(
+ False,
+ "non-existing",
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ True,
+ id="Invalid Kind",
+ ),
+ pytest.param(
+ False,
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ "repo1",
+ None,
+ None,
+ None,
+ True,
+ id="Invalid Parameters",
+ ),
+ # Remote exceptions
+ pytest.param(
+ False,
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ None,
+ None,
+ FAILURE_400,
+ None,
+ True,
+ id="Throw on pull log failure",
+ ),
+ pytest.param(
+ True,
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ None,
+ parse("2017-03-08T03:30"),
+ FAILURE_400,
+ INDEX_REQUEST_2017_03_08,
+ False,
+ id="Ok on pull log failure",
+ ),
+ # Success executions
+ pytest.param(
+ False,
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ None,
+ parse("2017-03-08T03:30"),
+ INDEX_RESPONSE_2017_03_08,
+ INDEX_REQUEST_2017_03_08,
+ False,
+ id="Log with namespace name and repository",
+ ),
+ pytest.param(
+ False,
+ "push_repo",
+ "user1",
+ None,
+ "repo1",
+ parse("2019-01-01T03:30"),
+ INDEX_RESPONSE_2019_01_01,
+ INDEX_REQUEST_2019_01_01,
+ False,
+ id="Log with namespace name and repository name",
+ ),
+ ],
+)
+def test_log_action(
+ unlogged_pulls_ok,
+ kind_name,
+ namespace_name,
+ repository,
+ repository_name,
+ timestamp,
+ index_response,
+ expected_request,
+ throws,
+ app_config,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+ mock_random_id,
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+ mock_elasticsearch.index = Mock(return_value=index_response)
+ app_config["ALLOW_PULLS_WITHOUT_STRICT_LOGGING"] = unlogged_pulls_ok
+ configure(app_config)
- # Remote exceptions
- pytest.param(
- False, 'pull_repo', 'user1', Mock(id=1), None,
- None,
- FAILURE_400, None, True,
- id="Throw on pull log failure"
- ),
- pytest.param(
- True, 'pull_repo', 'user1', Mock(id=1), None,
- parse("2017-03-08T03:30"),
- FAILURE_400, INDEX_REQUEST_2017_03_08, False,
- id="Ok on pull log failure"
- ),
-
- # Success executions
- pytest.param(
- False, 'pull_repo', 'user1', Mock(id=1), None,
- parse("2017-03-08T03:30"),
- INDEX_RESPONSE_2017_03_08, INDEX_REQUEST_2017_03_08, False,
- id="Log with namespace name and repository"
- ),
- pytest.param(
- False, 'push_repo', 'user1', None, 'repo1',
- parse("2019-01-01T03:30"),
- INDEX_RESPONSE_2019_01_01, INDEX_REQUEST_2019_01_01, False,
- id="Log with namespace name and repository name"
- ),
- ])
-def test_log_action(unlogged_pulls_ok, kind_name, namespace_name, repository, repository_name,
- timestamp,
- index_response, expected_request, throws,
- app_config, logs_model, mock_elasticsearch, mock_db_model, mock_random_id):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- mock_elasticsearch.index = Mock(return_value=index_response)
- app_config['ALLOW_PULLS_WITHOUT_STRICT_LOGGING'] = unlogged_pulls_ok
- configure(app_config)
-
- performer = Mock(id=1)
- ip = "192.168.1.1"
- metadata = {'key': 'value', 'time': parse("2018-03-08T03:30"), '😂': '😂👌👌👌👌'}
- if throws:
- with pytest.raises(Exception):
- logs_model.log_action(kind_name, namespace_name, performer, ip, metadata, repository,
- repository_name, timestamp)
- else:
- logs_model.log_action(kind_name, namespace_name, performer, ip, metadata, repository,
- repository_name, timestamp)
- mock_elasticsearch.index.assert_called_with(*expected_request)
+ performer = Mock(id=1)
+ ip = "192.168.1.1"
+ metadata = {"key": "value", "time": parse("2018-03-08T03:30"), "😂": "😂👌👌👌👌"}
+ if throws:
+ with pytest.raises(Exception):
+ logs_model.log_action(
+ kind_name,
+ namespace_name,
+ performer,
+ ip,
+ metadata,
+ repository,
+ repository_name,
+ timestamp,
+ )
+ else:
+ logs_model.log_action(
+ kind_name,
+ namespace_name,
+ performer,
+ ip,
+ metadata,
+ repository,
+ repository_name,
+ timestamp,
+ )
+ mock_elasticsearch.index.assert_called_with(*expected_request)
@pytest.mark.parametrize(
- """
+ """
start_datetime, end_datetime,
performer_name, repository_name, namespace_name,
filter_kinds,
@@ -273,257 +357,377 @@ def test_log_action(unlogged_pulls_ok, kind_name, namespace_name, repository, re
expected_page,
throws
""",
- [
- # 1st page
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-08T03:30'),
- 'user1', 'repo1', 'user1',
- None,
- None,
- None,
- SEARCH_RESPONSE_START,
- INDEX_LIST_RESPONSE_HIT1_HIT2,
- SEARCH_REQUEST_START,
- SEARCH_PAGE_START,
- False,
- id="1st page"
- ),
+ [
+ # 1st page
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-08T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ None,
+ None,
+ None,
+ SEARCH_RESPONSE_START,
+ INDEX_LIST_RESPONSE_HIT1_HIT2,
+ SEARCH_REQUEST_START,
+ SEARCH_PAGE_START,
+ False,
+ id="1st page",
+ ),
+ # Last page
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-08T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ None,
+ SEARCH_PAGE_TOKEN,
+ None,
+ SEARCH_RESPONSE_END,
+ INDEX_LIST_RESPONSE_HIT1_HIT2,
+ SEARCH_REQUEST_END,
+ SEARCH_PAGE_END,
+ False,
+ id="Search using pagination token",
+ ),
+ # Filter
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-08T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ ["push_repo"],
+ None,
+ None,
+ SEARCH_RESPONSE_END,
+ INDEX_LIST_RESPONSE_HIT2,
+ SEARCH_REQUEST_FILTER,
+ SEARCH_PAGE_END,
+ False,
+ id="Filtered search",
+ ),
+ # Max page count
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-08T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ None,
+ SEARCH_PAGE_TOKEN,
+ 1,
+ AssertionError, # Assert that it should not reach the ES server
+ None,
+ None,
+ SEARCH_PAGE_EMPTY,
+ False,
+ id="Page token reaches maximum page count",
+ ),
+ ],
+)
+def test_lookup_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ page_token,
+ max_page_count,
+ search_response,
+ list_indices_response,
+ expected_request,
+ expected_page,
+ throws,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+ mock_page_size,
+ app_config,
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+ mock_elasticsearch.search_after = Mock(return_value=search_response)
+ mock_elasticsearch.list_indices = Mock(return_value=list_indices_response)
- # Last page
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-08T03:30'),
- 'user1', 'repo1', 'user1',
- None,
- SEARCH_PAGE_TOKEN,
- None,
- SEARCH_RESPONSE_END,
- INDEX_LIST_RESPONSE_HIT1_HIT2,
- SEARCH_REQUEST_END,
- SEARCH_PAGE_END,
- False,
- id="Search using pagination token"
- ),
-
- # Filter
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-08T03:30'),
- 'user1', 'repo1', 'user1',
- ['push_repo'],
- None,
- None,
- SEARCH_RESPONSE_END,
- INDEX_LIST_RESPONSE_HIT2,
- SEARCH_REQUEST_FILTER,
- SEARCH_PAGE_END,
- False,
- id="Filtered search"
- ),
-
- # Max page count
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-08T03:30'),
- 'user1', 'repo1', 'user1',
- None,
- SEARCH_PAGE_TOKEN,
- 1,
- AssertionError, # Assert that it should not reach the ES server
- None,
- None,
- SEARCH_PAGE_EMPTY,
- False,
- id="Page token reaches maximum page count",
- ),
- ])
-def test_lookup_logs(start_datetime, end_datetime,
- performer_name, repository_name, namespace_name,
- filter_kinds,
- page_token,
- max_page_count,
- search_response,
- list_indices_response,
- expected_request,
- expected_page,
- throws,
- logs_model, mock_elasticsearch, mock_db_model, mock_page_size, app_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- mock_elasticsearch.search_after = Mock(return_value=search_response)
- mock_elasticsearch.list_indices = Mock(return_value=list_indices_response)
-
- configure(app_config)
- if throws:
- with pytest.raises(Exception):
- logs_model.lookup_logs(start_datetime, end_datetime, performer_name, repository_name,
- namespace_name, filter_kinds, page_token, max_page_count)
- else:
- page = logs_model.lookup_logs(start_datetime, end_datetime, performer_name, repository_name,
- namespace_name, filter_kinds, page_token, max_page_count)
- assert page == expected_page
- if expected_request:
- mock_elasticsearch.search_after.assert_called_with(expected_request)
+ configure(app_config)
+ if throws:
+ with pytest.raises(Exception):
+ logs_model.lookup_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ page_token,
+ max_page_count,
+ )
+ else:
+ page = logs_model.lookup_logs(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ page_token,
+ max_page_count,
+ )
+ assert page == expected_page
+ if expected_request:
+ mock_elasticsearch.search_after.assert_called_with(expected_request)
@pytest.mark.parametrize(
- """
+ """
start_datetime, end_datetime,
performer_name, repository_name, namespace_name,
filter_kinds, search_response, expected_request, expected_counts, throws
""",
- [
- # Valid
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-08T03:30'),
- 'user1', 'repo1', 'user1',
- ['pull_repo'], AGGS_RESPONSE, AGGS_REQUEST, AGGS_COUNT, False,
- id="Valid Counts"
- ),
+ [
+ # Valid
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-08T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ ["pull_repo"],
+ AGGS_RESPONSE,
+ AGGS_REQUEST,
+ AGGS_COUNT,
+ False,
+ id="Valid Counts",
+ ),
+ # Invalid case: date range too big
+ pytest.param(
+ parse("2018-03-08T03:30"),
+ parse("2018-04-09T03:30"),
+ "user1",
+ "repo1",
+ "user1",
+ [],
+ None,
+ None,
+ None,
+ True,
+ id="Throw on date range too big",
+ ),
+ ],
+)
+def test_get_aggregated_log_counts(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ search_response,
+ expected_request,
+ expected_counts,
+ throws,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+ app_config,
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+ mock_elasticsearch.search_aggs = Mock(return_value=search_response)
- # Invalid case: date range too big
- pytest.param(
- parse('2018-03-08T03:30'), parse('2018-04-09T03:30'),
- 'user1', 'repo1', 'user1',
- [], None, None, None, True,
- id="Throw on date range too big"
- )
- ])
-def test_get_aggregated_log_counts(start_datetime, end_datetime,
- performer_name, repository_name, namespace_name,
- filter_kinds, search_response, expected_request, expected_counts, throws,
- logs_model, mock_elasticsearch, mock_db_model, app_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- mock_elasticsearch.search_aggs = Mock(return_value=search_response)
-
- configure(app_config)
- if throws:
- with pytest.raises(Exception):
- logs_model.get_aggregated_log_counts(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds)
- else:
- counts = logs_model.get_aggregated_log_counts(start_datetime, end_datetime, performer_name,
- repository_name, namespace_name, filter_kinds)
- assert set(counts) == set(expected_counts)
- if expected_request:
- mock_elasticsearch.search_aggs.assert_called_with(expected_request)
+ configure(app_config)
+ if throws:
+ with pytest.raises(Exception):
+ logs_model.get_aggregated_log_counts(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ )
+ else:
+ counts = logs_model.get_aggregated_log_counts(
+ start_datetime,
+ end_datetime,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ )
+ assert set(counts) == set(expected_counts)
+ if expected_request:
+ mock_elasticsearch.search_aggs.assert_called_with(expected_request)
@pytest.mark.parametrize(
- """
+ """
repository,
day,
count_response, expected_request, expected_count, throws
""",
- [
- pytest.param(
- FAKE_REPOSITORIES['user1/repo1'],
- parse("2018-03-08").date(),
- COUNT_RESPONSE, COUNT_REQUEST, 1, False,
- id="Valid Count with 1 as result"),
- ])
-def test_count_repository_actions(repository,
- day,
- count_response, expected_request, expected_count, throws,
- logs_model, mock_elasticsearch, mock_db_model, app_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- mock_elasticsearch.count = Mock(return_value=count_response)
- mock_elasticsearch.list_indices = Mock(return_value=INDEX_LIST_RESPONSE)
+ [
+ pytest.param(
+ FAKE_REPOSITORIES["user1/repo1"],
+ parse("2018-03-08").date(),
+ COUNT_RESPONSE,
+ COUNT_REQUEST,
+ 1,
+ False,
+ id="Valid Count with 1 as result",
+ )
+ ],
+)
+def test_count_repository_actions(
+ repository,
+ day,
+ count_response,
+ expected_request,
+ expected_count,
+ throws,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+ app_config,
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+ mock_elasticsearch.count = Mock(return_value=count_response)
+ mock_elasticsearch.list_indices = Mock(return_value=INDEX_LIST_RESPONSE)
- configure(app_config)
- if throws:
- with pytest.raises(Exception):
- logs_model.count_repository_actions(repository, day)
- else:
- count = logs_model.count_repository_actions(repository, day)
- assert count == expected_count
- if expected_request:
- mock_elasticsearch.count.assert_called_with(expected_request)
+ configure(app_config)
+ if throws:
+ with pytest.raises(Exception):
+ logs_model.count_repository_actions(repository, day)
+ else:
+ count = logs_model.count_repository_actions(repository, day)
+ assert count == expected_count
+ if expected_request:
+ mock_elasticsearch.count.assert_called_with(expected_request)
@pytest.mark.parametrize(
- """
+ """
start_datetime, end_datetime,
repository_id, namespace_id,
max_query_time, scroll_responses, expected_requests, expected_logs, throws
""",
- [
- pytest.param(
- parse("2018-03-08"), parse("2018-04-02"),
- 1, 1,
- timedelta(seconds=10), SCROLL_RESPONSES, SCROLL_REQUESTS, SCROLL_LOGS, False,
- id="Scroll 3 pages with page size = 1"
- ),
- ])
-def test_yield_logs_for_export(start_datetime, end_datetime,
- repository_id, namespace_id,
- max_query_time, scroll_responses, expected_requests, expected_logs, throws,
- logs_model, mock_elasticsearch, mock_db_model, mock_max_result_window, app_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- mock_elasticsearch.search_scroll_create = Mock(return_value=scroll_responses[0])
- mock_elasticsearch.scroll_get = Mock(side_effect=scroll_responses[1:-1])
- mock_elasticsearch.scroll_delete = Mock(return_value=scroll_responses[-1])
+ [
+ pytest.param(
+ parse("2018-03-08"),
+ parse("2018-04-02"),
+ 1,
+ 1,
+ timedelta(seconds=10),
+ SCROLL_RESPONSES,
+ SCROLL_REQUESTS,
+ SCROLL_LOGS,
+ False,
+ id="Scroll 3 pages with page size = 1",
+ )
+ ],
+)
+def test_yield_logs_for_export(
+ start_datetime,
+ end_datetime,
+ repository_id,
+ namespace_id,
+ max_query_time,
+ scroll_responses,
+ expected_requests,
+ expected_logs,
+ throws,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+ mock_max_result_window,
+ app_config,
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+ mock_elasticsearch.search_scroll_create = Mock(return_value=scroll_responses[0])
+ mock_elasticsearch.scroll_get = Mock(side_effect=scroll_responses[1:-1])
+ mock_elasticsearch.scroll_delete = Mock(return_value=scroll_responses[-1])
- configure(app_config)
- if throws:
- with pytest.raises(Exception):
- logs_model.yield_logs_for_export(start_datetime, end_datetime, max_query_time=max_query_time)
- else:
- log_generator = logs_model.yield_logs_for_export(start_datetime, end_datetime,
- max_query_time=max_query_time)
- counter = 0
- for logs in log_generator:
- if counter == 0:
- mock_elasticsearch.search_scroll_create.assert_called_with(*expected_requests[counter])
- else:
- mock_elasticsearch.scroll_get.assert_called_with(*expected_requests[counter])
- assert expected_logs[counter] == logs
- counter += 1
- # the last two requests must be
- # 1. get with response scroll with 0 hits, which indicates the termination condition
- # 2. delete scroll request
- mock_elasticsearch.scroll_get.assert_called_with(*expected_requests[-2])
- mock_elasticsearch.scroll_delete.assert_called_with(*expected_requests[-1])
+ configure(app_config)
+ if throws:
+ with pytest.raises(Exception):
+ logs_model.yield_logs_for_export(
+ start_datetime, end_datetime, max_query_time=max_query_time
+ )
+ else:
+ log_generator = logs_model.yield_logs_for_export(
+ start_datetime, end_datetime, max_query_time=max_query_time
+ )
+ counter = 0
+ for logs in log_generator:
+ if counter == 0:
+ mock_elasticsearch.search_scroll_create.assert_called_with(
+ *expected_requests[counter]
+ )
+ else:
+ mock_elasticsearch.scroll_get.assert_called_with(
+ *expected_requests[counter]
+ )
+ assert expected_logs[counter] == logs
+ counter += 1
+ # the last two requests must be
+ # 1. get with response scroll with 0 hits, which indicates the termination condition
+ # 2. delete scroll request
+ mock_elasticsearch.scroll_get.assert_called_with(*expected_requests[-2])
+ mock_elasticsearch.scroll_delete.assert_called_with(*expected_requests[-1])
-@pytest.mark.parametrize('prefix, is_valid', [
- pytest.param('..', False, id='Invalid `..`'),
- pytest.param('.', False, id='Invalid `.`'),
- pytest.param('-prefix', False, id='Invalid prefix start -'),
- pytest.param('_prefix', False, id='Invalid prefix start _'),
- pytest.param('+prefix', False, id='Invalid prefix start +'),
- pytest.param('prefix_with_UPPERCASES', False, id='Invalid uppercase'),
- pytest.param('valid_index', True, id='Valid prefix'),
- pytest.param('valid_index_with_numbers1234', True, id='Valid prefix with numbers'),
- pytest.param('a'*256, False, id='Prefix too long')
-])
+@pytest.mark.parametrize(
+ "prefix, is_valid",
+ [
+ pytest.param("..", False, id="Invalid `..`"),
+ pytest.param(".", False, id="Invalid `.`"),
+ pytest.param("-prefix", False, id="Invalid prefix start -"),
+ pytest.param("_prefix", False, id="Invalid prefix start _"),
+ pytest.param("+prefix", False, id="Invalid prefix start +"),
+ pytest.param("prefix_with_UPPERCASES", False, id="Invalid uppercase"),
+ pytest.param("valid_index", True, id="Valid prefix"),
+ pytest.param(
+ "valid_index_with_numbers1234", True, id="Valid prefix with numbers"
+ ),
+ pytest.param("a" * 256, False, id="Prefix too long"),
+ ],
+)
def test_valid_index_prefix(prefix, is_valid):
- assert ElasticsearchLogs._valid_index_prefix(prefix) == is_valid
+ assert ElasticsearchLogs._valid_index_prefix(prefix) == is_valid
-@pytest.mark.parametrize('index, cutoff_date, expected_result', [
- pytest.param(
- INDEX_NAME_PREFIX+'2019-06-06',
- datetime(2019, 6, 8),
- True,
- id="Index older than cutoff"
- ),
- pytest.param(
- INDEX_NAME_PREFIX+'2019-06-06',
- datetime(2019, 6, 4),
- False,
- id="Index younger than cutoff"
- ),
- pytest.param(
- INDEX_NAME_PREFIX+'2019-06-06',
- datetime(2019, 6, 6, 23),
- False,
- id="Index older than cutoff but timedelta less than 1 day"
- ),
- pytest.param(
- INDEX_NAME_PREFIX+'2019-06-06',
- datetime(2019, 6, 7),
- True,
- id="Index older than cutoff by exactly one day"
- ),
-])
+@pytest.mark.parametrize(
+ "index, cutoff_date, expected_result",
+ [
+ pytest.param(
+ INDEX_NAME_PREFIX + "2019-06-06",
+ datetime(2019, 6, 8),
+ True,
+ id="Index older than cutoff",
+ ),
+ pytest.param(
+ INDEX_NAME_PREFIX + "2019-06-06",
+ datetime(2019, 6, 4),
+ False,
+ id="Index younger than cutoff",
+ ),
+ pytest.param(
+ INDEX_NAME_PREFIX + "2019-06-06",
+ datetime(2019, 6, 6, 23),
+ False,
+ id="Index older than cutoff but timedelta less than 1 day",
+ ),
+ pytest.param(
+ INDEX_NAME_PREFIX + "2019-06-06",
+ datetime(2019, 6, 7),
+ True,
+ id="Index older than cutoff by exactly one day",
+ ),
+ ],
+)
def test_can_delete_index(index, cutoff_date, expected_result):
- es = ElasticsearchLogs(index_prefix=INDEX_NAME_PREFIX)
- assert datetime.strptime(index.split(es._index_prefix, 1)[-1], INDEX_DATE_FORMAT)
- assert es.can_delete_index(index, cutoff_date) == expected_result
+ es = ElasticsearchLogs(index_prefix=INDEX_NAME_PREFIX)
+ assert datetime.strptime(index.split(es._index_prefix, 1)[-1], INDEX_DATE_FORMAT)
+ assert es.can_delete_index(index, cutoff_date) == expected_result
diff --git a/data/logs_model/test/test_logs_interface.py b/data/logs_model/test/test_logs_interface.py
index 8f4f143c0..74dbe0298 100644
--- a/data/logs_model/test/test_logs_interface.py
+++ b/data/logs_model/test/test_logs_interface.py
@@ -4,7 +4,10 @@ from data.logs_model.table_logs_model import TableLogsModel
from data.logs_model.combined_model import CombinedLogsModel
from data.logs_model.inmemory_model import InMemoryModel
from data.logs_model.combined_model import _merge_aggregated_log_counts
-from data.logs_model.document_logs_model import _date_range_in_single_index, DocumentLogsModel
+from data.logs_model.document_logs_model import (
+ _date_range_in_single_index,
+ DocumentLogsModel,
+)
from data.logs_model.interface import LogsIterationTimeout
from data.logs_model.test.fake_elasticsearch import FAKE_ES_HOST, fake_elasticsearch
@@ -16,279 +19,293 @@ from test.fixtures import *
@pytest.fixture()
def mock_page_size():
- page_size = 2
- with patch('data.logs_model.document_logs_model.PAGE_SIZE', page_size):
- yield page_size
+ page_size = 2
+ with patch("data.logs_model.document_logs_model.PAGE_SIZE", page_size):
+ yield page_size
@pytest.fixture()
def clear_db_logs(initialized_db):
- LogEntry.delete().execute()
- LogEntry2.delete().execute()
- LogEntry3.delete().execute()
+ LogEntry.delete().execute()
+ LogEntry2.delete().execute()
+ LogEntry3.delete().execute()
def combined_model():
- return CombinedLogsModel(TableLogsModel(), InMemoryModel())
+ return CombinedLogsModel(TableLogsModel(), InMemoryModel())
def es_model():
- return DocumentLogsModel(producer='elasticsearch', elasticsearch_config={
- 'host': FAKE_ES_HOST,
- 'port': 12345,
- })
+ return DocumentLogsModel(
+ producer="elasticsearch",
+ elasticsearch_config={"host": FAKE_ES_HOST, "port": 12345},
+ )
+
@pytest.fixture()
def fake_es():
- with fake_elasticsearch():
- yield
+ with fake_elasticsearch():
+ yield
@pytest.fixture(params=[TableLogsModel, InMemoryModel, es_model, combined_model])
def logs_model(request, clear_db_logs, fake_es):
- return request.param()
+ return request.param()
def _lookup_logs(logs_model, start_time, end_time, **kwargs):
- logs_found = []
- page_token = None
- while True:
- found = logs_model.lookup_logs(start_time, end_time, page_token=page_token, **kwargs)
- logs_found.extend(found.logs)
- page_token = found.next_page_token
- if not found.logs or not page_token:
- break
+ logs_found = []
+ page_token = None
+ while True:
+ found = logs_model.lookup_logs(
+ start_time, end_time, page_token=page_token, **kwargs
+ )
+ logs_found.extend(found.logs)
+ page_token = found.next_page_token
+ if not found.logs or not page_token:
+ break
- assert len(logs_found) == len(set(logs_found))
- return logs_found
+ assert len(logs_found) == len(set(logs_found))
+ return logs_found
-@pytest.mark.skipif(os.environ.get('TEST_DATABASE_URI', '').find('mysql') >= 0,
- reason='Flaky on MySQL')
-@pytest.mark.parametrize('namespace_name, repo_name, performer_name, check_args, expect_results', [
- pytest.param('devtable', 'simple', 'devtable', {}, True, id='no filters'),
- pytest.param('devtable', 'simple', 'devtable', {
- 'performer_name': 'devtable',
- }, True, id='matching performer'),
+@pytest.mark.skipif(
+ os.environ.get("TEST_DATABASE_URI", "").find("mysql") >= 0, reason="Flaky on MySQL"
+)
+@pytest.mark.parametrize(
+ "namespace_name, repo_name, performer_name, check_args, expect_results",
+ [
+ pytest.param("devtable", "simple", "devtable", {}, True, id="no filters"),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"performer_name": "devtable"},
+ True,
+ id="matching performer",
+ ),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"namespace_name": "devtable"},
+ True,
+ id="matching namespace",
+ ),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"namespace_name": "devtable", "repository_name": "simple"},
+ True,
+ id="matching repository",
+ ),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"performer_name": "public"},
+ False,
+ id="different performer",
+ ),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"namespace_name": "public"},
+ False,
+ id="different namespace",
+ ),
+ pytest.param(
+ "devtable",
+ "simple",
+ "devtable",
+ {"namespace_name": "devtable", "repository_name": "complex"},
+ False,
+ id="different repository",
+ ),
+ ],
+)
+def test_logs(
+ namespace_name, repo_name, performer_name, check_args, expect_results, logs_model
+):
+ # Add some logs.
+ kinds = list(LogEntryKind.select())
+ user = model.user.get_user(performer_name)
- pytest.param('devtable', 'simple', 'devtable', {
- 'namespace_name': 'devtable',
- }, True, id='matching namespace'),
+ start_timestamp = datetime.utcnow()
+ timestamp = start_timestamp
- pytest.param('devtable', 'simple', 'devtable', {
- 'namespace_name': 'devtable',
- 'repository_name': 'simple',
- }, True, id='matching repository'),
+ for kind in kinds:
+ for index in range(0, 3):
+ logs_model.log_action(
+ kind.name,
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="1.2.3.4",
+ timestamp=timestamp,
+ )
+ timestamp = timestamp + timedelta(seconds=1)
- pytest.param('devtable', 'simple', 'devtable', {
- 'performer_name': 'public',
- }, False, id='different performer'),
+ found = _lookup_logs(
+ logs_model,
+ start_timestamp,
+ start_timestamp + timedelta(minutes=10),
+ **check_args
+ )
+ if expect_results:
+ assert len(found) == len(kinds) * 3
+ else:
+ assert not found
- pytest.param('devtable', 'simple', 'devtable', {
- 'namespace_name': 'public',
- }, False, id='different namespace'),
-
- pytest.param('devtable', 'simple', 'devtable', {
- 'namespace_name': 'devtable',
- 'repository_name': 'complex',
- }, False, id='different repository'),
-])
-def test_logs(namespace_name, repo_name, performer_name, check_args, expect_results, logs_model):
- # Add some logs.
- kinds = list(LogEntryKind.select())
- user = model.user.get_user(performer_name)
-
- start_timestamp = datetime.utcnow()
- timestamp = start_timestamp
-
- for kind in kinds:
- for index in range(0, 3):
- logs_model.log_action(kind.name, namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='1.2.3.4', timestamp=timestamp)
- timestamp = timestamp + timedelta(seconds=1)
-
- found = _lookup_logs(logs_model, start_timestamp, start_timestamp + timedelta(minutes=10),
- **check_args)
- if expect_results:
- assert len(found) == len(kinds) * 3
- else:
- assert not found
-
- aggregated_counts = logs_model.get_aggregated_log_counts(start_timestamp,
- start_timestamp + timedelta(minutes=10),
- **check_args)
- if expect_results:
- assert len(aggregated_counts) == len(kinds)
- for ac in aggregated_counts:
- assert ac.count == 3
- else:
- assert not aggregated_counts
+ aggregated_counts = logs_model.get_aggregated_log_counts(
+ start_timestamp, start_timestamp + timedelta(minutes=10), **check_args
+ )
+ if expect_results:
+ assert len(aggregated_counts) == len(kinds)
+ for ac in aggregated_counts:
+ assert ac.count == 3
+ else:
+ assert not aggregated_counts
-@pytest.mark.parametrize('filter_kinds, expect_results', [
- pytest.param(None, True),
- pytest.param(['push_repo'], True, id='push_repo filter'),
- pytest.param(['pull_repo'], True, id='pull_repo filter'),
- pytest.param(['push_repo', 'pull_repo'], False, id='push and pull filters')
-])
+@pytest.mark.parametrize(
+ "filter_kinds, expect_results",
+ [
+ pytest.param(None, True),
+ pytest.param(["push_repo"], True, id="push_repo filter"),
+ pytest.param(["pull_repo"], True, id="pull_repo filter"),
+ pytest.param(["push_repo", "pull_repo"], False, id="push and pull filters"),
+ ],
+)
def test_lookup_latest_logs(filter_kinds, expect_results, logs_model):
- kind_map = model.log.get_log_entry_kinds()
- if filter_kinds:
- ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds if filter_kinds]
- else:
- ignore_ids = []
+ kind_map = model.log.get_log_entry_kinds()
+ if filter_kinds:
+ ignore_ids = [kind_map[kind_name] for kind_name in filter_kinds if filter_kinds]
+ else:
+ ignore_ids = []
- now = datetime.now()
- namespace_name = 'devtable'
- repo_name = 'simple'
- performer_name = 'devtable'
+ now = datetime.now()
+ namespace_name = "devtable"
+ repo_name = "simple"
+ performer_name = "devtable"
- user = model.user.get_user(performer_name)
- size = 3
+ user = model.user.get_user(performer_name)
+ size = 3
- # Log some push actions
- logs_model.log_action('push_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=1, seconds=11))
- logs_model.log_action('push_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=7, seconds=33))
+ # Log some push actions
+ logs_model.log_action(
+ "push_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=1, seconds=11),
+ )
+ logs_model.log_action(
+ "push_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=7, seconds=33),
+ )
- # Log some pull actions
- logs_model.log_action('pull_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=0, seconds=3))
- logs_model.log_action('pull_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=3, seconds=55))
- logs_model.log_action('pull_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=5, seconds=3))
- logs_model.log_action('pull_repo', namespace_name=namespace_name, repository_name=repo_name,
- performer=user, ip='0.0.0.0', timestamp=now-timedelta(days=11, seconds=11))
+ # Log some pull actions
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=0, seconds=3),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=3, seconds=55),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=5, seconds=3),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ performer=user,
+ ip="0.0.0.0",
+ timestamp=now - timedelta(days=11, seconds=11),
+ )
- # Get the latest logs
- latest_logs = logs_model.lookup_latest_logs(performer_name, repo_name, namespace_name,
- filter_kinds=filter_kinds, size=size)
+ # Get the latest logs
+ latest_logs = logs_model.lookup_latest_logs(
+ performer_name, repo_name, namespace_name, filter_kinds=filter_kinds, size=size
+ )
- # Test max lookup size
- assert len(latest_logs) <= size
+ # Test max lookup size
+ assert len(latest_logs) <= size
- # Make sure that the latest logs returned are in decreasing order
- assert all(x >= y for x, y in zip(latest_logs, latest_logs[1:]))
+ # Make sure that the latest logs returned are in decreasing order
+ assert all(x >= y for x, y in zip(latest_logs, latest_logs[1:]))
- if expect_results:
- assert latest_logs
+ if expect_results:
+ assert latest_logs
- # Lookup all logs filtered by kinds and sort them in reverse chronological order
- all_logs = _lookup_logs(logs_model, now - timedelta(days=30), now + timedelta(days=30),
- filter_kinds=filter_kinds, namespace_name=namespace_name,
- repository_name=repo_name)
- all_logs = sorted(all_logs, key=lambda l: l.datetime, reverse=True)
+ # Lookup all logs filtered by kinds and sort them in reverse chronological order
+ all_logs = _lookup_logs(
+ logs_model,
+ now - timedelta(days=30),
+ now + timedelta(days=30),
+ filter_kinds=filter_kinds,
+ namespace_name=namespace_name,
+ repository_name=repo_name,
+ )
+ all_logs = sorted(all_logs, key=lambda l: l.datetime, reverse=True)
- # Check that querying all logs does not return the filtered kinds
- assert all([log.kind_id not in ignore_ids for log in all_logs])
+ # Check that querying all logs does not return the filtered kinds
+ assert all([log.kind_id not in ignore_ids for log in all_logs])
- # Check that the latest logs contains only th most recent ones
- assert latest_logs == all_logs[:len(latest_logs)]
+ # Check that the latest logs contains only th most recent ones
+ assert latest_logs == all_logs[: len(latest_logs)]
def test_count_repository_actions(logs_model):
- # Log some actions.
- logs_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
-
- # Log some actions to a different repo.
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='complex',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='complex',
- ip='1.2.3.4')
-
- # Count the actions.
- day = date.today()
- simple_repo = model.repository.get_repository('devtable', 'simple')
-
- count = logs_model.count_repository_actions(simple_repo, day)
- assert count == 3
-
- complex_repo = model.repository.get_repository('devtable', 'complex')
- count = logs_model.count_repository_actions(complex_repo, day)
- assert count == 2
-
- # Try counting actions for a few days in the future to ensure it doesn't raise an error.
- count = logs_model.count_repository_actions(simple_repo, day + timedelta(days=5))
- assert count == 0
-
-
-def test_yield_log_rotation_context(logs_model):
- cutoff_date = datetime.now()
- min_logs_per_rotation = 3
-
- # Log some actions to be archived
- # One day
- logs_model.log_action('push_repo', namespace_name='devtable', repository_name='simple1',
- ip='1.2.3.4', timestamp=cutoff_date-timedelta(days=1, seconds=1))
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple2',
- ip='5.6.7.8', timestamp=cutoff_date-timedelta(days=1, seconds=2))
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple3',
- ip='9.10.11.12', timestamp=cutoff_date-timedelta(days=1, seconds=3))
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple4',
- ip='0.0.0.0', timestamp=cutoff_date-timedelta(days=1, seconds=4))
- # Another day
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple5',
- ip='1.1.1.1', timestamp=cutoff_date-timedelta(days=2, seconds=1))
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple5',
- ip='1.1.1.1', timestamp=cutoff_date-timedelta(days=2, seconds=2))
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple5',
- ip='1.1.1.1', timestamp=cutoff_date-timedelta(days=2, seconds=3))
-
- found = _lookup_logs(logs_model, cutoff_date - timedelta(days=3), cutoff_date + timedelta(days=1))
- assert found is not None and len(found) == 7
-
- # Iterate the logs using the log rotation contexts
- all_logs = []
- for log_rotation_context in logs_model.yield_log_rotation_context(cutoff_date,
- min_logs_per_rotation):
- with log_rotation_context as context:
- for logs, _ in context.yield_logs_batch():
- all_logs.extend(logs)
-
- assert len(all_logs) == 7
- found = _lookup_logs(logs_model, cutoff_date - timedelta(days=3), cutoff_date + timedelta(days=1))
- assert not found
-
- # Make sure all datetimes are monotonically increasing (by datetime) after sorting the lookup
- # to make sure no duplicates were returned
- all_logs.sort(key=lambda d: d.datetime)
- assert all(x.datetime < y.datetime for x, y in zip(all_logs, all_logs[1:]))
-
-
-def test_count_repository_actions_with_wildcard_disabled(initialized_db):
- with fake_elasticsearch(allow_wildcard=False):
- logs_model = es_model()
-
# Log some actions.
- logs_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
-
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ logs_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
# Log some actions to a different repo.
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='complex',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='complex',
- ip='1.2.3.4')
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="complex", ip="1.2.3.4"
+ )
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="complex", ip="1.2.3.4"
+ )
# Count the actions.
day = date.today()
- simple_repo = model.repository.get_repository('devtable', 'simple')
+ simple_repo = model.repository.get_repository("devtable", "simple")
count = logs_model.count_repository_actions(simple_repo, day)
assert count == 3
- complex_repo = model.repository.get_repository('devtable', 'complex')
+ complex_repo = model.repository.get_repository("devtable", "complex")
count = logs_model.count_repository_actions(complex_repo, day)
assert count == 2
@@ -297,177 +314,384 @@ def test_count_repository_actions_with_wildcard_disabled(initialized_db):
assert count == 0
-@pytest.mark.skipif(os.environ.get('TEST_DATABASE_URI', '').find('mysql') >= 0,
- reason='Flaky on MySQL')
+def test_yield_log_rotation_context(logs_model):
+ cutoff_date = datetime.now()
+ min_logs_per_rotation = 3
+
+ # Log some actions to be archived
+ # One day
+ logs_model.log_action(
+ "push_repo",
+ namespace_name="devtable",
+ repository_name="simple1",
+ ip="1.2.3.4",
+ timestamp=cutoff_date - timedelta(days=1, seconds=1),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple2",
+ ip="5.6.7.8",
+ timestamp=cutoff_date - timedelta(days=1, seconds=2),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple3",
+ ip="9.10.11.12",
+ timestamp=cutoff_date - timedelta(days=1, seconds=3),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple4",
+ ip="0.0.0.0",
+ timestamp=cutoff_date - timedelta(days=1, seconds=4),
+ )
+ # Another day
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple5",
+ ip="1.1.1.1",
+ timestamp=cutoff_date - timedelta(days=2, seconds=1),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple5",
+ ip="1.1.1.1",
+ timestamp=cutoff_date - timedelta(days=2, seconds=2),
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple5",
+ ip="1.1.1.1",
+ timestamp=cutoff_date - timedelta(days=2, seconds=3),
+ )
+
+ found = _lookup_logs(
+ logs_model, cutoff_date - timedelta(days=3), cutoff_date + timedelta(days=1)
+ )
+ assert found is not None and len(found) == 7
+
+ # Iterate the logs using the log rotation contexts
+ all_logs = []
+ for log_rotation_context in logs_model.yield_log_rotation_context(
+ cutoff_date, min_logs_per_rotation
+ ):
+ with log_rotation_context as context:
+ for logs, _ in context.yield_logs_batch():
+ all_logs.extend(logs)
+
+ assert len(all_logs) == 7
+ found = _lookup_logs(
+ logs_model, cutoff_date - timedelta(days=3), cutoff_date + timedelta(days=1)
+ )
+ assert not found
+
+ # Make sure all datetimes are monotonically increasing (by datetime) after sorting the lookup
+ # to make sure no duplicates were returned
+ all_logs.sort(key=lambda d: d.datetime)
+ assert all(x.datetime < y.datetime for x, y in zip(all_logs, all_logs[1:]))
+
+
+def test_count_repository_actions_with_wildcard_disabled(initialized_db):
+ with fake_elasticsearch(allow_wildcard=False):
+ logs_model = es_model()
+
+ # Log some actions.
+ logs_model.log_action(
+ "push_repo",
+ namespace_name="devtable",
+ repository_name="simple",
+ ip="1.2.3.4",
+ )
+
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple",
+ ip="1.2.3.4",
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple",
+ ip="1.2.3.4",
+ )
+
+ # Log some actions to a different repo.
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="complex",
+ ip="1.2.3.4",
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="complex",
+ ip="1.2.3.4",
+ )
+
+ # Count the actions.
+ day = date.today()
+ simple_repo = model.repository.get_repository("devtable", "simple")
+
+ count = logs_model.count_repository_actions(simple_repo, day)
+ assert count == 3
+
+ complex_repo = model.repository.get_repository("devtable", "complex")
+ count = logs_model.count_repository_actions(complex_repo, day)
+ assert count == 2
+
+ # Try counting actions for a few days in the future to ensure it doesn't raise an error.
+ count = logs_model.count_repository_actions(
+ simple_repo, day + timedelta(days=5)
+ )
+ assert count == 0
+
+
+@pytest.mark.skipif(
+ os.environ.get("TEST_DATABASE_URI", "").find("mysql") >= 0, reason="Flaky on MySQL"
+)
def test_yield_logs_for_export(logs_model):
- # Add some logs.
- kinds = list(LogEntryKind.select())
- user = model.user.get_user('devtable')
+ # Add some logs.
+ kinds = list(LogEntryKind.select())
+ user = model.user.get_user("devtable")
- start_timestamp = datetime.utcnow()
- timestamp = start_timestamp
+ start_timestamp = datetime.utcnow()
+ timestamp = start_timestamp
- for kind in kinds:
- for index in range(0, 10):
- logs_model.log_action(kind.name, namespace_name='devtable', repository_name='simple',
- performer=user, ip='1.2.3.4', timestamp=timestamp)
- timestamp = timestamp + timedelta(seconds=1)
+ for kind in kinds:
+ for index in range(0, 10):
+ logs_model.log_action(
+ kind.name,
+ namespace_name="devtable",
+ repository_name="simple",
+ performer=user,
+ ip="1.2.3.4",
+ timestamp=timestamp,
+ )
+ timestamp = timestamp + timedelta(seconds=1)
- # Yield the logs.
- simple_repo = model.repository.get_repository('devtable', 'simple')
- logs_found = []
- for logs in logs_model.yield_logs_for_export(start_timestamp, timestamp + timedelta(minutes=10),
- repository_id=simple_repo.id):
- logs_found.extend(logs)
+ # Yield the logs.
+ simple_repo = model.repository.get_repository("devtable", "simple")
+ logs_found = []
+ for logs in logs_model.yield_logs_for_export(
+ start_timestamp, timestamp + timedelta(minutes=10), repository_id=simple_repo.id
+ ):
+ logs_found.extend(logs)
- # Ensure we found all added logs.
- assert len(logs_found) == len(kinds) * 10
+ # Ensure we found all added logs.
+ assert len(logs_found) == len(kinds) * 10
def test_yield_logs_for_export_timeout(logs_model):
- # Add some logs.
- kinds = list(LogEntryKind.select())
- user = model.user.get_user('devtable')
+ # Add some logs.
+ kinds = list(LogEntryKind.select())
+ user = model.user.get_user("devtable")
- start_timestamp = datetime.utcnow()
- timestamp = start_timestamp
+ start_timestamp = datetime.utcnow()
+ timestamp = start_timestamp
- for kind in kinds:
- for _ in range(0, 2):
- logs_model.log_action(kind.name, namespace_name='devtable', repository_name='simple',
- performer=user, ip='1.2.3.4', timestamp=timestamp)
- timestamp = timestamp + timedelta(seconds=1)
+ for kind in kinds:
+ for _ in range(0, 2):
+ logs_model.log_action(
+ kind.name,
+ namespace_name="devtable",
+ repository_name="simple",
+ performer=user,
+ ip="1.2.3.4",
+ timestamp=timestamp,
+ )
+ timestamp = timestamp + timedelta(seconds=1)
- # Yield the logs. Since we set the timeout to nothing, it should immediately fail.
- simple_repo = model.repository.get_repository('devtable', 'simple')
- with pytest.raises(LogsIterationTimeout):
- list(logs_model.yield_logs_for_export(start_timestamp, timestamp + timedelta(minutes=1),
- repository_id=simple_repo.id,
- max_query_time=timedelta(seconds=0)))
+ # Yield the logs. Since we set the timeout to nothing, it should immediately fail.
+ simple_repo = model.repository.get_repository("devtable", "simple")
+ with pytest.raises(LogsIterationTimeout):
+ list(
+ logs_model.yield_logs_for_export(
+ start_timestamp,
+ timestamp + timedelta(minutes=1),
+ repository_id=simple_repo.id,
+ max_query_time=timedelta(seconds=0),
+ )
+ )
def test_disabled_namespace(clear_db_logs):
- logs_model = TableLogsModel(lambda kind, namespace, is_free: namespace == 'devtable')
+ logs_model = TableLogsModel(
+ lambda kind, namespace, is_free: namespace == "devtable"
+ )
- # Log some actions.
- logs_model.log_action('push_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ # Log some actions.
+ logs_model.log_action(
+ "push_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple',
- ip='1.2.3.4')
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
+ logs_model.log_action(
+ "pull_repo", namespace_name="devtable", repository_name="simple", ip="1.2.3.4"
+ )
- # Log some actions to a different namespace.
- logs_model.log_action('push_repo', namespace_name='buynlarge', repository_name='orgrepo',
- ip='1.2.3.4')
+ # Log some actions to a different namespace.
+ logs_model.log_action(
+ "push_repo", namespace_name="buynlarge", repository_name="orgrepo", ip="1.2.3.4"
+ )
- logs_model.log_action('pull_repo', namespace_name='buynlarge', repository_name='orgrepo',
- ip='1.2.3.4')
- logs_model.log_action('pull_repo', namespace_name='buynlarge', repository_name='orgrepo',
- ip='1.2.3.4')
+ logs_model.log_action(
+ "pull_repo", namespace_name="buynlarge", repository_name="orgrepo", ip="1.2.3.4"
+ )
+ logs_model.log_action(
+ "pull_repo", namespace_name="buynlarge", repository_name="orgrepo", ip="1.2.3.4"
+ )
- # Count the actions.
- day = datetime.today() - timedelta(minutes=60)
- simple_repo = model.repository.get_repository('devtable', 'simple')
- count = logs_model.count_repository_actions(simple_repo, day)
- assert count == 0
+ # Count the actions.
+ day = datetime.today() - timedelta(minutes=60)
+ simple_repo = model.repository.get_repository("devtable", "simple")
+ count = logs_model.count_repository_actions(simple_repo, day)
+ assert count == 0
- org_repo = model.repository.get_repository('buynlarge', 'orgrepo')
- count = logs_model.count_repository_actions(org_repo, day)
- assert count == 3
+ org_repo = model.repository.get_repository("buynlarge", "orgrepo")
+ count = logs_model.count_repository_actions(org_repo, day)
+ assert count == 3
-@pytest.mark.parametrize('aggregated_log_counts1, aggregated_log_counts2, expected_result', [
- pytest.param(
+@pytest.mark.parametrize(
+ "aggregated_log_counts1, aggregated_log_counts2, expected_result",
[
- AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0)), # 1
- AggregatedLogCount(1, 3, datetime(2019, 6, 7, 0, 0)), # 2
+ pytest.param(
+ [
+ AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0)), # 1
+ AggregatedLogCount(1, 3, datetime(2019, 6, 7, 0, 0)), # 2
+ ],
+ [
+ AggregatedLogCount(1, 5, datetime(2019, 6, 6, 0, 0)), # 1
+ AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0)), # 2
+ AggregatedLogCount(3, 3, datetime(2019, 6, 1, 0, 0)), # 3
+ ],
+ [
+ AggregatedLogCount(1, 8, datetime(2019, 6, 6, 0, 0)), # 1
+ AggregatedLogCount(1, 10, datetime(2019, 6, 7, 0, 0)), # 2
+ AggregatedLogCount(3, 3, datetime(2019, 6, 1, 0, 0)), # 3
+ ],
+ ),
+ pytest.param(
+ [AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0))], # 1
+ [AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0))], # 2
+ [
+ AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0)), # 1
+ AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0)), # 2
+ ],
+ ),
+ pytest.param(
+ [],
+ [AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0))],
+ [AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0))],
+ ),
],
+)
+def test_merge_aggregated_log_counts(
+ aggregated_log_counts1, aggregated_log_counts2, expected_result
+):
+ assert sorted(
+ _merge_aggregated_log_counts(aggregated_log_counts1, aggregated_log_counts2)
+ ) == sorted(expected_result)
+
+
+@pytest.mark.parametrize(
+ "dt1, dt2, expected_result",
[
- AggregatedLogCount(1, 5, datetime(2019, 6, 6, 0, 0)), # 1
- AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0)), # 2
- AggregatedLogCount(3, 3, datetime(2019, 6, 1, 0, 0)), # 3
+ # Valid dates
+ pytest.param(date(2019, 6, 17), date(2019, 6, 18), True),
+ # Invalid dates
+ pytest.param(date(2019, 6, 17), date(2019, 6, 17), False),
+ pytest.param(date(2019, 6, 17), date(2019, 6, 19), False),
+ pytest.param(date(2019, 6, 18), date(2019, 6, 17), False),
+ # Valid datetimes
+ pytest.param(datetime(2019, 6, 17, 0, 1), datetime(2019, 6, 17, 0, 2), True),
+ # Invalid datetimes
+ pytest.param(datetime(2019, 6, 17, 0, 2), datetime(2019, 6, 17, 0, 1), False),
+ pytest.param(
+ datetime(2019, 6, 17, 11),
+ datetime(2019, 6, 17, 11) + timedelta(hours=14),
+ False,
+ ),
],
- [
- AggregatedLogCount(1, 8, datetime(2019, 6, 6, 0, 0)), # 1
- AggregatedLogCount(1, 10, datetime(2019, 6, 7, 0, 0)), # 2
- AggregatedLogCount(3, 3, datetime(2019, 6, 1, 0, 0)) # 3
- ]
- ),
- pytest.param(
- [
- AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0)), # 1
- ],
- [
- AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0)), # 2
- ],
- [
- AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0)), # 1
- AggregatedLogCount(1, 7, datetime(2019, 6, 7, 0, 0)), # 2
- ]
- ),
- pytest.param(
- [],
- [AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0))],
- [AggregatedLogCount(1, 3, datetime(2019, 6, 6, 0, 0))]
- ),
-])
-def test_merge_aggregated_log_counts(aggregated_log_counts1, aggregated_log_counts2, expected_result):
- assert (sorted(_merge_aggregated_log_counts(aggregated_log_counts1, aggregated_log_counts2)) ==
- sorted(expected_result))
-
-
-@pytest.mark.parametrize('dt1, dt2, expected_result', [
- # Valid dates
- pytest.param(date(2019, 6, 17), date(2019, 6, 18), True),
-
- # Invalid dates
- pytest.param(date(2019, 6, 17), date(2019, 6, 17), False),
- pytest.param(date(2019, 6, 17), date(2019, 6, 19), False),
- pytest.param(date(2019, 6, 18), date(2019, 6, 17), False),
-
- # Valid datetimes
- pytest.param(datetime(2019, 6, 17, 0, 1), datetime(2019, 6, 17, 0, 2), True),
-
- # Invalid datetimes
- pytest.param(datetime(2019, 6, 17, 0, 2), datetime(2019, 6, 17, 0, 1), False),
- pytest.param(datetime(2019, 6, 17, 11), datetime(2019, 6, 17, 11) + timedelta(hours=14), False),
-])
+)
def test_date_range_in_single_index(dt1, dt2, expected_result):
- assert _date_range_in_single_index(dt1, dt2) == expected_result
+ assert _date_range_in_single_index(dt1, dt2) == expected_result
def test_pagination(logs_model, mock_page_size):
- """
+ """
Make sure that pagination does not stop if searching through multiple indices by day,
and the current log count matches the page size while there are still indices to be searched.
"""
- day1 = datetime.now()
- day2 = day1 + timedelta(days=1)
- day3 = day2 + timedelta(days=1)
+ day1 = datetime.now()
+ day2 = day1 + timedelta(days=1)
+ day3 = day2 + timedelta(days=1)
- # Log some actions in day indices
- # One day
- logs_model.log_action('push_repo', namespace_name='devtable', repository_name='simple1',
- ip='1.2.3.4', timestamp=day1)
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple1',
- ip='5.6.7.8', timestamp=day1)
+ # Log some actions in day indices
+ # One day
+ logs_model.log_action(
+ "push_repo",
+ namespace_name="devtable",
+ repository_name="simple1",
+ ip="1.2.3.4",
+ timestamp=day1,
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple1",
+ ip="5.6.7.8",
+ timestamp=day1,
+ )
- found = _lookup_logs(logs_model, day1-timedelta(seconds=1), day3+timedelta(seconds=1))
- assert len(found) == mock_page_size
+ found = _lookup_logs(
+ logs_model, day1 - timedelta(seconds=1), day3 + timedelta(seconds=1)
+ )
+ assert len(found) == mock_page_size
- # Another day
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple2',
- ip='1.1.1.1', timestamp=day2)
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple2',
- ip='0.0.0.0', timestamp=day2)
+ # Another day
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple2",
+ ip="1.1.1.1",
+ timestamp=day2,
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple2",
+ ip="0.0.0.0",
+ timestamp=day2,
+ )
- # Yet another day
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple2',
- ip='1.1.1.1', timestamp=day3)
- logs_model.log_action('pull_repo', namespace_name='devtable', repository_name='simple2',
- ip='0.0.0.0', timestamp=day3)
+ # Yet another day
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple2",
+ ip="1.1.1.1",
+ timestamp=day3,
+ )
+ logs_model.log_action(
+ "pull_repo",
+ namespace_name="devtable",
+ repository_name="simple2",
+ ip="0.0.0.0",
+ timestamp=day3,
+ )
- found = _lookup_logs(logs_model, day1-timedelta(seconds=1), day3+timedelta(seconds=1))
- assert len(found) == 6
+ found = _lookup_logs(
+ logs_model, day1 - timedelta(seconds=1), day3 + timedelta(seconds=1)
+ )
+ assert len(found) == 6
diff --git a/data/logs_model/test/test_logs_producer.py b/data/logs_model/test/test_logs_producer.py
index 382684244..5d6759edb 100644
--- a/data/logs_model/test/test_logs_producer.py
+++ b/data/logs_model/test/test_logs_producer.py
@@ -7,71 +7,105 @@ import botocore
from data.logs_model import configure
-from test_elasticsearch import app_config, logs_model_config, logs_model, mock_elasticsearch, mock_db_model
+from test_elasticsearch import (
+ app_config,
+ logs_model_config,
+ logs_model,
+ mock_elasticsearch,
+ mock_db_model,
+)
from mock_elasticsearch import *
logger = logging.getLogger(__name__)
-FAKE_KAFKA_BROKERS = ['fake_server1', 'fake_server2']
-FAKE_KAFKA_TOPIC = 'sometopic'
+FAKE_KAFKA_BROKERS = ["fake_server1", "fake_server2"]
+FAKE_KAFKA_TOPIC = "sometopic"
FAKE_MAX_BLOCK_SECONDS = 1
+
@pytest.fixture()
def kafka_logs_producer_config(app_config):
- producer_config = {}
- producer_config.update(app_config)
-
- kafka_config = {
- 'bootstrap_servers': FAKE_KAFKA_BROKERS,
- 'topic': FAKE_KAFKA_TOPIC,
- 'max_block_seconds': FAKE_MAX_BLOCK_SECONDS
- }
+ producer_config = {}
+ producer_config.update(app_config)
- producer_config['LOGS_MODEL_CONFIG']['producer'] = 'kafka'
- producer_config['LOGS_MODEL_CONFIG']['kafka_config'] = kafka_config
- return producer_config
+ kafka_config = {
+ "bootstrap_servers": FAKE_KAFKA_BROKERS,
+ "topic": FAKE_KAFKA_TOPIC,
+ "max_block_seconds": FAKE_MAX_BLOCK_SECONDS,
+ }
+
+ producer_config["LOGS_MODEL_CONFIG"]["producer"] = "kafka"
+ producer_config["LOGS_MODEL_CONFIG"]["kafka_config"] = kafka_config
+ return producer_config
@pytest.fixture()
def kinesis_logs_producer_config(app_config):
- producer_config = {}
- producer_config.update(app_config)
-
- kinesis_stream_config = {
- 'stream_name': 'test-stream',
- 'aws_region': 'fake_region',
- 'aws_access_key': 'some_key',
- 'aws_secret_key': 'some_secret'
- }
+ producer_config = {}
+ producer_config.update(app_config)
- producer_config['LOGS_MODEL_CONFIG']['producer'] = 'kinesis_stream'
- producer_config['LOGS_MODEL_CONFIG']['kinesis_stream_config'] = kinesis_stream_config
- return producer_config
+ kinesis_stream_config = {
+ "stream_name": "test-stream",
+ "aws_region": "fake_region",
+ "aws_access_key": "some_key",
+ "aws_secret_key": "some_secret",
+ }
+
+ producer_config["LOGS_MODEL_CONFIG"]["producer"] = "kinesis_stream"
+ producer_config["LOGS_MODEL_CONFIG"][
+ "kinesis_stream_config"
+ ] = kinesis_stream_config
+ return producer_config
-def test_kafka_logs_producers(logs_model, mock_elasticsearch, mock_db_model, kafka_logs_producer_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+def test_kafka_logs_producers(
+ logs_model, mock_elasticsearch, mock_db_model, kafka_logs_producer_config
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- producer_config = kafka_logs_producer_config
- with patch('kafka.client_async.KafkaClient.check_version'), patch('kafka.KafkaProducer.send') as mock_send:
- configure(producer_config)
- logs_model.log_action('pull_repo', 'user1', Mock(id=1), '192.168.1.1', {'key': 'value'},
- None, 'repo1', parse("2019-01-01T03:30"))
-
- mock_send.assert_called_once()
+ producer_config = kafka_logs_producer_config
+ with patch("kafka.client_async.KafkaClient.check_version"), patch(
+ "kafka.KafkaProducer.send"
+ ) as mock_send:
+ configure(producer_config)
+ logs_model.log_action(
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ "192.168.1.1",
+ {"key": "value"},
+ None,
+ "repo1",
+ parse("2019-01-01T03:30"),
+ )
+
+ mock_send.assert_called_once()
-def test_kinesis_logs_producers(logs_model, mock_elasticsearch, mock_db_model, kinesis_logs_producer_config):
- mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
+def test_kinesis_logs_producers(
+ logs_model, mock_elasticsearch, mock_db_model, kinesis_logs_producer_config
+):
+ mock_elasticsearch.template = Mock(return_value=DEFAULT_TEMPLATE_RESPONSE)
- producer_config = kinesis_logs_producer_config
- with patch('botocore.endpoint.EndpointCreator.create_endpoint'), \
- patch('botocore.client.BaseClient._make_api_call') as mock_send:
- configure(producer_config)
- logs_model.log_action('pull_repo', 'user1', Mock(id=1), '192.168.1.1', {'key': 'value'},
- None, 'repo1', parse("2019-01-01T03:30"))
+ producer_config = kinesis_logs_producer_config
+ with patch("botocore.endpoint.EndpointCreator.create_endpoint"), patch(
+ "botocore.client.BaseClient._make_api_call"
+ ) as mock_send:
+ configure(producer_config)
+ logs_model.log_action(
+ "pull_repo",
+ "user1",
+ Mock(id=1),
+ "192.168.1.1",
+ {"key": "value"},
+ None,
+ "repo1",
+ parse("2019-01-01T03:30"),
+ )
- # Check that a PutRecord api call is made.
- # NOTE: The second arg of _make_api_call uses a randomized PartitionKey
- mock_send.assert_called_once_with(u'PutRecord', mock_send.call_args_list[0][0][1])
+ # Check that a PutRecord api call is made.
+ # NOTE: The second arg of _make_api_call uses a randomized PartitionKey
+ mock_send.assert_called_once_with(
+ u"PutRecord", mock_send.call_args_list[0][0][1]
+ )
diff --git a/data/migrations/progress.py b/data/migrations/progress.py
index 91278beea..2f849253a 100644
--- a/data/migrations/progress.py
+++ b/data/migrations/progress.py
@@ -9,93 +9,97 @@ from util.abchelpers import nooper
@add_metaclass(ABCMeta)
class ProgressReporter(object):
- """ Implements an interface for reporting progress with the migrations.
+ """ Implements an interface for reporting progress with the migrations.
"""
- @abstractmethod
- def report_version_complete(self, success):
- """ Called when an entire migration is complete. """
- @abstractmethod
- def report_step_progress(self):
- """ Called when a single step in the migration has been completed. """
+ @abstractmethod
+ def report_version_complete(self, success):
+ """ Called when an entire migration is complete. """
+
+ @abstractmethod
+ def report_step_progress(self):
+ """ Called when a single step in the migration has been completed. """
@nooper
class NullReporter(ProgressReporter):
- """ No-op version of the progress reporter, designed for use when no progress
+ """ No-op version of the progress reporter, designed for use when no progress
reporting endpoint is provided. """
class PrometheusReporter(ProgressReporter):
- def __init__(self, prom_pushgateway_addr, prom_job, labels, total_steps_num=None):
- self._total_steps_num = total_steps_num
- self._completed_steps = 0.0
+ def __init__(self, prom_pushgateway_addr, prom_job, labels, total_steps_num=None):
+ self._total_steps_num = total_steps_num
+ self._completed_steps = 0.0
- registry = CollectorRegistry()
+ registry = CollectorRegistry()
- self._migration_completion_percent = Gauge(
- 'migration_completion_percent',
- 'Estimate of the completion percentage of the job',
- registry=registry,
- )
- self._migration_complete_total = Counter(
- 'migration_complete_total',
- 'Binary value of whether or not the job is complete',
- registry=registry,
- )
- self._migration_failed_total = Counter(
- 'migration_failed_total',
- 'Binary value of whether or not the job has failed',
- registry=registry,
- )
- self._migration_items_completed_total = Counter(
- 'migration_items_completed_total',
- 'Number of items this migration has completed',
- registry=registry,
- )
+ self._migration_completion_percent = Gauge(
+ "migration_completion_percent",
+ "Estimate of the completion percentage of the job",
+ registry=registry,
+ )
+ self._migration_complete_total = Counter(
+ "migration_complete_total",
+ "Binary value of whether or not the job is complete",
+ registry=registry,
+ )
+ self._migration_failed_total = Counter(
+ "migration_failed_total",
+ "Binary value of whether or not the job has failed",
+ registry=registry,
+ )
+ self._migration_items_completed_total = Counter(
+ "migration_items_completed_total",
+ "Number of items this migration has completed",
+ registry=registry,
+ )
- self._push = partial(push_to_gateway,
- prom_pushgateway_addr,
- job=prom_job,
- registry=registry,
- grouping_key=labels,
- )
+ self._push = partial(
+ push_to_gateway,
+ prom_pushgateway_addr,
+ job=prom_job,
+ registry=registry,
+ grouping_key=labels,
+ )
- def report_version_complete(self, success=True):
- if success:
- self._migration_complete_total.inc()
- else:
- self._migration_failed_total.inc()
- self._migration_completion_percent.set(1.0)
+ def report_version_complete(self, success=True):
+ if success:
+ self._migration_complete_total.inc()
+ else:
+ self._migration_failed_total.inc()
+ self._migration_completion_percent.set(1.0)
- self._push()
+ self._push()
- def report_step_progress(self):
- self._migration_items_completed_total.inc()
+ def report_step_progress(self):
+ self._migration_items_completed_total.inc()
- if self._total_steps_num is not None:
- self._completed_steps += 1
- self._migration_completion_percent = self._completed_steps / self._total_steps_num
+ if self._total_steps_num is not None:
+ self._completed_steps += 1
+ self._migration_completion_percent = (
+ self._completed_steps / self._total_steps_num
+ )
- self._push()
+ self._push()
class ProgressWrapper(object):
- def __init__(self, delegate_module, progress_monitor):
- self._delegate_module = delegate_module
- self._progress_monitor = progress_monitor
+ def __init__(self, delegate_module, progress_monitor):
+ self._delegate_module = delegate_module
+ self._progress_monitor = progress_monitor
- def __getattr__(self, attr_name):
- # Will raise proper attribute error
- maybe_callable = self._delegate_module.__dict__[attr_name]
- if callable(maybe_callable):
- # Build a callable which when executed places the request
- # onto a queue
- @wraps(maybe_callable)
- def wrapped_method(*args, **kwargs):
- result = maybe_callable(*args, **kwargs)
- self._progress_monitor.report_step_progress()
- return result
+ def __getattr__(self, attr_name):
+ # Will raise proper attribute error
+ maybe_callable = self._delegate_module.__dict__[attr_name]
+ if callable(maybe_callable):
+ # Build a callable which when executed places the request
+ # onto a queue
+ @wraps(maybe_callable)
+ def wrapped_method(*args, **kwargs):
+ result = maybe_callable(*args, **kwargs)
+ self._progress_monitor.report_step_progress()
+ return result
- return wrapped_method
- return maybe_callable
+ return wrapped_method
+ return maybe_callable
diff --git a/data/migrations/test/test_db_config.py b/data/migrations/test/test_db_config.py
index 747c5eb73..4f524608d 100644
--- a/data/migrations/test/test_db_config.py
+++ b/data/migrations/test/test_db_config.py
@@ -5,17 +5,21 @@ from data.runmigration import run_alembic_migration
from alembic.script import ScriptDirectory
from test.fixtures import *
-@pytest.mark.parametrize('db_uri, is_valid', [
- ('postgresql://devtable:password@quay-postgres/registry_database', True),
- ('postgresql://devtable:password%25@quay-postgres/registry_database', False),
- ('postgresql://devtable:password%%25@quay-postgres/registry_database', True),
- ('postgresql://devtable@db:password@quay-postgres/registry_database', True),
-])
+
+@pytest.mark.parametrize(
+ "db_uri, is_valid",
+ [
+ ("postgresql://devtable:password@quay-postgres/registry_database", True),
+ ("postgresql://devtable:password%25@quay-postgres/registry_database", False),
+ ("postgresql://devtable:password%%25@quay-postgres/registry_database", True),
+ ("postgresql://devtable@db:password@quay-postgres/registry_database", True),
+ ],
+)
def test_alembic_db_uri(db_uri, is_valid):
- """ Test if the given URI is escaped for string interpolation (Python's configparser). """
- with patch('alembic.script.ScriptDirectory.run_env') as m:
- if is_valid:
- run_alembic_migration(db_uri)
- else:
- with pytest.raises(ValueError):
- run_alembic_migration(db_uri)
+ """ Test if the given URI is escaped for string interpolation (Python's configparser). """
+ with patch("alembic.script.ScriptDirectory.run_env") as m:
+ if is_valid:
+ run_alembic_migration(db_uri)
+ else:
+ with pytest.raises(ValueError):
+ run_alembic_migration(db_uri)
diff --git a/data/migrations/tester.py b/data/migrations/tester.py
index 2643b80e2..34f3065b1 100644
--- a/data/migrations/tester.py
+++ b/data/migrations/tester.py
@@ -13,128 +13,142 @@ from util.abchelpers import nooper
logger = logging.getLogger(__name__)
-def escape_table_name(table_name):
- if op.get_bind().engine.name == 'postgresql':
- # Needed for the `user` table.
- return '"%s"' % table_name
- return table_name
+def escape_table_name(table_name):
+ if op.get_bind().engine.name == "postgresql":
+ # Needed for the `user` table.
+ return '"%s"' % table_name
+
+ return table_name
class DataTypes(object):
- @staticmethod
- def DateTime():
- return datetime.now()
+ @staticmethod
+ def DateTime():
+ return datetime.now()
- @staticmethod
- def Date():
- return datetime.now()
+ @staticmethod
+ def Date():
+ return datetime.now()
- @staticmethod
- def String():
- return 'somestringvalue'
+ @staticmethod
+ def String():
+ return "somestringvalue"
- @staticmethod
- def Token():
- return '%s%s' % ('a' * 60, 'b' * 60)
+ @staticmethod
+ def Token():
+ return "%s%s" % ("a" * 60, "b" * 60)
- @staticmethod
- def UTF8Char():
- return 'some other value'
+ @staticmethod
+ def UTF8Char():
+ return "some other value"
- @staticmethod
- def UUID():
- return str(uuid.uuid4())
+ @staticmethod
+ def UUID():
+ return str(uuid.uuid4())
- @staticmethod
- def JSON():
- return json.dumps(dict(foo='bar', baz='meh'))
+ @staticmethod
+ def JSON():
+ return json.dumps(dict(foo="bar", baz="meh"))
- @staticmethod
- def Boolean():
- if op.get_bind().engine.name == 'postgresql':
- return True
+ @staticmethod
+ def Boolean():
+ if op.get_bind().engine.name == "postgresql":
+ return True
- return 1
+ return 1
- @staticmethod
- def BigInteger():
- return 21474836470
+ @staticmethod
+ def BigInteger():
+ return 21474836470
- @staticmethod
- def Integer():
- return 42
+ @staticmethod
+ def Integer():
+ return 42
- @staticmethod
- def Constant(value):
- def get_value():
- return value
- return get_value
+ @staticmethod
+ def Constant(value):
+ def get_value():
+ return value
- @staticmethod
- def Foreign(table_name):
- def get_index():
- result = op.get_bind().execute("SELECT id FROM %s LIMIT 1" % escape_table_name(table_name))
- try:
- return list(result)[0][0]
- except IndexError:
- raise Exception('Could not find row for table %s' % table_name)
- finally:
- result.close()
+ return get_value
- return get_index
+ @staticmethod
+ def Foreign(table_name):
+ def get_index():
+ result = op.get_bind().execute(
+ "SELECT id FROM %s LIMIT 1" % escape_table_name(table_name)
+ )
+ try:
+ return list(result)[0][0]
+ except IndexError:
+ raise Exception("Could not find row for table %s" % table_name)
+ finally:
+ result.close()
+
+ return get_index
@add_metaclass(ABCMeta)
class MigrationTester(object):
- """ Implements an interface for adding testing capabilities to the
+ """ Implements an interface for adding testing capabilities to the
data model migration system in Alembic.
"""
- TestDataType = DataTypes
- @abstractproperty
- def is_testing(self):
- """ Returns whether we are currently under a migration test. """
+ TestDataType = DataTypes
- @abstractmethod
- def populate_table(self, table_name, fields):
- """ Called to populate a table with the given fields filled in with testing data. """
+ @abstractproperty
+ def is_testing(self):
+ """ Returns whether we are currently under a migration test. """
- @abstractmethod
- def populate_column(self, table_name, col_name, field_type):
- """ Called to populate a column in a table to be filled in with testing data. """
+ @abstractmethod
+ def populate_table(self, table_name, fields):
+ """ Called to populate a table with the given fields filled in with testing data. """
+
+ @abstractmethod
+ def populate_column(self, table_name, col_name, field_type):
+ """ Called to populate a column in a table to be filled in with testing data. """
@nooper
class NoopTester(MigrationTester):
- """ No-op version of the tester, designed for production workloads. """
+ """ No-op version of the tester, designed for production workloads. """
class PopulateTestDataTester(MigrationTester):
- @property
- def is_testing(self):
- return True
+ @property
+ def is_testing(self):
+ return True
- def populate_table(self, table_name, fields):
- columns = {field_name: field_type() for field_name, field_type in fields}
- field_name_vars = [':' + field_name for field_name, _ in fields]
+ def populate_table(self, table_name, fields):
+ columns = {field_name: field_type() for field_name, field_type in fields}
+ field_name_vars = [":" + field_name for field_name, _ in fields]
- if op.get_bind().engine.name == 'postgresql':
- field_names = ["%s" % field_name for field_name, _ in fields]
- else:
- field_names = ["`%s`" % field_name for field_name, _ in fields]
+ if op.get_bind().engine.name == "postgresql":
+ field_names = ["%s" % field_name for field_name, _ in fields]
+ else:
+ field_names = ["`%s`" % field_name for field_name, _ in fields]
- table_name = escape_table_name(table_name)
- query = text('INSERT INTO %s (%s) VALUES (%s)' % (table_name, ', '.join(field_names),
- ', '.join(field_name_vars)))
- logger.info("Executing test query %s with values %s", query, columns.values())
- op.get_bind().execute(query, **columns)
+ table_name = escape_table_name(table_name)
+ query = text(
+ "INSERT INTO %s (%s) VALUES (%s)"
+ % (table_name, ", ".join(field_names), ", ".join(field_name_vars))
+ )
+ logger.info("Executing test query %s with values %s", query, columns.values())
+ op.get_bind().execute(query, **columns)
- def populate_column(self, table_name, col_name, field_type):
- col_value = field_type()
- row_id = DataTypes.Foreign(table_name)()
+ def populate_column(self, table_name, col_name, field_type):
+ col_value = field_type()
+ row_id = DataTypes.Foreign(table_name)()
- table_name = escape_table_name(table_name)
- update_text = text("UPDATE %s SET %s=:col_value where ID=:row_id" % (table_name, col_name))
- logger.info("Executing test query %s with value %s on row %s", update_text, col_value, row_id)
- op.get_bind().execute(update_text, col_value=col_value, row_id=row_id)
+ table_name = escape_table_name(table_name)
+ update_text = text(
+ "UPDATE %s SET %s=:col_value where ID=:row_id" % (table_name, col_name)
+ )
+ logger.info(
+ "Executing test query %s with value %s on row %s",
+ update_text,
+ col_value,
+ row_id,
+ )
+ op.get_bind().execute(update_text, col_value=col_value, row_id=row_id)
diff --git a/data/migrations/versions/0cf50323c78b_add_creation_date_to_user_table.py b/data/migrations/versions/0cf50323c78b_add_creation_date_to_user_table.py
index 2a995e58c..77fcee664 100644
--- a/data/migrations/versions/0cf50323c78b_add_creation_date_to_user_table.py
+++ b/data/migrations/versions/0cf50323c78b_add_creation_date_to_user_table.py
@@ -7,8 +7,8 @@ Create Date: 2018-03-09 13:19:41.903196
"""
# revision identifiers, used by Alembic.
-revision = '0cf50323c78b'
-down_revision = '87fbbc224f10'
+revision = "0cf50323c78b"
+down_revision = "87fbbc224f10"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -18,16 +18,16 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('user', sa.Column('creation_date', sa.DateTime(), nullable=True))
+ op.add_column("user", sa.Column("creation_date", sa.DateTime(), nullable=True))
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('user', 'creation_date', tester.TestDataType.DateTime)
+ tester.populate_column("user", "creation_date", tester.TestDataType.DateTime)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column('user', 'creation_date')
+ op.drop_column("user", "creation_date")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/10f45ee2310b_add_tag_tagkind_and_manifestchild_tables.py b/data/migrations/versions/10f45ee2310b_add_tag_tagkind_and_manifestchild_tables.py
index e2b4073da..a693498bf 100644
--- a/data/migrations/versions/10f45ee2310b_add_tag_tagkind_and_manifestchild_tables.py
+++ b/data/migrations/versions/10f45ee2310b_add_tag_tagkind_and_manifestchild_tables.py
@@ -7,94 +7,179 @@ Create Date: 2018-10-29 15:22:53.552216
"""
# revision identifiers, used by Alembic.
-revision = '10f45ee2310b'
-down_revision = '13411de1c0ff'
+revision = "10f45ee2310b"
+down_revision = "13411de1c0ff"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from util.migrate import UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('tagkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagkind'))
+ op.create_table(
+ "tagkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagkind")),
)
- op.create_index('tagkind_name', 'tagkind', ['name'], unique=True)
- op.create_table('manifestchild',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('child_manifest_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['child_manifest_id'], ['manifest.id'], name=op.f('fk_manifestchild_child_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestchild_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestchild_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestchild'))
+ op.create_index("tagkind_name", "tagkind", ["name"], unique=True)
+ op.create_table(
+ "manifestchild",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("child_manifest_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["child_manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestchild_child_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestchild_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestchild_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestchild")),
)
- op.create_index('manifestchild_child_manifest_id', 'manifestchild', ['child_manifest_id'], unique=False)
- op.create_index('manifestchild_manifest_id', 'manifestchild', ['manifest_id'], unique=False)
- op.create_index('manifestchild_manifest_id_child_manifest_id', 'manifestchild', ['manifest_id', 'child_manifest_id'], unique=True)
- op.create_index('manifestchild_repository_id', 'manifestchild', ['repository_id'], unique=False)
- op.create_index('manifestchild_repository_id_child_manifest_id', 'manifestchild', ['repository_id', 'child_manifest_id'], unique=False)
- op.create_index('manifestchild_repository_id_manifest_id', 'manifestchild', ['repository_id', 'manifest_id'], unique=False)
- op.create_index('manifestchild_repository_id_manifest_id_child_manifest_id', 'manifestchild', ['repository_id', 'manifest_id', 'child_manifest_id'], unique=False)
- op.create_table('tag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=True),
- sa.Column('lifetime_start_ms', sa.BigInteger(), nullable=False),
- sa.Column('lifetime_end_ms', sa.BigInteger(), nullable=True),
- sa.Column('hidden', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('reversion', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('tag_kind_id', sa.Integer(), nullable=False),
- sa.Column('linked_tag_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['linked_tag_id'], ['tag.id'], name=op.f('fk_tag_linked_tag_id_tag')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_tag_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_tag_repository_id_repository')),
- sa.ForeignKeyConstraint(['tag_kind_id'], ['tagkind.id'], name=op.f('fk_tag_tag_kind_id_tagkind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tag'))
+ op.create_index(
+ "manifestchild_child_manifest_id",
+ "manifestchild",
+ ["child_manifest_id"],
+ unique=False,
)
- op.create_index('tag_lifetime_end_ms', 'tag', ['lifetime_end_ms'], unique=False)
- op.create_index('tag_linked_tag_id', 'tag', ['linked_tag_id'], unique=False)
- op.create_index('tag_manifest_id', 'tag', ['manifest_id'], unique=False)
- op.create_index('tag_repository_id', 'tag', ['repository_id'], unique=False)
- op.create_index('tag_repository_id_name', 'tag', ['repository_id', 'name'], unique=False)
- op.create_index('tag_repository_id_name_hidden', 'tag', ['repository_id', 'name', 'hidden'], unique=False)
- op.create_index('tag_repository_id_name_lifetime_end_ms', 'tag', ['repository_id', 'name', 'lifetime_end_ms'], unique=True)
- op.create_index('tag_repository_id_name_tag_kind_id', 'tag', ['repository_id', 'name', 'tag_kind_id'], unique=False)
- op.create_index('tag_tag_kind_id', 'tag', ['tag_kind_id'], unique=False)
+ op.create_index(
+ "manifestchild_manifest_id", "manifestchild", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestchild_manifest_id_child_manifest_id",
+ "manifestchild",
+ ["manifest_id", "child_manifest_id"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestchild_repository_id", "manifestchild", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "manifestchild_repository_id_child_manifest_id",
+ "manifestchild",
+ ["repository_id", "child_manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestchild_repository_id_manifest_id",
+ "manifestchild",
+ ["repository_id", "manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestchild_repository_id_manifest_id_child_manifest_id",
+ "manifestchild",
+ ["repository_id", "manifest_id", "child_manifest_id"],
+ unique=False,
+ )
+ op.create_table(
+ "tag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=True),
+ sa.Column("lifetime_start_ms", sa.BigInteger(), nullable=False),
+ sa.Column("lifetime_end_ms", sa.BigInteger(), nullable=True),
+ sa.Column(
+ "hidden",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column(
+ "reversion",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column("tag_kind_id", sa.Integer(), nullable=False),
+ sa.Column("linked_tag_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["linked_tag_id"], ["tag.id"], name=op.f("fk_tag_linked_tag_id_tag")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"], ["manifest.id"], name=op.f("fk_tag_manifest_id_manifest")
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_tag_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_kind_id"], ["tagkind.id"], name=op.f("fk_tag_tag_kind_id_tagkind")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tag")),
+ )
+ op.create_index("tag_lifetime_end_ms", "tag", ["lifetime_end_ms"], unique=False)
+ op.create_index("tag_linked_tag_id", "tag", ["linked_tag_id"], unique=False)
+ op.create_index("tag_manifest_id", "tag", ["manifest_id"], unique=False)
+ op.create_index("tag_repository_id", "tag", ["repository_id"], unique=False)
+ op.create_index(
+ "tag_repository_id_name", "tag", ["repository_id", "name"], unique=False
+ )
+ op.create_index(
+ "tag_repository_id_name_hidden",
+ "tag",
+ ["repository_id", "name", "hidden"],
+ unique=False,
+ )
+ op.create_index(
+ "tag_repository_id_name_lifetime_end_ms",
+ "tag",
+ ["repository_id", "name", "lifetime_end_ms"],
+ unique=True,
+ )
+ op.create_index(
+ "tag_repository_id_name_tag_kind_id",
+ "tag",
+ ["repository_id", "name", "tag_kind_id"],
+ unique=False,
+ )
+ op.create_index("tag_tag_kind_id", "tag", ["tag_kind_id"], unique=False)
# ### end Alembic commands ###
- op.bulk_insert(tables.tagkind,
- [
- {'name': 'tag'},
- ])
+ op.bulk_insert(tables.tagkind, [{"name": "tag"}])
# ### population of test data ### #
- tester.populate_table('tag', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('tag_kind_id', tester.TestDataType.Foreign('tagkind')),
- ('name', tester.TestDataType.String),
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('lifetime_start_ms', tester.TestDataType.BigInteger),
- ])
+ tester.populate_table(
+ "tag",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("tag_kind_id", tester.TestDataType.Foreign("tagkind")),
+ ("name", tester.TestDataType.String),
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("lifetime_start_ms", tester.TestDataType.BigInteger),
+ ],
+ )
- tester.populate_table('manifestchild', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('child_manifest_id', tester.TestDataType.Foreign('manifest')),
- ])
+ tester.populate_table(
+ "manifestchild",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("child_manifest_id", tester.TestDataType.Foreign("manifest")),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('tag')
- op.drop_table('manifestchild')
- op.drop_table('tagkind')
+ op.drop_table("tag")
+ op.drop_table("manifestchild")
+ op.drop_table("tagkind")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/13411de1c0ff_remove_unique_from_tagmanifesttomanifest.py b/data/migrations/versions/13411de1c0ff_remove_unique_from_tagmanifesttomanifest.py
index 70e0a21d7..08a661695 100644
--- a/data/migrations/versions/13411de1c0ff_remove_unique_from_tagmanifesttomanifest.py
+++ b/data/migrations/versions/13411de1c0ff_remove_unique_from_tagmanifesttomanifest.py
@@ -7,38 +7,71 @@ Create Date: 2018-08-19 23:30:24.969549
"""
# revision identifiers, used by Alembic.
-revision = '13411de1c0ff'
-down_revision = '654e6df88b71'
+revision = "13411de1c0ff"
+down_revision = "654e6df88b71"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# Note: Because of a restriction in MySQL, we cannot simply remove the index and re-add
# it without the unique=False, nor can we simply alter the index. To make it work, we'd have to
# remove the primary key on the field, so instead we simply drop the table entirely and
# recreate it with the modified index. The backfill will re-fill this in.
- op.drop_table('tagmanifesttomanifest')
+ op.drop_table("tagmanifesttomanifest")
- op.create_table('tagmanifesttomanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('tag_manifest_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('broken', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_tagmanifesttomanifest_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['tag_manifest_id'], ['tagmanifest.id'], name=op.f('fk_tagmanifesttomanifest_tag_manifest_id_tagmanifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagmanifesttomanifest'))
+ op.create_table(
+ "tagmanifesttomanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("tag_manifest_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "broken",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_tagmanifesttomanifest_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_manifest_id"],
+ ["tagmanifest.id"],
+ name=op.f("fk_tagmanifesttomanifest_tag_manifest_id_tagmanifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagmanifesttomanifest")),
+ )
+ op.create_index(
+ "tagmanifesttomanifest_broken",
+ "tagmanifesttomanifest",
+ ["broken"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifesttomanifest_manifest_id",
+ "tagmanifesttomanifest",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifesttomanifest_tag_manifest_id",
+ "tagmanifesttomanifest",
+ ["tag_manifest_id"],
+ unique=True,
)
- op.create_index('tagmanifesttomanifest_broken', 'tagmanifesttomanifest', ['broken'], unique=False)
- op.create_index('tagmanifesttomanifest_manifest_id', 'tagmanifesttomanifest', ['manifest_id'], unique=False)
- op.create_index('tagmanifesttomanifest_tag_manifest_id', 'tagmanifesttomanifest', ['tag_manifest_id'], unique=True)
- tester.populate_table('tagmanifesttomanifest', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('tag_manifest_id', tester.TestDataType.Foreign('tagmanifest')),
- ])
+ tester.populate_table(
+ "tagmanifesttomanifest",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("tag_manifest_id", tester.TestDataType.Foreign("tagmanifest")),
+ ],
+ )
def downgrade(tables, tester, progress_reporter):
diff --git a/data/migrations/versions/152bb29a1bb3_add_maximum_build_queue_count_setting_.py b/data/migrations/versions/152bb29a1bb3_add_maximum_build_queue_count_setting_.py
index 489303dde..bc246cdbe 100644
--- a/data/migrations/versions/152bb29a1bb3_add_maximum_build_queue_count_setting_.py
+++ b/data/migrations/versions/152bb29a1bb3_add_maximum_build_queue_count_setting_.py
@@ -7,27 +7,32 @@ Create Date: 2018-02-20 13:34:34.902415
"""
# revision identifiers, used by Alembic.
-revision = '152bb29a1bb3'
-down_revision = 'cbc8177760d9'
+revision = "152bb29a1bb3"
+down_revision = "cbc8177760d9"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('user', sa.Column('maximum_queued_builds_count', sa.Integer(), nullable=True))
+ op.add_column(
+ "user", sa.Column("maximum_queued_builds_count", sa.Integer(), nullable=True)
+ )
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('user', 'maximum_queued_builds_count', tester.TestDataType.Integer)
+ tester.populate_column(
+ "user", "maximum_queued_builds_count", tester.TestDataType.Integer
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column('user', 'maximum_queued_builds_count')
+ op.drop_column("user", "maximum_queued_builds_count")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/152edccba18c_make_blodupload_byte_count_not_nullable.py b/data/migrations/versions/152edccba18c_make_blodupload_byte_count_not_nullable.py
index 6eca834fa..593b9f231 100644
--- a/data/migrations/versions/152edccba18c_make_blodupload_byte_count_not_nullable.py
+++ b/data/migrations/versions/152edccba18c_make_blodupload_byte_count_not_nullable.py
@@ -7,8 +7,8 @@ Create Date: 2018-02-23 12:41:25.571835
"""
# revision identifiers, used by Alembic.
-revision = '152edccba18c'
-down_revision = 'c91c564aad34'
+revision = "152edccba18c"
+down_revision = "c91c564aad34"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -17,11 +17,13 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('blobupload', 'byte_count', existing_type=sa.BigInteger(),
- nullable=False)
+ op.alter_column(
+ "blobupload", "byte_count", existing_type=sa.BigInteger(), nullable=False
+ )
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('blobupload', 'byte_count', existing_type=sa.BigInteger(),
- nullable=True)
+ op.alter_column(
+ "blobupload", "byte_count", existing_type=sa.BigInteger(), nullable=True
+ )
diff --git a/data/migrations/versions/1783530bee68_add_logentry2_table_quay_io_only.py b/data/migrations/versions/1783530bee68_add_logentry2_table_quay_io_only.py
index ffe5d9176..09f2d7818 100644
--- a/data/migrations/versions/1783530bee68_add_logentry2_table_quay_io_only.py
+++ b/data/migrations/versions/1783530bee68_add_logentry2_table_quay_io_only.py
@@ -7,8 +7,8 @@ Create Date: 2018-05-17 16:32:28.532264
"""
# revision identifiers, used by Alembic.
-revision = '1783530bee68'
-down_revision = '5b7503aada1b'
+revision = "1783530bee68"
+down_revision = "5b7503aada1b"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -18,32 +18,61 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('logentry2',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.Column('account_id', sa.Integer(), nullable=False),
- sa.Column('performer_id', sa.Integer(), nullable=True),
- sa.Column('repository_id', sa.Integer(), nullable=True),
- sa.Column('datetime', sa.DateTime(), nullable=False),
- sa.Column('ip', sa.String(length=255), nullable=True),
- sa.Column('metadata_json', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['kind_id'], ['logentrykind.id'], name=op.f('fk_logentry2_kind_id_logentrykind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_logentry2'))
+ op.create_table(
+ "logentry2",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.Column("account_id", sa.Integer(), nullable=False),
+ sa.Column("performer_id", sa.Integer(), nullable=True),
+ sa.Column("repository_id", sa.Integer(), nullable=True),
+ sa.Column("datetime", sa.DateTime(), nullable=False),
+ sa.Column("ip", sa.String(length=255), nullable=True),
+ sa.Column("metadata_json", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["logentrykind.id"],
+ name=op.f("fk_logentry2_kind_id_logentrykind"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_logentry2")),
+ )
+ op.create_index("logentry2_account_id", "logentry2", ["account_id"], unique=False)
+ op.create_index(
+ "logentry2_account_id_datetime",
+ "logentry2",
+ ["account_id", "datetime"],
+ unique=False,
+ )
+ op.create_index("logentry2_datetime", "logentry2", ["datetime"], unique=False)
+ op.create_index("logentry2_kind_id", "logentry2", ["kind_id"], unique=False)
+ op.create_index(
+ "logentry2_performer_id", "logentry2", ["performer_id"], unique=False
+ )
+ op.create_index(
+ "logentry2_performer_id_datetime",
+ "logentry2",
+ ["performer_id", "datetime"],
+ unique=False,
+ )
+ op.create_index(
+ "logentry2_repository_id", "logentry2", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "logentry2_repository_id_datetime",
+ "logentry2",
+ ["repository_id", "datetime"],
+ unique=False,
+ )
+ op.create_index(
+ "logentry2_repository_id_datetime_kind_id",
+ "logentry2",
+ ["repository_id", "datetime", "kind_id"],
+ unique=False,
)
- op.create_index('logentry2_account_id', 'logentry2', ['account_id'], unique=False)
- op.create_index('logentry2_account_id_datetime', 'logentry2', ['account_id', 'datetime'], unique=False)
- op.create_index('logentry2_datetime', 'logentry2', ['datetime'], unique=False)
- op.create_index('logentry2_kind_id', 'logentry2', ['kind_id'], unique=False)
- op.create_index('logentry2_performer_id', 'logentry2', ['performer_id'], unique=False)
- op.create_index('logentry2_performer_id_datetime', 'logentry2', ['performer_id', 'datetime'], unique=False)
- op.create_index('logentry2_repository_id', 'logentry2', ['repository_id'], unique=False)
- op.create_index('logentry2_repository_id_datetime', 'logentry2', ['repository_id', 'datetime'], unique=False)
- op.create_index('logentry2_repository_id_datetime_kind_id', 'logentry2', ['repository_id', 'datetime', 'kind_id'], unique=False)
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('logentry2')
+ op.drop_table("logentry2")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/17aff2e1354e_add_automatic_disable_of_build_triggers.py b/data/migrations/versions/17aff2e1354e_add_automatic_disable_of_build_triggers.py
index 27f1aafa6..8f3e83eec 100644
--- a/data/migrations/versions/17aff2e1354e_add_automatic_disable_of_build_triggers.py
+++ b/data/migrations/versions/17aff2e1354e_add_automatic_disable_of_build_triggers.py
@@ -7,48 +7,73 @@ Create Date: 2017-10-18 15:58:03.971526
"""
# revision identifiers, used by Alembic.
-revision = '17aff2e1354e'
-down_revision = '61cadbacb9fc'
+revision = "17aff2e1354e"
+down_revision = "61cadbacb9fc"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('repositorybuildtrigger', sa.Column('successive_failure_count', sa.Integer(), server_default='0', nullable=False))
- op.add_column('repositorybuildtrigger', sa.Column('successive_internal_error_count', sa.Integer(), server_default='0', nullable=False))
+ op.add_column(
+ "repositorybuildtrigger",
+ sa.Column(
+ "successive_failure_count", sa.Integer(), server_default="0", nullable=False
+ ),
+ )
+ op.add_column(
+ "repositorybuildtrigger",
+ sa.Column(
+ "successive_internal_error_count",
+ sa.Integer(),
+ server_default="0",
+ nullable=False,
+ ),
+ )
# ### end Alembic commands ###
op.bulk_insert(
tables.disablereason,
[
- {'id': 2, 'name': 'successive_build_failures'},
- {'id': 3, 'name': 'successive_build_internal_errors'},
+ {"id": 2, "name": "successive_build_failures"},
+ {"id": 3, "name": "successive_build_internal_errors"},
],
)
# ### population of test data ### #
- tester.populate_column('repositorybuildtrigger', 'successive_failure_count', tester.TestDataType.Integer)
- tester.populate_column('repositorybuildtrigger', 'successive_internal_error_count', tester.TestDataType.Integer)
+ tester.populate_column(
+ "repositorybuildtrigger",
+ "successive_failure_count",
+ tester.TestDataType.Integer,
+ )
+ tester.populate_column(
+ "repositorybuildtrigger",
+ "successive_internal_error_count",
+ tester.TestDataType.Integer,
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column('repositorybuildtrigger', 'successive_internal_error_count')
- op.drop_column('repositorybuildtrigger', 'successive_failure_count')
+ op.drop_column("repositorybuildtrigger", "successive_internal_error_count")
+ op.drop_column("repositorybuildtrigger", "successive_failure_count")
# ### end Alembic commands ###
- op.execute(tables
- .disablereason
- .delete()
- .where(tables.disablereason.c.name == op.inline_literal('successive_internal_error_count')))
+ op.execute(
+ tables.disablereason.delete().where(
+ tables.disablereason.c.name
+ == op.inline_literal("successive_internal_error_count")
+ )
+ )
- op.execute(tables
- .disablereason
- .delete()
- .where(tables.disablereason.c.name == op.inline_literal('successive_failure_count')))
+ op.execute(
+ tables.disablereason.delete().where(
+ tables.disablereason.c.name == op.inline_literal("successive_failure_count")
+ )
+ )
diff --git a/data/migrations/versions/224ce4c72c2f_add_last_accessed_field_to_user_table.py b/data/migrations/versions/224ce4c72c2f_add_last_accessed_field_to_user_table.py
index 9b9bb1978..6208a6a6d 100644
--- a/data/migrations/versions/224ce4c72c2f_add_last_accessed_field_to_user_table.py
+++ b/data/migrations/versions/224ce4c72c2f_add_last_accessed_field_to_user_table.py
@@ -7,8 +7,8 @@ Create Date: 2018-03-12 22:44:07.070490
"""
# revision identifiers, used by Alembic.
-revision = '224ce4c72c2f'
-down_revision = 'b547bc139ad8'
+revision = "224ce4c72c2f"
+down_revision = "b547bc139ad8"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -18,18 +18,18 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('user', sa.Column('last_accessed', sa.DateTime(), nullable=True))
- op.create_index('user_last_accessed', 'user', ['last_accessed'], unique=False)
+ op.add_column("user", sa.Column("last_accessed", sa.DateTime(), nullable=True))
+ op.create_index("user_last_accessed", "user", ["last_accessed"], unique=False)
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('user', 'last_accessed', tester.TestDataType.DateTime)
+ tester.populate_column("user", "last_accessed", tester.TestDataType.DateTime)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('user_last_accessed', table_name='user')
- op.drop_column('user', 'last_accessed')
+ op.drop_index("user_last_accessed", table_name="user")
+ op.drop_column("user", "last_accessed")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py b/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py
index 2b73b8afa..071c1906c 100644
--- a/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py
+++ b/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py
@@ -7,8 +7,8 @@ Create Date: 2019-10-07 13:11:20.424715
"""
# revision identifiers, used by Alembic.
-revision = '34c8ef052ec9'
-down_revision = 'cc6778199cdb'
+revision = "34c8ef052ec9"
+down_revision = "cc6778199cdb"
from alembic import op
from alembic import op as original_op
@@ -17,8 +17,17 @@ from datetime import datetime
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from peewee import ForeignKeyField, DateTimeField, BooleanField
-from data.database import (BaseModel, RepoMirrorType, RepoMirrorStatus, RepoMirrorRule, uuid_generator,
- QuayUserField, Repository, IntegerField, JSONField)
+from data.database import (
+ BaseModel,
+ RepoMirrorType,
+ RepoMirrorStatus,
+ RepoMirrorRule,
+ uuid_generator,
+ QuayUserField,
+ Repository,
+ IntegerField,
+ JSONField,
+)
from data.fields import EnumField as ClientEnumField, CharField, EncryptedCharField
import logging
@@ -30,100 +39,150 @@ BATCH_SIZE = 10
# Original model
class RepoMirrorConfig(BaseModel):
- """
+ """
Represents a repository to be mirrored and any additional configuration
required to perform the mirroring.
"""
- repository = ForeignKeyField(Repository, index=True, unique=True, backref='mirror')
- creation_date = DateTimeField(default=datetime.utcnow)
- is_enabled = BooleanField(default=True)
- # Mirror Configuration
- mirror_type = ClientEnumField(RepoMirrorType, default=RepoMirrorType.PULL)
- internal_robot = QuayUserField(allows_robots=True, null=True, backref='mirrorpullrobot',
- robot_null_delete=True)
- external_reference = CharField()
- external_registry = CharField()
- external_namespace = CharField()
- external_repository = CharField()
- external_registry_username = EncryptedCharField(max_length=2048, null=True)
- external_registry_password = EncryptedCharField(max_length=2048, null=True)
- external_registry_config = JSONField(default={})
+ repository = ForeignKeyField(Repository, index=True, unique=True, backref="mirror")
+ creation_date = DateTimeField(default=datetime.utcnow)
+ is_enabled = BooleanField(default=True)
- # Worker Queuing
- sync_interval = IntegerField() # seconds between syncs
- sync_start_date = DateTimeField(null=True) # next start time
- sync_expiration_date = DateTimeField(null=True) # max duration
- sync_retries_remaining = IntegerField(default=3)
- sync_status = ClientEnumField(RepoMirrorStatus, default=RepoMirrorStatus.NEVER_RUN)
- sync_transaction_id = CharField(default=uuid_generator, max_length=36)
+ # Mirror Configuration
+ mirror_type = ClientEnumField(RepoMirrorType, default=RepoMirrorType.PULL)
+ internal_robot = QuayUserField(
+ allows_robots=True, null=True, backref="mirrorpullrobot", robot_null_delete=True
+ )
+ external_reference = CharField()
+ external_registry = CharField()
+ external_namespace = CharField()
+ external_repository = CharField()
+ external_registry_username = EncryptedCharField(max_length=2048, null=True)
+ external_registry_password = EncryptedCharField(max_length=2048, null=True)
+ external_registry_config = JSONField(default={})
- # Tag-Matching Rules
- root_rule = ForeignKeyField(RepoMirrorRule)
+ # Worker Queuing
+ sync_interval = IntegerField() # seconds between syncs
+ sync_start_date = DateTimeField(null=True) # next start time
+ sync_expiration_date = DateTimeField(null=True) # max duration
+ sync_retries_remaining = IntegerField(default=3)
+ sync_status = ClientEnumField(RepoMirrorStatus, default=RepoMirrorStatus.NEVER_RUN)
+ sync_transaction_id = CharField(default=uuid_generator, max_length=36)
+
+ # Tag-Matching Rules
+ root_rule = ForeignKeyField(RepoMirrorRule)
def _iterate(model_class, clause):
- while True:
- has_rows = False
- for row in list(model_class.select().where(clause).limit(BATCH_SIZE)):
- has_rows = True
- yield row
+ while True:
+ has_rows = False
+ for row in list(model_class.select().where(clause).limit(BATCH_SIZE)):
+ has_rows = True
+ yield row
- if not has_rows:
- break
+ if not has_rows:
+ break
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
+ op = ProgressWrapper(original_op, progress_reporter)
- logger.info('Migrating to external_reference from existing columns')
+ logger.info("Migrating to external_reference from existing columns")
- op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True))
+ op.add_column(
+ "repomirrorconfig", sa.Column("external_reference", sa.Text(), nullable=True)
+ )
- from app import app
- if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
- for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)):
- repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository)
- logger.info('migrating %s' % repo)
- repo_mirror.external_reference = repo
- repo_mirror.save()
+ from app import app
- op.drop_column('repomirrorconfig', 'external_registry')
- op.drop_column('repomirrorconfig', 'external_namespace')
- op.drop_column('repomirrorconfig', 'external_repository')
+ if app.config.get("SETUP_COMPLETE", False) or tester.is_testing:
+ for repo_mirror in _iterate(
+ RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)
+ ):
+ repo = "%s/%s/%s" % (
+ repo_mirror.external_registry,
+ repo_mirror.external_namespace,
+ repo_mirror.external_repository,
+ )
+ logger.info("migrating %s" % repo)
+ repo_mirror.external_reference = repo
+ repo_mirror.save()
- op.alter_column('repomirrorconfig', 'external_reference', nullable=False, existing_type=sa.Text())
+ op.drop_column("repomirrorconfig", "external_registry")
+ op.drop_column("repomirrorconfig", "external_namespace")
+ op.drop_column("repomirrorconfig", "external_repository")
+ op.alter_column(
+ "repomirrorconfig",
+ "external_reference",
+ nullable=False,
+ existing_type=sa.Text(),
+ )
- tester.populate_column('repomirrorconfig', 'external_reference', tester.TestDataType.String)
+ tester.populate_column(
+ "repomirrorconfig", "external_reference", tester.TestDataType.String
+ )
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
+ op = ProgressWrapper(original_op, progress_reporter)
- '''
+ """
This will downgrade existing data but may not exactly match previous data structure. If the
external_reference does not have three parts (registry, namespace, repository) then a failed
value is inserted.
- '''
+ """
- op.add_column('repomirrorconfig', sa.Column('external_registry', sa.String(length=255), nullable=True))
- op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True))
- op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True))
+ op.add_column(
+ "repomirrorconfig",
+ sa.Column("external_registry", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ "repomirrorconfig",
+ sa.Column("external_namespace", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ "repomirrorconfig",
+ sa.Column("external_repository", sa.String(length=255), nullable=True),
+ )
- from app import app
- if app.config.get('SETUP_COMPLETE', False):
- logger.info('Restoring columns from external_reference')
- for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)):
- logger.info('Restoring %s' % repo_mirror.external_reference)
- parts = repo_mirror.external_reference.split('/', 2)
- repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED'
- repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED'
- repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED'
- repo_mirror.save()
+ from app import app
- op.drop_column('repomirrorconfig', 'external_reference')
+ if app.config.get("SETUP_COMPLETE", False):
+ logger.info("Restoring columns from external_reference")
+ for repo_mirror in _iterate(
+ RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)
+ ):
+ logger.info("Restoring %s" % repo_mirror.external_reference)
+ parts = repo_mirror.external_reference.split("/", 2)
+ repo_mirror.external_registry = (
+ parts[0] if len(parts) >= 1 else "DOWNGRADE-FAILED"
+ )
+ repo_mirror.external_namespace = (
+ parts[1] if len(parts) >= 2 else "DOWNGRADE-FAILED"
+ )
+ repo_mirror.external_repository = (
+ parts[2] if len(parts) >= 3 else "DOWNGRADE-FAILED"
+ )
+ repo_mirror.save()
- op.alter_column('repomirrorconfig', 'external_registry', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('repomirrorconfig', 'external_namespace', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('repomirrorconfig', 'external_repository', nullable=False, existing_type=sa.String(length=255))
+ op.drop_column("repomirrorconfig", "external_reference")
+
+ op.alter_column(
+ "repomirrorconfig",
+ "external_registry",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "repomirrorconfig",
+ "external_namespace",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "repomirrorconfig",
+ "external_repository",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
diff --git a/data/migrations/versions/3e8cc74a1e7b_add_severity_and_media_type_to_global_.py b/data/migrations/versions/3e8cc74a1e7b_add_severity_and_media_type_to_global_.py
index 87e6f8890..47961514b 100644
--- a/data/migrations/versions/3e8cc74a1e7b_add_severity_and_media_type_to_global_.py
+++ b/data/migrations/versions/3e8cc74a1e7b_add_severity_and_media_type_to_global_.py
@@ -7,57 +7,78 @@ Create Date: 2017-01-17 16:22:28.584237
"""
# revision identifiers, used by Alembic.
-revision = '3e8cc74a1e7b'
-down_revision = 'fc47c1ec019f'
+revision = "3e8cc74a1e7b"
+down_revision = "fc47c1ec019f"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('messages', sa.Column('media_type_id', sa.Integer(), nullable=False, server_default='1'))
- op.add_column('messages', sa.Column('severity', sa.String(length=255), nullable=False, server_default='info'))
- op.alter_column('messages', 'uuid',
- existing_type=mysql.VARCHAR(length=36),
- server_default='',
- nullable=False)
- op.create_index('messages_media_type_id', 'messages', ['media_type_id'], unique=False)
- op.create_index('messages_severity', 'messages', ['severity'], unique=False)
- op.create_index('messages_uuid', 'messages', ['uuid'], unique=False)
- op.create_foreign_key(op.f('fk_messages_media_type_id_mediatype'), 'messages', 'mediatype', ['media_type_id'], ['id'])
+ op.add_column(
+ "messages",
+ sa.Column("media_type_id", sa.Integer(), nullable=False, server_default="1"),
+ )
+ op.add_column(
+ "messages",
+ sa.Column(
+ "severity", sa.String(length=255), nullable=False, server_default="info"
+ ),
+ )
+ op.alter_column(
+ "messages",
+ "uuid",
+ existing_type=mysql.VARCHAR(length=36),
+ server_default="",
+ nullable=False,
+ )
+ op.create_index(
+ "messages_media_type_id", "messages", ["media_type_id"], unique=False
+ )
+ op.create_index("messages_severity", "messages", ["severity"], unique=False)
+ op.create_index("messages_uuid", "messages", ["uuid"], unique=False)
+ op.create_foreign_key(
+ op.f("fk_messages_media_type_id_mediatype"),
+ "messages",
+ "mediatype",
+ ["media_type_id"],
+ ["id"],
+ )
# ### end Alembic commands ###
- op.bulk_insert(tables.mediatype,
- [
- {'name': 'text/markdown'},
- ])
+ op.bulk_insert(tables.mediatype, [{"name": "text/markdown"}])
# ### population of test data ### #
- tester.populate_column('messages', 'media_type_id', tester.TestDataType.Foreign('mediatype'))
- tester.populate_column('messages', 'severity', lambda: 'info')
- tester.populate_column('messages', 'uuid', tester.TestDataType.UUID)
+ tester.populate_column(
+ "messages", "media_type_id", tester.TestDataType.Foreign("mediatype")
+ )
+ tester.populate_column("messages", "severity", lambda: "info")
+ tester.populate_column("messages", "uuid", tester.TestDataType.UUID)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_constraint(op.f('fk_messages_media_type_id_mediatype'), 'messages', type_='foreignkey')
- op.drop_index('messages_uuid', table_name='messages')
- op.drop_index('messages_severity', table_name='messages')
- op.drop_index('messages_media_type_id', table_name='messages')
- op.alter_column('messages', 'uuid',
- existing_type=mysql.VARCHAR(length=36),
- nullable=True)
- op.drop_column('messages', 'severity')
- op.drop_column('messages', 'media_type_id')
+ op.drop_constraint(
+ op.f("fk_messages_media_type_id_mediatype"), "messages", type_="foreignkey"
+ )
+ op.drop_index("messages_uuid", table_name="messages")
+ op.drop_index("messages_severity", table_name="messages")
+ op.drop_index("messages_media_type_id", table_name="messages")
+ op.alter_column(
+ "messages", "uuid", existing_type=mysql.VARCHAR(length=36), nullable=True
+ )
+ op.drop_column("messages", "severity")
+ op.drop_column("messages", "media_type_id")
# ### end Alembic commands ###
- op.execute(tables
- .mediatype
- .delete()
- .where(tables.
- mediatype.c.name == op.inline_literal('text/markdown')))
+ op.execute(
+ tables.mediatype.delete().where(
+ tables.mediatype.c.name == op.inline_literal("text/markdown")
+ )
+ )
diff --git a/data/migrations/versions/45fd8b9869d4_add_notification_type.py b/data/migrations/versions/45fd8b9869d4_add_notification_type.py
index 66f5c0870..81c038942 100644
--- a/data/migrations/versions/45fd8b9869d4_add_notification_type.py
+++ b/data/migrations/versions/45fd8b9869d4_add_notification_type.py
@@ -7,24 +7,22 @@ Create Date: 2016-12-01 12:02:19.724528
"""
# revision identifiers, used by Alembic.
-revision = '45fd8b9869d4'
-down_revision = '94836b099894'
+revision = "45fd8b9869d4"
+down_revision = "94836b099894"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.bulk_insert(tables.notificationkind,
- [
- {'name': 'build_cancelled'},
- ])
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.bulk_insert(tables.notificationkind, [{"name": "build_cancelled"}])
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.execute(tables
- .notificationkind
- .delete()
- .where(tables.
- notificationkind.c.name == op.inline_literal('build_cancelled')))
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.execute(
+ tables.notificationkind.delete().where(
+ tables.notificationkind.c.name == op.inline_literal("build_cancelled")
+ )
+ )
diff --git a/data/migrations/versions/481623ba00ba_add_index_on_logs_archived_on_.py b/data/migrations/versions/481623ba00ba_add_index_on_logs_archived_on_.py
index da8476f8a..d2bfd0474 100644
--- a/data/migrations/versions/481623ba00ba_add_index_on_logs_archived_on_.py
+++ b/data/migrations/versions/481623ba00ba_add_index_on_logs_archived_on_.py
@@ -7,21 +7,27 @@ Create Date: 2019-02-15 16:09:47.326805
"""
# revision identifiers, used by Alembic.
-revision = '481623ba00ba'
-down_revision = 'b9045731c4de'
+revision = "481623ba00ba"
+down_revision = "b9045731c4de"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_index('repositorybuild_logs_archived', 'repositorybuild', ['logs_archived'], unique=False)
+ op.create_index(
+ "repositorybuild_logs_archived",
+ "repositorybuild",
+ ["logs_archived"],
+ unique=False,
+ )
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('repositorybuild_logs_archived', table_name='repositorybuild')
+ op.drop_index("repositorybuild_logs_archived", table_name="repositorybuild")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/5248ddf35167_repository_mirror.py b/data/migrations/versions/5248ddf35167_repository_mirror.py
index 8bb806105..6842cc1d5 100644
--- a/data/migrations/versions/5248ddf35167_repository_mirror.py
+++ b/data/migrations/versions/5248ddf35167_repository_mirror.py
@@ -6,139 +6,229 @@ Create Date: 2019-06-25 16:22:36.310532
"""
-revision = '5248ddf35167'
-down_revision = 'b918abdbee43'
+revision = "5248ddf35167"
+down_revision = "b918abdbee43"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.create_table('repomirrorrule',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=36), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('creation_date', sa.DateTime(), nullable=False),
- sa.Column('rule_type', sa.Integer(), nullable=False),
- sa.Column('rule_value', sa.Text(), nullable=False),
- sa.Column('left_child_id', sa.Integer(), nullable=True),
- sa.Column('right_child_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['left_child_id'], ['repomirrorrule.id'], name=op.f('fk_repomirrorrule_left_child_id_repomirrorrule')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repomirrorrule_repository_id_repository')),
- sa.ForeignKeyConstraint(['right_child_id'], ['repomirrorrule.id'], name=op.f('fk_repomirrorrule_right_child_id_repomirrorrule')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repomirrorrule')))
- op.create_index('repomirrorrule_left_child_id', 'repomirrorrule', ['left_child_id'], unique=False)
- op.create_index('repomirrorrule_repository_id', 'repomirrorrule', ['repository_id'], unique=False)
- op.create_index('repomirrorrule_right_child_id', 'repomirrorrule', ['right_child_id'], unique=False)
- op.create_index('repomirrorrule_rule_type', 'repomirrorrule', ['rule_type'], unique=False)
- op.create_index('repomirrorrule_uuid', 'repomirrorrule', ['uuid'], unique=True)
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.create_table(
+ "repomirrorrule",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=36), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("creation_date", sa.DateTime(), nullable=False),
+ sa.Column("rule_type", sa.Integer(), nullable=False),
+ sa.Column("rule_value", sa.Text(), nullable=False),
+ sa.Column("left_child_id", sa.Integer(), nullable=True),
+ sa.Column("right_child_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["left_child_id"],
+ ["repomirrorrule.id"],
+ name=op.f("fk_repomirrorrule_left_child_id_repomirrorrule"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repomirrorrule_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["right_child_id"],
+ ["repomirrorrule.id"],
+ name=op.f("fk_repomirrorrule_right_child_id_repomirrorrule"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repomirrorrule")),
+ )
+ op.create_index(
+ "repomirrorrule_left_child_id",
+ "repomirrorrule",
+ ["left_child_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorrule_repository_id",
+ "repomirrorrule",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorrule_right_child_id",
+ "repomirrorrule",
+ ["right_child_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorrule_rule_type", "repomirrorrule", ["rule_type"], unique=False
+ )
+ op.create_index("repomirrorrule_uuid", "repomirrorrule", ["uuid"], unique=True)
- op.create_table('repomirrorconfig',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('creation_date', sa.DateTime(), nullable=False),
- sa.Column('is_enabled', sa.Boolean(), nullable=False),
- sa.Column('mirror_type', sa.Integer(), nullable=False),
- sa.Column('internal_robot_id', sa.Integer(), nullable=False),
- sa.Column('external_registry', sa.String(length=255), nullable=False),
- sa.Column('external_namespace', sa.String(length=255), nullable=False),
- sa.Column('external_repository', sa.String(length=255), nullable=False),
- sa.Column('external_registry_username', sa.String(length=2048), nullable=True),
- sa.Column('external_registry_password', sa.String(length=2048), nullable=True),
- sa.Column('external_registry_config', sa.Text(), nullable=False),
- sa.Column('sync_interval', sa.Integer(), nullable=False, server_default='60'),
- sa.Column('sync_start_date', sa.DateTime(), nullable=True),
- sa.Column('sync_expiration_date', sa.DateTime(), nullable=True),
- sa.Column('sync_retries_remaining', sa.Integer(), nullable=False, server_default='3'),
- sa.Column('sync_status', sa.Integer(), nullable=False),
- sa.Column('sync_transaction_id', sa.String(length=36), nullable=True),
- sa.Column('root_rule_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repomirrorconfig_repository_id_repository')),
- sa.ForeignKeyConstraint(['root_rule_id'], ['repomirrorrule.id'], name=op.f('fk_repomirrorconfig_root_rule_id_repomirrorrule')),
- sa.ForeignKeyConstraint(['internal_robot_id'], ['user.id'], name=op.f('fk_repomirrorconfig_internal_robot_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repomirrorconfig'))
- )
- op.create_index('repomirrorconfig_mirror_type', 'repomirrorconfig', ['mirror_type'], unique=False)
- op.create_index('repomirrorconfig_repository_id', 'repomirrorconfig', ['repository_id'], unique=True)
- op.create_index('repomirrorconfig_root_rule_id', 'repomirrorconfig', ['root_rule_id'], unique=False)
- op.create_index('repomirrorconfig_sync_status', 'repomirrorconfig', ['sync_status'], unique=False)
- op.create_index('repomirrorconfig_sync_transaction_id', 'repomirrorconfig', ['sync_transaction_id'], unique=False)
- op.create_index('repomirrorconfig_internal_robot_id', 'repomirrorconfig', ['internal_robot_id'], unique=False)
+ op.create_table(
+ "repomirrorconfig",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("creation_date", sa.DateTime(), nullable=False),
+ sa.Column("is_enabled", sa.Boolean(), nullable=False),
+ sa.Column("mirror_type", sa.Integer(), nullable=False),
+ sa.Column("internal_robot_id", sa.Integer(), nullable=False),
+ sa.Column("external_registry", sa.String(length=255), nullable=False),
+ sa.Column("external_namespace", sa.String(length=255), nullable=False),
+ sa.Column("external_repository", sa.String(length=255), nullable=False),
+ sa.Column("external_registry_username", sa.String(length=2048), nullable=True),
+ sa.Column("external_registry_password", sa.String(length=2048), nullable=True),
+ sa.Column("external_registry_config", sa.Text(), nullable=False),
+ sa.Column("sync_interval", sa.Integer(), nullable=False, server_default="60"),
+ sa.Column("sync_start_date", sa.DateTime(), nullable=True),
+ sa.Column("sync_expiration_date", sa.DateTime(), nullable=True),
+ sa.Column(
+ "sync_retries_remaining", sa.Integer(), nullable=False, server_default="3"
+ ),
+ sa.Column("sync_status", sa.Integer(), nullable=False),
+ sa.Column("sync_transaction_id", sa.String(length=36), nullable=True),
+ sa.Column("root_rule_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repomirrorconfig_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["root_rule_id"],
+ ["repomirrorrule.id"],
+ name=op.f("fk_repomirrorconfig_root_rule_id_repomirrorrule"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["internal_robot_id"],
+ ["user.id"],
+ name=op.f("fk_repomirrorconfig_internal_robot_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repomirrorconfig")),
+ )
+ op.create_index(
+ "repomirrorconfig_mirror_type",
+ "repomirrorconfig",
+ ["mirror_type"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorconfig_repository_id",
+ "repomirrorconfig",
+ ["repository_id"],
+ unique=True,
+ )
+ op.create_index(
+ "repomirrorconfig_root_rule_id",
+ "repomirrorconfig",
+ ["root_rule_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorconfig_sync_status",
+ "repomirrorconfig",
+ ["sync_status"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorconfig_sync_transaction_id",
+ "repomirrorconfig",
+ ["sync_transaction_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repomirrorconfig_internal_robot_id",
+ "repomirrorconfig",
+ ["internal_robot_id"],
+ unique=False,
+ )
- op.add_column(u'repository', sa.Column('state', sa.Integer(), nullable=False, server_default='0'))
- op.create_index('repository_state', 'repository', ['state'], unique=False)
+ op.add_column(
+ u"repository",
+ sa.Column("state", sa.Integer(), nullable=False, server_default="0"),
+ )
+ op.create_index("repository_state", "repository", ["state"], unique=False)
- op.bulk_insert(tables.logentrykind,
- [
- {'name': 'repo_mirror_enabled'},
- {'name': 'repo_mirror_disabled'},
- {'name': 'repo_mirror_config_changed'},
- {'name': 'repo_mirror_sync_started'},
- {'name': 'repo_mirror_sync_failed'},
- {'name': 'repo_mirror_sync_success'},
- {'name': 'repo_mirror_sync_now_requested'},
- {'name': 'repo_mirror_sync_tag_success'},
- {'name': 'repo_mirror_sync_tag_failed'},
- {'name': 'repo_mirror_sync_test_success'},
- {'name': 'repo_mirror_sync_test_failed'},
- {'name': 'repo_mirror_sync_test_started'},
- {'name': 'change_repo_state'}
- ])
+ op.bulk_insert(
+ tables.logentrykind,
+ [
+ {"name": "repo_mirror_enabled"},
+ {"name": "repo_mirror_disabled"},
+ {"name": "repo_mirror_config_changed"},
+ {"name": "repo_mirror_sync_started"},
+ {"name": "repo_mirror_sync_failed"},
+ {"name": "repo_mirror_sync_success"},
+ {"name": "repo_mirror_sync_now_requested"},
+ {"name": "repo_mirror_sync_tag_success"},
+ {"name": "repo_mirror_sync_tag_failed"},
+ {"name": "repo_mirror_sync_test_success"},
+ {"name": "repo_mirror_sync_test_failed"},
+ {"name": "repo_mirror_sync_test_started"},
+ {"name": "change_repo_state"},
+ ],
+ )
+ tester.populate_table(
+ "repomirrorrule",
+ [
+ ("uuid", tester.TestDataType.String),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("creation_date", tester.TestDataType.DateTime),
+ ("rule_type", tester.TestDataType.Integer),
+ ("rule_value", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('repomirrorrule', [
- ('uuid', tester.TestDataType.String),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('creation_date', tester.TestDataType.DateTime),
- ('rule_type', tester.TestDataType.Integer),
- ('rule_value', tester.TestDataType.String),
- ])
-
- tester.populate_table('repomirrorconfig', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('creation_date', tester.TestDataType.DateTime),
- ('is_enabled', tester.TestDataType.Boolean),
- ('mirror_type', tester.TestDataType.Constant(1)),
- ('internal_robot_id', tester.TestDataType.Foreign('user')),
- ('external_registry', tester.TestDataType.String),
- ('external_namespace', tester.TestDataType.String),
- ('external_repository', tester.TestDataType.String),
- ('external_registry_username', tester.TestDataType.String),
- ('external_registry_password', tester.TestDataType.String),
- ('external_registry_config', tester.TestDataType.JSON),
- ('sync_start_date', tester.TestDataType.DateTime),
- ('sync_expiration_date', tester.TestDataType.DateTime),
- ('sync_retries_remaining', tester.TestDataType.Integer),
- ('sync_status', tester.TestDataType.Constant(0)),
- ('sync_transaction_id', tester.TestDataType.String),
- ('root_rule_id', tester.TestDataType.Foreign('repomirrorrule')),
- ])
+ tester.populate_table(
+ "repomirrorconfig",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("creation_date", tester.TestDataType.DateTime),
+ ("is_enabled", tester.TestDataType.Boolean),
+ ("mirror_type", tester.TestDataType.Constant(1)),
+ ("internal_robot_id", tester.TestDataType.Foreign("user")),
+ ("external_registry", tester.TestDataType.String),
+ ("external_namespace", tester.TestDataType.String),
+ ("external_repository", tester.TestDataType.String),
+ ("external_registry_username", tester.TestDataType.String),
+ ("external_registry_password", tester.TestDataType.String),
+ ("external_registry_config", tester.TestDataType.JSON),
+ ("sync_start_date", tester.TestDataType.DateTime),
+ ("sync_expiration_date", tester.TestDataType.DateTime),
+ ("sync_retries_remaining", tester.TestDataType.Integer),
+ ("sync_status", tester.TestDataType.Constant(0)),
+ ("sync_transaction_id", tester.TestDataType.String),
+ ("root_rule_id", tester.TestDataType.Foreign("repomirrorrule")),
+ ],
+ )
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.drop_column(u'repository', 'state')
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.drop_column(u"repository", "state")
- op.drop_table('repomirrorconfig')
+ op.drop_table("repomirrorconfig")
- op.drop_table('repomirrorrule')
+ op.drop_table("repomirrorrule")
- for logentrykind in [
- 'repo_mirror_enabled',
- 'repo_mirror_disabled',
- 'repo_mirror_config_changed',
- 'repo_mirror_sync_started',
- 'repo_mirror_sync_failed',
- 'repo_mirror_sync_success',
- 'repo_mirror_sync_now_requested',
- 'repo_mirror_sync_tag_success',
- 'repo_mirror_sync_tag_failed',
- 'repo_mirror_sync_test_success',
- 'repo_mirror_sync_test_failed',
- 'repo_mirror_sync_test_started',
- 'change_repo_state'
- ]:
- op.execute(tables.logentrykind.delete()
- .where(tables.logentrykind.c.name == op.inline_literal(logentrykind)))
+ for logentrykind in [
+ "repo_mirror_enabled",
+ "repo_mirror_disabled",
+ "repo_mirror_config_changed",
+ "repo_mirror_sync_started",
+ "repo_mirror_sync_failed",
+ "repo_mirror_sync_success",
+ "repo_mirror_sync_now_requested",
+ "repo_mirror_sync_tag_success",
+ "repo_mirror_sync_tag_failed",
+ "repo_mirror_sync_test_success",
+ "repo_mirror_sync_test_failed",
+ "repo_mirror_sync_test_started",
+ "change_repo_state",
+ ]:
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.c.name == op.inline_literal(logentrykind)
+ )
+ )
diff --git a/data/migrations/versions/53e2ac668296_remove_reference_to_subdir.py b/data/migrations/versions/53e2ac668296_remove_reference_to_subdir.py
index e0b61814b..b30ad7ff5 100644
--- a/data/migrations/versions/53e2ac668296_remove_reference_to_subdir.py
+++ b/data/migrations/versions/53e2ac668296_remove_reference_to_subdir.py
@@ -17,38 +17,44 @@ from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
-revision = '53e2ac668296'
-down_revision = 'ed01e313d3cb'
+revision = "53e2ac668296"
+down_revision = "ed01e313d3cb"
log = logging.getLogger(__name__)
def run_migration(migrate_function, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- conn = op.get_bind()
- triggers = conn.execute("SELECT id, config FROM repositorybuildtrigger")
- for trigger in triggers:
- config = json.dumps(migrate_function(json.loads(trigger[1])))
- try:
- conn.execute("UPDATE repositorybuildtrigger SET config=%s WHERE id=%s", config, trigger[0])
- except(RevisionError, CommandError) as e:
- log.warning("Failed to update build trigger %s with exception: ", trigger[0], e)
+ op = ProgressWrapper(original_op, progress_reporter)
+ conn = op.get_bind()
+ triggers = conn.execute("SELECT id, config FROM repositorybuildtrigger")
+ for trigger in triggers:
+ config = json.dumps(migrate_function(json.loads(trigger[1])))
+ try:
+ conn.execute(
+ "UPDATE repositorybuildtrigger SET config=%s WHERE id=%s",
+ config,
+ trigger[0],
+ )
+ except (RevisionError, CommandError) as e:
+ log.warning(
+ "Failed to update build trigger %s with exception: ", trigger[0], e
+ )
def upgrade(tables, tester, progress_reporter):
- run_migration(delete_subdir, progress_reporter)
+ run_migration(delete_subdir, progress_reporter)
def downgrade(tables, tester, progress_reporter):
- run_migration(add_subdir, progress_reporter)
+ run_migration(add_subdir, progress_reporter)
def delete_subdir(config):
""" Remove subdir from config """
if not config:
return config
- if 'subdir' in config:
- del config['subdir']
+ if "subdir" in config:
+ del config["subdir"]
return config
@@ -57,7 +63,7 @@ def add_subdir(config):
""" Add subdir back into config """
if not config:
return config
- if 'context' in config:
- config['subdir'] = config['context']
+ if "context" in config:
+ config["subdir"] = config["context"]
return config
diff --git a/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py b/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py
index efe900ad7..8ac76c4d6 100644
--- a/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py
+++ b/data/migrations/versions/54492a68a3cf_add_namespacegeorestriction_table.py
@@ -7,43 +7,67 @@ Create Date: 2018-12-05 15:12:14.201116
"""
# revision identifiers, used by Alembic.
-revision = '54492a68a3cf'
-down_revision = 'c00a1f15968b'
+revision = "54492a68a3cf"
+down_revision = "c00a1f15968b"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('namespacegeorestriction',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('namespace_id', sa.Integer(), nullable=False),
- sa.Column('added', sa.DateTime(), nullable=False),
- sa.Column('description', sa.String(length=255), nullable=False),
- sa.Column('unstructured_json', sa.Text(), nullable=False),
- sa.Column('restricted_region_iso_code', sa.String(length=255), nullable=False),
- sa.ForeignKeyConstraint(['namespace_id'], ['user.id'], name=op.f('fk_namespacegeorestriction_namespace_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_namespacegeorestriction'))
+ op.create_table(
+ "namespacegeorestriction",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("namespace_id", sa.Integer(), nullable=False),
+ sa.Column("added", sa.DateTime(), nullable=False),
+ sa.Column("description", sa.String(length=255), nullable=False),
+ sa.Column("unstructured_json", sa.Text(), nullable=False),
+ sa.Column("restricted_region_iso_code", sa.String(length=255), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["namespace_id"],
+ ["user.id"],
+ name=op.f("fk_namespacegeorestriction_namespace_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_namespacegeorestriction")),
+ )
+ op.create_index(
+ "namespacegeorestriction_namespace_id",
+ "namespacegeorestriction",
+ ["namespace_id"],
+ unique=False,
+ )
+ op.create_index(
+ "namespacegeorestriction_namespace_id_restricted_region_iso_code",
+ "namespacegeorestriction",
+ ["namespace_id", "restricted_region_iso_code"],
+ unique=True,
+ )
+ op.create_index(
+ "namespacegeorestriction_restricted_region_iso_code",
+ "namespacegeorestriction",
+ ["restricted_region_iso_code"],
+ unique=False,
)
- op.create_index('namespacegeorestriction_namespace_id', 'namespacegeorestriction', ['namespace_id'], unique=False)
- op.create_index('namespacegeorestriction_namespace_id_restricted_region_iso_code', 'namespacegeorestriction', ['namespace_id', 'restricted_region_iso_code'], unique=True)
- op.create_index('namespacegeorestriction_restricted_region_iso_code', 'namespacegeorestriction', ['restricted_region_iso_code'], unique=False)
# ### end Alembic commands ###
- tester.populate_table('namespacegeorestriction', [
- ('namespace_id', tester.TestDataType.Foreign('user')),
- ('added', tester.TestDataType.DateTime),
- ('description', tester.TestDataType.String),
- ('unstructured_json', tester.TestDataType.JSON),
- ('restricted_region_iso_code', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "namespacegeorestriction",
+ [
+ ("namespace_id", tester.TestDataType.Foreign("user")),
+ ("added", tester.TestDataType.DateTime),
+ ("description", tester.TestDataType.String),
+ ("unstructured_json", tester.TestDataType.JSON),
+ ("restricted_region_iso_code", tester.TestDataType.String),
+ ],
+ )
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('namespacegeorestriction')
+ op.drop_table("namespacegeorestriction")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/5b7503aada1b_cleanup_old_robots.py b/data/migrations/versions/5b7503aada1b_cleanup_old_robots.py
index 89b469d6b..b5b6923bc 100644
--- a/data/migrations/versions/5b7503aada1b_cleanup_old_robots.py
+++ b/data/migrations/versions/5b7503aada1b_cleanup_old_robots.py
@@ -7,8 +7,8 @@ Create Date: 2018-05-09 17:18:52.230504
"""
# revision identifiers, used by Alembic.
-revision = '5b7503aada1b'
-down_revision = '224ce4c72c2f'
+revision = "5b7503aada1b"
+down_revision = "224ce4c72c2f"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,11 +16,13 @@ import sqlalchemy as sa
from util.migrate.cleanup_old_robots import cleanup_old_robots
+
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- cleanup_old_robots()
+ op = ProgressWrapper(original_op, progress_reporter)
+ cleanup_old_robots()
+
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- # Nothing to do.
- pass
+ op = ProgressWrapper(original_op, progress_reporter)
+ # Nothing to do.
+ pass
diff --git a/data/migrations/versions/5cbbfc95bac7_remove_oci_tables_not_used_by_cnr_the_.py b/data/migrations/versions/5cbbfc95bac7_remove_oci_tables_not_used_by_cnr_the_.py
index 46a2c3cec..73917f139 100644
--- a/data/migrations/versions/5cbbfc95bac7_remove_oci_tables_not_used_by_cnr_the_.py
+++ b/data/migrations/versions/5cbbfc95bac7_remove_oci_tables_not_used_by_cnr_the_.py
@@ -7,8 +7,8 @@ Create Date: 2018-05-23 17:28:40.114433
"""
# revision identifiers, used by Alembic.
-revision = '5cbbfc95bac7'
-down_revision = '1783530bee68'
+revision = "5cbbfc95bac7"
+down_revision = "1783530bee68"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,17 +16,18 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from util.migrate import UTF8LongText, UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('derivedimage')
- op.drop_table('manifestlabel')
- op.drop_table('blobplacementlocationpreference')
- op.drop_table('blobuploading')
- op.drop_table('bittorrentpieces')
- op.drop_table('manifestlayerdockerv1')
- op.drop_table('manifestlayerscan')
- op.drop_table('manifestlayer')
+ op.drop_table("derivedimage")
+ op.drop_table("manifestlabel")
+ op.drop_table("blobplacementlocationpreference")
+ op.drop_table("blobuploading")
+ op.drop_table("bittorrentpieces")
+ op.drop_table("manifestlayerdockerv1")
+ op.drop_table("manifestlayerscan")
+ op.drop_table("manifestlayer")
# ### end Alembic commands ###
@@ -34,137 +35,277 @@ def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
- 'manifestlayer',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('manifest_index', sa.BigInteger(), nullable=False),
- sa.Column('metadata_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_manifestlayer_blob_id_blob')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlayer_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayer'))
+ "manifestlayer",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_index", sa.BigInteger(), nullable=False),
+ sa.Column("metadata_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_manifestlayer_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlayer_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayer")),
)
- op.create_index('manifestlayer_manifest_index', 'manifestlayer', ['manifest_index'], unique=False)
- op.create_index('manifestlayer_manifest_id_manifest_index', 'manifestlayer', ['manifest_id', 'manifest_index'], unique=True)
- op.create_index('manifestlayer_manifest_id', 'manifestlayer', ['manifest_id'], unique=False)
- op.create_index('manifestlayer_blob_id', 'manifestlayer', ['blob_id'], unique=False)
+ op.create_index(
+ "manifestlayer_manifest_index",
+ "manifestlayer",
+ ["manifest_index"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlayer_manifest_id_manifest_index",
+ "manifestlayer",
+ ["manifest_id", "manifest_index"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestlayer_manifest_id", "manifestlayer", ["manifest_id"], unique=False
+ )
+ op.create_index("manifestlayer_blob_id", "manifestlayer", ["blob_id"], unique=False)
op.create_table(
- 'manifestlayerscan',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('layer_id', sa.Integer(), nullable=False),
- sa.Column('scannable', sa.Boolean(), nullable=False),
- sa.Column('scanned_by', UTF8CharField(length=255), nullable=False),
- sa.ForeignKeyConstraint(['layer_id'], ['manifestlayer.id'], name=op.f('fk_manifestlayerscan_layer_id_manifestlayer')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayerscan'))
+ "manifestlayerscan",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("layer_id", sa.Integer(), nullable=False),
+ sa.Column("scannable", sa.Boolean(), nullable=False),
+ sa.Column("scanned_by", UTF8CharField(length=255), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["layer_id"],
+ ["manifestlayer.id"],
+ name=op.f("fk_manifestlayerscan_layer_id_manifestlayer"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayerscan")),
+ )
+
+ op.create_index(
+ "manifestlayerscan_layer_id", "manifestlayerscan", ["layer_id"], unique=True
)
-
- op.create_index('manifestlayerscan_layer_id', 'manifestlayerscan', ['layer_id'], unique=True)
op.create_table(
- 'bittorrentpieces',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('pieces', UTF8LongText, nullable=False),
- sa.Column('piece_length', sa.BigInteger(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_bittorrentpieces_blob_id_blob')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_bittorrentpieces'))
+ "bittorrentpieces",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("pieces", UTF8LongText, nullable=False),
+ sa.Column("piece_length", sa.BigInteger(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_bittorrentpieces_blob_id_blob")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_bittorrentpieces")),
)
- op.create_index('bittorrentpieces_blob_id_piece_length', 'bittorrentpieces', ['blob_id', 'piece_length'], unique=True)
- op.create_index('bittorrentpieces_blob_id', 'bittorrentpieces', ['blob_id'], unique=False)
+ op.create_index(
+ "bittorrentpieces_blob_id_piece_length",
+ "bittorrentpieces",
+ ["blob_id", "piece_length"],
+ unique=True,
+ )
+ op.create_index(
+ "bittorrentpieces_blob_id", "bittorrentpieces", ["blob_id"], unique=False
+ )
op.create_table(
- 'blobuploading',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.Column('byte_count', sa.BigInteger(), nullable=False),
- sa.Column('uncompressed_byte_count', sa.BigInteger(), nullable=True),
- sa.Column('chunk_count', sa.BigInteger(), nullable=False),
- sa.Column('storage_metadata', UTF8LongText, nullable=True),
- sa.Column('sha_state', UTF8LongText, nullable=True),
- sa.Column('piece_sha_state', UTF8LongText, nullable=True),
- sa.Column('piece_hashes', UTF8LongText, nullable=True),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobuploading_location_id_blobplacementlocation')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_blobuploading_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobuploading'))
+ "blobuploading",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.Column("byte_count", sa.BigInteger(), nullable=False),
+ sa.Column("uncompressed_byte_count", sa.BigInteger(), nullable=True),
+ sa.Column("chunk_count", sa.BigInteger(), nullable=False),
+ sa.Column("storage_metadata", UTF8LongText, nullable=True),
+ sa.Column("sha_state", UTF8LongText, nullable=True),
+ sa.Column("piece_sha_state", UTF8LongText, nullable=True),
+ sa.Column("piece_hashes", UTF8LongText, nullable=True),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobuploading_location_id_blobplacementlocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_blobuploading_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobuploading")),
)
-
- op.create_index('blobuploading_uuid', 'blobuploading', ['uuid'], unique=True)
- op.create_index('blobuploading_repository_id_uuid', 'blobuploading', ['repository_id', 'uuid'], unique=True)
- op.create_index('blobuploading_repository_id', 'blobuploading', ['repository_id'], unique=False)
- op.create_index('blobuploading_location_id', 'blobuploading', ['location_id'], unique=False)
- op.create_index('blobuploading_created', 'blobuploading', ['created'], unique=False)
+
+ op.create_index("blobuploading_uuid", "blobuploading", ["uuid"], unique=True)
+ op.create_index(
+ "blobuploading_repository_id_uuid",
+ "blobuploading",
+ ["repository_id", "uuid"],
+ unique=True,
+ )
+ op.create_index(
+ "blobuploading_repository_id", "blobuploading", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "blobuploading_location_id", "blobuploading", ["location_id"], unique=False
+ )
+ op.create_index("blobuploading_created", "blobuploading", ["created"], unique=False)
op.create_table(
- 'manifestlayerdockerv1',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_layer_id', sa.Integer(), nullable=False),
- sa.Column('image_id', UTF8CharField(length=255), nullable=False),
- sa.Column('checksum', UTF8CharField(length=255), nullable=False),
- sa.Column('compat_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['manifest_layer_id'], ['manifestlayer.id'], name=op.f('fk_manifestlayerdockerv1_manifest_layer_id_manifestlayer')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayerdockerv1'))
+ "manifestlayerdockerv1",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_layer_id", sa.Integer(), nullable=False),
+ sa.Column("image_id", UTF8CharField(length=255), nullable=False),
+ sa.Column("checksum", UTF8CharField(length=255), nullable=False),
+ sa.Column("compat_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["manifest_layer_id"],
+ ["manifestlayer.id"],
+ name=op.f("fk_manifestlayerdockerv1_manifest_layer_id_manifestlayer"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayerdockerv1")),
)
- op.create_index('manifestlayerdockerv1_manifest_layer_id', 'manifestlayerdockerv1', ['manifest_layer_id'], unique=False)
- op.create_index('manifestlayerdockerv1_image_id', 'manifestlayerdockerv1', ['image_id'], unique=False)
+ op.create_index(
+ "manifestlayerdockerv1_manifest_layer_id",
+ "manifestlayerdockerv1",
+ ["manifest_layer_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlayerdockerv1_image_id",
+ "manifestlayerdockerv1",
+ ["image_id"],
+ unique=False,
+ )
op.create_table(
- 'manifestlabel',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('annotated_id', sa.Integer(), nullable=False),
- sa.Column('label_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['annotated_id'], ['manifest.id'], name=op.f('fk_manifestlabel_annotated_id_manifest')),
- sa.ForeignKeyConstraint(['label_id'], ['label.id'], name=op.f('fk_manifestlabel_label_id_label')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestlabel_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlabel'))
+ "manifestlabel",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("annotated_id", sa.Integer(), nullable=False),
+ sa.Column("label_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["annotated_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlabel_annotated_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["label_id"], ["label.id"], name=op.f("fk_manifestlabel_label_id_label")
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestlabel_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlabel")),
)
- op.create_index('manifestlabel_repository_id_annotated_id_label_id', 'manifestlabel', ['repository_id', 'annotated_id', 'label_id'], unique=True)
- op.create_index('manifestlabel_repository_id', 'manifestlabel', ['repository_id'], unique=False)
- op.create_index('manifestlabel_label_id', 'manifestlabel', ['label_id'], unique=False)
- op.create_index('manifestlabel_annotated_id', 'manifestlabel', ['annotated_id'], unique=False)
+ op.create_index(
+ "manifestlabel_repository_id_annotated_id_label_id",
+ "manifestlabel",
+ ["repository_id", "annotated_id", "label_id"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestlabel_repository_id", "manifestlabel", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_label_id", "manifestlabel", ["label_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_annotated_id", "manifestlabel", ["annotated_id"], unique=False
+ )
op.create_table(
- 'blobplacementlocationpreference',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobplacementlocpref_locid_blobplacementlocation')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_blobplacementlocationpreference_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacementlocationpreference'))
+ "blobplacementlocationpreference",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobplacementlocpref_locid_blobplacementlocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["user.id"],
+ name=op.f("fk_blobplacementlocationpreference_user_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacementlocationpreference")),
+ )
+ op.create_index(
+ "blobplacementlocationpreference_user_id",
+ "blobplacementlocationpreference",
+ ["user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "blobplacementlocationpreference_location_id",
+ "blobplacementlocationpreference",
+ ["location_id"],
+ unique=False,
)
- op.create_index('blobplacementlocationpreference_user_id', 'blobplacementlocationpreference', ['user_id'], unique=False)
- op.create_index('blobplacementlocationpreference_location_id', 'blobplacementlocationpreference', ['location_id'], unique=False)
-
op.create_table(
- 'derivedimage',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('source_manifest_id', sa.Integer(), nullable=False),
- sa.Column('derived_manifest_json', UTF8LongText, nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('uniqueness_hash', sa.String(length=255), nullable=False),
- sa.Column('signature_blob_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_derivedimage_blob_id_blob')),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_derivedimage_media_type_id_mediatype')),
- sa.ForeignKeyConstraint(['signature_blob_id'], ['blob.id'], name=op.f('fk_derivedimage_signature_blob_id_blob')),
- sa.ForeignKeyConstraint(['source_manifest_id'], ['manifest.id'], name=op.f('fk_derivedimage_source_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_derivedimage'))
+ "derivedimage",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("source_manifest_id", sa.Integer(), nullable=False),
+ sa.Column("derived_manifest_json", UTF8LongText, nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("uniqueness_hash", sa.String(length=255), nullable=False),
+ sa.Column("signature_blob_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_derivedimage_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_derivedimage_media_type_id_mediatype"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["signature_blob_id"],
+ ["blob.id"],
+ name=op.f("fk_derivedimage_signature_blob_id_blob"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["source_manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_derivedimage_source_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_derivedimage")),
)
- op.create_index('derivedimage_uuid', 'derivedimage', ['uuid'], unique=True)
- op.create_index('derivedimage_uniqueness_hash', 'derivedimage', ['uniqueness_hash'], unique=True)
- op.create_index('derivedimage_source_manifest_id_media_type_id_uniqueness_hash', 'derivedimage', ['source_manifest_id', 'media_type_id', 'uniqueness_hash'], unique=True)
- op.create_index('derivedimage_source_manifest_id_blob_id', 'derivedimage', ['source_manifest_id', 'blob_id'], unique=True)
- op.create_index('derivedimage_source_manifest_id', 'derivedimage', ['source_manifest_id'], unique=False)
- op.create_index('derivedimage_signature_blob_id', 'derivedimage', ['signature_blob_id'], unique=False)
- op.create_index('derivedimage_media_type_id', 'derivedimage', ['media_type_id'], unique=False)
- op.create_index('derivedimage_blob_id', 'derivedimage', ['blob_id'], unique=False)
+ op.create_index("derivedimage_uuid", "derivedimage", ["uuid"], unique=True)
+ op.create_index(
+ "derivedimage_uniqueness_hash", "derivedimage", ["uniqueness_hash"], unique=True
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id_media_type_id_uniqueness_hash",
+ "derivedimage",
+ ["source_manifest_id", "media_type_id", "uniqueness_hash"],
+ unique=True,
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id_blob_id",
+ "derivedimage",
+ ["source_manifest_id", "blob_id"],
+ unique=True,
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id",
+ "derivedimage",
+ ["source_manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "derivedimage_signature_blob_id",
+ "derivedimage",
+ ["signature_blob_id"],
+ unique=False,
+ )
+ op.create_index(
+ "derivedimage_media_type_id", "derivedimage", ["media_type_id"], unique=False
+ )
+ op.create_index("derivedimage_blob_id", "derivedimage", ["blob_id"], unique=False)
# ### end Alembic commands ###
diff --git a/data/migrations/versions/5d463ea1e8a8_backfill_new_appr_tables.py b/data/migrations/versions/5d463ea1e8a8_backfill_new_appr_tables.py
index a0df295dc..c357b9256 100644
--- a/data/migrations/versions/5d463ea1e8a8_backfill_new_appr_tables.py
+++ b/data/migrations/versions/5d463ea1e8a8_backfill_new_appr_tables.py
@@ -7,25 +7,27 @@ Create Date: 2018-07-08 10:01:19.756126
"""
# revision identifiers, used by Alembic.
-revision = '5d463ea1e8a8'
-down_revision = '610320e9dacf'
+revision = "5d463ea1e8a8"
+down_revision = "610320e9dacf"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from util.migrate.table_ops import copy_table_contents
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
conn = op.get_bind()
- copy_table_contents('blob', 'apprblob', conn)
- copy_table_contents('manifest', 'apprmanifest', conn)
- copy_table_contents('manifestlist', 'apprmanifestlist', conn)
- copy_table_contents('blobplacement', 'apprblobplacement', conn)
- copy_table_contents('manifestblob', 'apprmanifestblob', conn)
- copy_table_contents('manifestlistmanifest', 'apprmanifestlistmanifest', conn)
- copy_table_contents('tag', 'apprtag', conn)
+ copy_table_contents("blob", "apprblob", conn)
+ copy_table_contents("manifest", "apprmanifest", conn)
+ copy_table_contents("manifestlist", "apprmanifestlist", conn)
+ copy_table_contents("blobplacement", "apprblobplacement", conn)
+ copy_table_contents("manifestblob", "apprmanifestblob", conn)
+ copy_table_contents("manifestlistmanifest", "apprmanifestlistmanifest", conn)
+ copy_table_contents("tag", "apprtag", conn)
+
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
diff --git a/data/migrations/versions/610320e9dacf_add_new_appr_specific_tables.py b/data/migrations/versions/610320e9dacf_add_new_appr_specific_tables.py
index 99c365260..dad746d2d 100644
--- a/data/migrations/versions/610320e9dacf_add_new_appr_specific_tables.py
+++ b/data/migrations/versions/610320e9dacf_add_new_appr_specific_tables.py
@@ -7,8 +7,8 @@ Create Date: 2018-05-24 16:46:13.514562
"""
# revision identifiers, used by Alembic.
-revision = '610320e9dacf'
-down_revision = '5cbbfc95bac7'
+revision = "610320e9dacf"
+down_revision = "5cbbfc95bac7"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,191 +16,356 @@ import sqlalchemy as sa
from util.migrate.table_ops import copy_table_contents
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('apprblobplacementlocation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprblobplacementlocation'))
+ op.create_table(
+ "apprblobplacementlocation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprblobplacementlocation")),
)
- op.create_index('apprblobplacementlocation_name', 'apprblobplacementlocation', ['name'], unique=True)
- op.create_table('apprtagkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprtagkind'))
+ op.create_index(
+ "apprblobplacementlocation_name",
+ "apprblobplacementlocation",
+ ["name"],
+ unique=True,
)
- op.create_index('apprtagkind_name', 'apprtagkind', ['name'], unique=True)
- op.create_table('apprblob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('size', sa.BigInteger(), nullable=False),
- sa.Column('uncompressed_size', sa.BigInteger(), nullable=True),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_apprblob_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprblob'))
+ op.create_table(
+ "apprtagkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprtagkind")),
)
- op.create_index('apprblob_digest', 'apprblob', ['digest'], unique=True)
- op.create_index('apprblob_media_type_id', 'apprblob', ['media_type_id'], unique=False)
- op.create_table('apprmanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('manifest_json', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_apprmanifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprmanifest'))
+ op.create_index("apprtagkind_name", "apprtagkind", ["name"], unique=True)
+ op.create_table(
+ "apprblob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("size", sa.BigInteger(), nullable=False),
+ sa.Column("uncompressed_size", sa.BigInteger(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_apprblob_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprblob")),
)
- op.create_index('apprmanifest_digest', 'apprmanifest', ['digest'], unique=True)
- op.create_index('apprmanifest_media_type_id', 'apprmanifest', ['media_type_id'], unique=False)
- op.create_table('apprmanifestlist',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('manifest_list_json', sa.Text(), nullable=False),
- sa.Column('schema_version', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_apprmanifestlist_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprmanifestlist'))
+ op.create_index("apprblob_digest", "apprblob", ["digest"], unique=True)
+ op.create_index(
+ "apprblob_media_type_id", "apprblob", ["media_type_id"], unique=False
)
- op.create_index('apprmanifestlist_digest', 'apprmanifestlist', ['digest'], unique=True)
- op.create_index('apprmanifestlist_media_type_id', 'apprmanifestlist', ['media_type_id'], unique=False)
- op.create_table('apprblobplacement',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['apprblob.id'], name=op.f('fk_apprblobplacement_blob_id_apprblob')),
- sa.ForeignKeyConstraint(['location_id'], ['apprblobplacementlocation.id'], name=op.f('fk_apprblobplacement_location_id_apprblobplacementlocation')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprblobplacement'))
+ op.create_table(
+ "apprmanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_json", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_apprmanifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprmanifest")),
)
- op.create_index('apprblobplacement_blob_id', 'apprblobplacement', ['blob_id'], unique=False)
- op.create_index('apprblobplacement_blob_id_location_id', 'apprblobplacement', ['blob_id', 'location_id'], unique=True)
- op.create_index('apprblobplacement_location_id', 'apprblobplacement', ['location_id'], unique=False)
- op.create_table('apprmanifestblob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['apprblob.id'], name=op.f('fk_apprmanifestblob_blob_id_apprblob')),
- sa.ForeignKeyConstraint(['manifest_id'], ['apprmanifest.id'], name=op.f('fk_apprmanifestblob_manifest_id_apprmanifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprmanifestblob'))
+ op.create_index("apprmanifest_digest", "apprmanifest", ["digest"], unique=True)
+ op.create_index(
+ "apprmanifest_media_type_id", "apprmanifest", ["media_type_id"], unique=False
)
- op.create_index('apprmanifestblob_blob_id', 'apprmanifestblob', ['blob_id'], unique=False)
- op.create_index('apprmanifestblob_manifest_id', 'apprmanifestblob', ['manifest_id'], unique=False)
- op.create_index('apprmanifestblob_manifest_id_blob_id', 'apprmanifestblob', ['manifest_id', 'blob_id'], unique=True)
- op.create_table('apprmanifestlistmanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('operating_system', sa.String(length=255), nullable=True),
- sa.Column('architecture', sa.String(length=255), nullable=True),
- sa.Column('platform_json', sa.Text(), nullable=True),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['manifest_id'], ['apprmanifest.id'], name=op.f('fk_apprmanifestlistmanifest_manifest_id_apprmanifest')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['apprmanifestlist.id'], name=op.f('fk_apprmanifestlistmanifest_manifest_list_id_apprmanifestlist')),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_apprmanifestlistmanifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprmanifestlistmanifest'))
+ op.create_table(
+ "apprmanifestlist",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("manifest_list_json", sa.Text(), nullable=False),
+ sa.Column("schema_version", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_apprmanifestlist_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprmanifestlist")),
)
- op.create_index('apprmanifestlistmanifest_manifest_id', 'apprmanifestlistmanifest', ['manifest_id'], unique=False)
- op.create_index('apprmanifestlistmanifest_manifest_list_id', 'apprmanifestlistmanifest', ['manifest_list_id'], unique=False)
- op.create_index('apprmanifestlistmanifest_manifest_list_id_media_type_id', 'apprmanifestlistmanifest', ['manifest_list_id', 'media_type_id'], unique=False)
- op.create_index('apprmanifestlistmanifest_manifest_list_id_operating_system_arch', 'apprmanifestlistmanifest', ['manifest_list_id', 'operating_system', 'architecture', 'media_type_id'], unique=False)
- op.create_index('apprmanifestlistmanifest_media_type_id', 'apprmanifestlistmanifest', ['media_type_id'], unique=False)
- op.create_table('apprtag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=True),
- sa.Column('lifetime_start', sa.BigInteger(), nullable=False),
- sa.Column('lifetime_end', sa.BigInteger(), nullable=True),
- sa.Column('hidden', sa.Boolean(), nullable=False),
- sa.Column('reverted', sa.Boolean(), nullable=False),
- sa.Column('protected', sa.Boolean(), nullable=False),
- sa.Column('tag_kind_id', sa.Integer(), nullable=False),
- sa.Column('linked_tag_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['linked_tag_id'], ['apprtag.id'], name=op.f('fk_apprtag_linked_tag_id_apprtag')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['apprmanifestlist.id'], name=op.f('fk_apprtag_manifest_list_id_apprmanifestlist')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_apprtag_repository_id_repository')),
- sa.ForeignKeyConstraint(['tag_kind_id'], ['apprtagkind.id'], name=op.f('fk_apprtag_tag_kind_id_apprtagkind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_apprtag'))
+ op.create_index(
+ "apprmanifestlist_digest", "apprmanifestlist", ["digest"], unique=True
)
- op.create_index('apprtag_lifetime_end', 'apprtag', ['lifetime_end'], unique=False)
- op.create_index('apprtag_linked_tag_id', 'apprtag', ['linked_tag_id'], unique=False)
- op.create_index('apprtag_manifest_list_id', 'apprtag', ['manifest_list_id'], unique=False)
- op.create_index('apprtag_repository_id', 'apprtag', ['repository_id'], unique=False)
- op.create_index('apprtag_repository_id_name', 'apprtag', ['repository_id', 'name'], unique=False)
- op.create_index('apprtag_repository_id_name_hidden', 'apprtag', ['repository_id', 'name', 'hidden'], unique=False)
- op.create_index('apprtag_repository_id_name_lifetime_end', 'apprtag', ['repository_id', 'name', 'lifetime_end'], unique=True)
- op.create_index('apprtag_tag_kind_id', 'apprtag', ['tag_kind_id'], unique=False)
+ op.create_index(
+ "apprmanifestlist_media_type_id",
+ "apprmanifestlist",
+ ["media_type_id"],
+ unique=False,
+ )
+ op.create_table(
+ "apprblobplacement",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"],
+ ["apprblob.id"],
+ name=op.f("fk_apprblobplacement_blob_id_apprblob"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["apprblobplacementlocation.id"],
+ name=op.f("fk_apprblobplacement_location_id_apprblobplacementlocation"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprblobplacement")),
+ )
+ op.create_index(
+ "apprblobplacement_blob_id", "apprblobplacement", ["blob_id"], unique=False
+ )
+ op.create_index(
+ "apprblobplacement_blob_id_location_id",
+ "apprblobplacement",
+ ["blob_id", "location_id"],
+ unique=True,
+ )
+ op.create_index(
+ "apprblobplacement_location_id",
+ "apprblobplacement",
+ ["location_id"],
+ unique=False,
+ )
+ op.create_table(
+ "apprmanifestblob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"],
+ ["apprblob.id"],
+ name=op.f("fk_apprmanifestblob_blob_id_apprblob"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["apprmanifest.id"],
+ name=op.f("fk_apprmanifestblob_manifest_id_apprmanifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprmanifestblob")),
+ )
+ op.create_index(
+ "apprmanifestblob_blob_id", "apprmanifestblob", ["blob_id"], unique=False
+ )
+ op.create_index(
+ "apprmanifestblob_manifest_id",
+ "apprmanifestblob",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "apprmanifestblob_manifest_id_blob_id",
+ "apprmanifestblob",
+ ["manifest_id", "blob_id"],
+ unique=True,
+ )
+ op.create_table(
+ "apprmanifestlistmanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("operating_system", sa.String(length=255), nullable=True),
+ sa.Column("architecture", sa.String(length=255), nullable=True),
+ sa.Column("platform_json", sa.Text(), nullable=True),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["apprmanifest.id"],
+ name=op.f("fk_apprmanifestlistmanifest_manifest_id_apprmanifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["apprmanifestlist.id"],
+ name=op.f("fk_apprmanifestlistmanifest_manifest_list_id_apprmanifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_apprmanifestlistmanifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprmanifestlistmanifest")),
+ )
+ op.create_index(
+ "apprmanifestlistmanifest_manifest_id",
+ "apprmanifestlistmanifest",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "apprmanifestlistmanifest_manifest_list_id",
+ "apprmanifestlistmanifest",
+ ["manifest_list_id"],
+ unique=False,
+ )
+ op.create_index(
+ "apprmanifestlistmanifest_manifest_list_id_media_type_id",
+ "apprmanifestlistmanifest",
+ ["manifest_list_id", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "apprmanifestlistmanifest_manifest_list_id_operating_system_arch",
+ "apprmanifestlistmanifest",
+ ["manifest_list_id", "operating_system", "architecture", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "apprmanifestlistmanifest_media_type_id",
+ "apprmanifestlistmanifest",
+ ["media_type_id"],
+ unique=False,
+ )
+ op.create_table(
+ "apprtag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=True),
+ sa.Column("lifetime_start", sa.BigInteger(), nullable=False),
+ sa.Column("lifetime_end", sa.BigInteger(), nullable=True),
+ sa.Column("hidden", sa.Boolean(), nullable=False),
+ sa.Column("reverted", sa.Boolean(), nullable=False),
+ sa.Column("protected", sa.Boolean(), nullable=False),
+ sa.Column("tag_kind_id", sa.Integer(), nullable=False),
+ sa.Column("linked_tag_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["linked_tag_id"],
+ ["apprtag.id"],
+ name=op.f("fk_apprtag_linked_tag_id_apprtag"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["apprmanifestlist.id"],
+ name=op.f("fk_apprtag_manifest_list_id_apprmanifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_apprtag_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_kind_id"],
+ ["apprtagkind.id"],
+ name=op.f("fk_apprtag_tag_kind_id_apprtagkind"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_apprtag")),
+ )
+ op.create_index("apprtag_lifetime_end", "apprtag", ["lifetime_end"], unique=False)
+ op.create_index("apprtag_linked_tag_id", "apprtag", ["linked_tag_id"], unique=False)
+ op.create_index(
+ "apprtag_manifest_list_id", "apprtag", ["manifest_list_id"], unique=False
+ )
+ op.create_index("apprtag_repository_id", "apprtag", ["repository_id"], unique=False)
+ op.create_index(
+ "apprtag_repository_id_name", "apprtag", ["repository_id", "name"], unique=False
+ )
+ op.create_index(
+ "apprtag_repository_id_name_hidden",
+ "apprtag",
+ ["repository_id", "name", "hidden"],
+ unique=False,
+ )
+ op.create_index(
+ "apprtag_repository_id_name_lifetime_end",
+ "apprtag",
+ ["repository_id", "name", "lifetime_end"],
+ unique=True,
+ )
+ op.create_index("apprtag_tag_kind_id", "apprtag", ["tag_kind_id"], unique=False)
# ### end Alembic commands ###
conn = op.get_bind()
- copy_table_contents('blobplacementlocation', 'apprblobplacementlocation', conn)
- copy_table_contents('tagkind', 'apprtagkind', conn)
-
+ copy_table_contents("blobplacementlocation", "apprblobplacementlocation", conn)
+ copy_table_contents("tagkind", "apprtagkind", conn)
+
# ### population of test data ### #
-
- tester.populate_table('apprmanifest', [
- ('digest', tester.TestDataType.String),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ('manifest_json', tester.TestDataType.JSON),
- ])
- tester.populate_table('apprmanifestlist', [
- ('digest', tester.TestDataType.String),
- ('manifest_list_json', tester.TestDataType.JSON),
- ('schema_version', tester.TestDataType.String),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ])
+ tester.populate_table(
+ "apprmanifest",
+ [
+ ("digest", tester.TestDataType.String),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ("manifest_json", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('apprmanifestlistmanifest', [
- ('manifest_list_id', tester.TestDataType.Foreign('apprmanifestlist')),
- ('manifest_id', tester.TestDataType.Foreign('apprmanifest')),
- ('operating_system', tester.TestDataType.String),
- ('architecture', tester.TestDataType.String),
- ('platform_json', tester.TestDataType.JSON),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ])
+ tester.populate_table(
+ "apprmanifestlist",
+ [
+ ("digest", tester.TestDataType.String),
+ ("manifest_list_json", tester.TestDataType.JSON),
+ ("schema_version", tester.TestDataType.String),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ],
+ )
- tester.populate_table('apprblob', [
- ('digest', tester.TestDataType.String),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ('size', tester.TestDataType.BigInteger),
- ('uncompressed_size', tester.TestDataType.BigInteger),
- ])
+ tester.populate_table(
+ "apprmanifestlistmanifest",
+ [
+ ("manifest_list_id", tester.TestDataType.Foreign("apprmanifestlist")),
+ ("manifest_id", tester.TestDataType.Foreign("apprmanifest")),
+ ("operating_system", tester.TestDataType.String),
+ ("architecture", tester.TestDataType.String),
+ ("platform_json", tester.TestDataType.JSON),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ],
+ )
- tester.populate_table('apprmanifestblob', [
- ('manifest_id', tester.TestDataType.Foreign('apprmanifest')),
- ('blob_id', tester.TestDataType.Foreign('apprblob')),
- ])
+ tester.populate_table(
+ "apprblob",
+ [
+ ("digest", tester.TestDataType.String),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ("size", tester.TestDataType.BigInteger),
+ ("uncompressed_size", tester.TestDataType.BigInteger),
+ ],
+ )
- tester.populate_table('apprtag', [
- ('name', tester.TestDataType.String),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('manifest_list_id', tester.TestDataType.Foreign('apprmanifestlist')),
- ('lifetime_start', tester.TestDataType.Integer),
- ('hidden', tester.TestDataType.Boolean),
- ('reverted', tester.TestDataType.Boolean),
- ('protected', tester.TestDataType.Boolean),
- ('tag_kind_id', tester.TestDataType.Foreign('apprtagkind')),
- ])
+ tester.populate_table(
+ "apprmanifestblob",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("apprmanifest")),
+ ("blob_id", tester.TestDataType.Foreign("apprblob")),
+ ],
+ )
- tester.populate_table('apprblobplacement', [
- ('blob_id', tester.TestDataType.Foreign('apprmanifestblob')),
- ('location_id', tester.TestDataType.Foreign('apprblobplacementlocation')),
- ])
+ tester.populate_table(
+ "apprtag",
+ [
+ ("name", tester.TestDataType.String),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("manifest_list_id", tester.TestDataType.Foreign("apprmanifestlist")),
+ ("lifetime_start", tester.TestDataType.Integer),
+ ("hidden", tester.TestDataType.Boolean),
+ ("reverted", tester.TestDataType.Boolean),
+ ("protected", tester.TestDataType.Boolean),
+ ("tag_kind_id", tester.TestDataType.Foreign("apprtagkind")),
+ ],
+ )
+
+ tester.populate_table(
+ "apprblobplacement",
+ [
+ ("blob_id", tester.TestDataType.Foreign("apprmanifestblob")),
+ ("location_id", tester.TestDataType.Foreign("apprblobplacementlocation")),
+ ],
+ )
# ### end population of test data ### #
-
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('apprtag')
- op.drop_table('apprmanifestlistmanifest')
- op.drop_table('apprmanifestblob')
- op.drop_table('apprblobplacement')
- op.drop_table('apprmanifestlist')
- op.drop_table('apprmanifest')
- op.drop_table('apprblob')
- op.drop_table('apprtagkind')
- op.drop_table('apprblobplacementlocation')
+ op.drop_table("apprtag")
+ op.drop_table("apprmanifestlistmanifest")
+ op.drop_table("apprmanifestblob")
+ op.drop_table("apprblobplacement")
+ op.drop_table("apprmanifestlist")
+ op.drop_table("apprmanifest")
+ op.drop_table("apprblob")
+ op.drop_table("apprtagkind")
+ op.drop_table("apprblobplacementlocation")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/61cadbacb9fc_add_ability_for_build_triggers_to_be_.py b/data/migrations/versions/61cadbacb9fc_add_ability_for_build_triggers_to_be_.py
index 1dbb1e7a4..73a590057 100644
--- a/data/migrations/versions/61cadbacb9fc_add_ability_for_build_triggers_to_be_.py
+++ b/data/migrations/versions/61cadbacb9fc_add_ability_for_build_triggers_to_be_.py
@@ -7,58 +7,88 @@ Create Date: 2017-10-18 12:07:26.190901
"""
# revision identifiers, used by Alembic.
-revision = '61cadbacb9fc'
-down_revision = 'b4c2d45bc132'
+revision = "61cadbacb9fc"
+down_revision = "b4c2d45bc132"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('disablereason',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_disablereason'))
+ op.create_table(
+ "disablereason",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_disablereason")),
)
- op.create_index('disablereason_name', 'disablereason', ['name'], unique=True)
+ op.create_index("disablereason_name", "disablereason", ["name"], unique=True)
- op.bulk_insert(
- tables.disablereason,
- [
- {'id': 1, 'name': 'user_toggled'},
- ],
+ op.bulk_insert(tables.disablereason, [{"id": 1, "name": "user_toggled"}])
+
+ op.bulk_insert(tables.logentrykind, [{"name": "toggle_repo_trigger"}])
+
+ op.add_column(
+ u"repositorybuildtrigger",
+ sa.Column("disabled_reason_id", sa.Integer(), nullable=True),
+ )
+ op.add_column(
+ u"repositorybuildtrigger",
+ sa.Column(
+ "enabled",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.true(),
+ ),
+ )
+ op.create_index(
+ "repositorybuildtrigger_disabled_reason_id",
+ "repositorybuildtrigger",
+ ["disabled_reason_id"],
+ unique=False,
+ )
+ op.create_foreign_key(
+ op.f("fk_repositorybuildtrigger_disabled_reason_id_disablereason"),
+ "repositorybuildtrigger",
+ "disablereason",
+ ["disabled_reason_id"],
+ ["id"],
)
-
- op.bulk_insert(tables.logentrykind, [
- {'name': 'toggle_repo_trigger'},
- ])
-
- op.add_column(u'repositorybuildtrigger', sa.Column('disabled_reason_id', sa.Integer(), nullable=True))
- op.add_column(u'repositorybuildtrigger', sa.Column('enabled', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
- op.create_index('repositorybuildtrigger_disabled_reason_id', 'repositorybuildtrigger', ['disabled_reason_id'], unique=False)
- op.create_foreign_key(op.f('fk_repositorybuildtrigger_disabled_reason_id_disablereason'), 'repositorybuildtrigger', 'disablereason', ['disabled_reason_id'], ['id'])
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('repositorybuildtrigger', 'disabled_reason_id', tester.TestDataType.Foreign('disablereason'))
- tester.populate_column('repositorybuildtrigger', 'enabled', tester.TestDataType.Boolean)
+ tester.populate_column(
+ "repositorybuildtrigger",
+ "disabled_reason_id",
+ tester.TestDataType.Foreign("disablereason"),
+ )
+ tester.populate_column(
+ "repositorybuildtrigger", "enabled", tester.TestDataType.Boolean
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_constraint(op.f('fk_repositorybuildtrigger_disabled_reason_id_disablereason'), 'repositorybuildtrigger', type_='foreignkey')
- op.drop_index('repositorybuildtrigger_disabled_reason_id', table_name='repositorybuildtrigger')
- op.drop_column(u'repositorybuildtrigger', 'enabled')
- op.drop_column(u'repositorybuildtrigger', 'disabled_reason_id')
- op.drop_table('disablereason')
+ op.drop_constraint(
+ op.f("fk_repositorybuildtrigger_disabled_reason_id_disablereason"),
+ "repositorybuildtrigger",
+ type_="foreignkey",
+ )
+ op.drop_index(
+ "repositorybuildtrigger_disabled_reason_id", table_name="repositorybuildtrigger"
+ )
+ op.drop_column(u"repositorybuildtrigger", "enabled")
+ op.drop_column(u"repositorybuildtrigger", "disabled_reason_id")
+ op.drop_table("disablereason")
# ### end Alembic commands ###
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.logentrykind.c.name == op.inline_literal('toggle_repo_trigger')))
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.c.name == op.inline_literal("toggle_repo_trigger")
+ )
+ )
diff --git a/data/migrations/versions/654e6df88b71_change_manifest_bytes_to_a_utf8_text_.py b/data/migrations/versions/654e6df88b71_change_manifest_bytes_to_a_utf8_text_.py
index b7d17207f..790516349 100644
--- a/data/migrations/versions/654e6df88b71_change_manifest_bytes_to_a_utf8_text_.py
+++ b/data/migrations/versions/654e6df88b71_change_manifest_bytes_to_a_utf8_text_.py
@@ -7,8 +7,8 @@ Create Date: 2018-08-15 09:58:46.109277
"""
# revision identifiers, used by Alembic.
-revision = '654e6df88b71'
-down_revision = 'eafdeadcebc7'
+revision = "654e6df88b71"
+down_revision = "eafdeadcebc7"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -19,8 +19,13 @@ from util.migrate import UTF8LongText
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('manifest', 'manifest_bytes', existing_type=sa.Text(), type_=UTF8LongText())
+ op.alter_column(
+ "manifest", "manifest_bytes", existing_type=sa.Text(), type_=UTF8LongText()
+ )
+
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('manifest', 'manifest_bytes', existing_type=UTF8LongText(), type_=sa.Text())
+ op.alter_column(
+ "manifest", "manifest_bytes", existing_type=UTF8LongText(), type_=sa.Text()
+ )
diff --git a/data/migrations/versions/67f0abd172ae_add_tagtorepositorytag_table.py b/data/migrations/versions/67f0abd172ae_add_tagtorepositorytag_table.py
index aae5325b9..c707e2711 100644
--- a/data/migrations/versions/67f0abd172ae_add_tagtorepositorytag_table.py
+++ b/data/migrations/versions/67f0abd172ae_add_tagtorepositorytag_table.py
@@ -7,41 +7,68 @@ Create Date: 2018-10-30 11:31:06.615488
"""
# revision identifiers, used by Alembic.
-revision = '67f0abd172ae'
-down_revision = '10f45ee2310b'
+revision = "67f0abd172ae"
+down_revision = "10f45ee2310b"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('tagtorepositorytag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('tag_id', sa.Integer(), nullable=False),
- sa.Column('repository_tag_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_tagtorepositorytag_repository_id_repository')),
- sa.ForeignKeyConstraint(['repository_tag_id'], ['repositorytag.id'], name=op.f('fk_tagtorepositorytag_repository_tag_id_repositorytag')),
- sa.ForeignKeyConstraint(['tag_id'], ['tag.id'], name=op.f('fk_tagtorepositorytag_tag_id_tag')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagtorepositorytag'))
+ op.create_table(
+ "tagtorepositorytag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("tag_id", sa.Integer(), nullable=False),
+ sa.Column("repository_tag_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_tagtorepositorytag_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_tag_id"],
+ ["repositorytag.id"],
+ name=op.f("fk_tagtorepositorytag_repository_tag_id_repositorytag"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_id"], ["tag.id"], name=op.f("fk_tagtorepositorytag_tag_id_tag")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagtorepositorytag")),
+ )
+ op.create_index(
+ "tagtorepositorytag_repository_id",
+ "tagtorepositorytag",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagtorepositorytag_repository_tag_id",
+ "tagtorepositorytag",
+ ["repository_tag_id"],
+ unique=True,
+ )
+ op.create_index(
+ "tagtorepositorytag_tag_id", "tagtorepositorytag", ["tag_id"], unique=True
)
- op.create_index('tagtorepositorytag_repository_id', 'tagtorepositorytag', ['repository_id'], unique=False)
- op.create_index('tagtorepositorytag_repository_tag_id', 'tagtorepositorytag', ['repository_tag_id'], unique=True)
- op.create_index('tagtorepositorytag_tag_id', 'tagtorepositorytag', ['tag_id'], unique=True)
# ### end Alembic commands ###
- tester.populate_table('tagtorepositorytag', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('tag_id', tester.TestDataType.Foreign('tag')),
- ('repository_tag_id', tester.TestDataType.Foreign('repositorytag')),
- ])
+ tester.populate_table(
+ "tagtorepositorytag",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("tag_id", tester.TestDataType.Foreign("tag")),
+ ("repository_tag_id", tester.TestDataType.Foreign("repositorytag")),
+ ],
+ )
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('tagtorepositorytag')
+ op.drop_table("tagtorepositorytag")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/6c21e2cfb8b6_change_logentry_to_use_a_biginteger_as_.py b/data/migrations/versions/6c21e2cfb8b6_change_logentry_to_use_a_biginteger_as_.py
index 789ba4fa4..53e0a497d 100644
--- a/data/migrations/versions/6c21e2cfb8b6_change_logentry_to_use_a_biginteger_as_.py
+++ b/data/migrations/versions/6c21e2cfb8b6_change_logentry_to_use_a_biginteger_as_.py
@@ -7,8 +7,8 @@ Create Date: 2018-07-27 16:30:02.877346
"""
# revision identifiers, used by Alembic.
-revision = '6c21e2cfb8b6'
-down_revision = 'd17c695859ea'
+revision = "6c21e2cfb8b6"
+down_revision = "d17c695859ea"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -18,18 +18,19 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
op.alter_column(
- table_name='logentry',
- column_name='id',
+ table_name="logentry",
+ column_name="id",
nullable=False,
autoincrement=True,
type_=sa.BigInteger(),
)
+
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
op.alter_column(
- table_name='logentry',
- column_name='id',
+ table_name="logentry",
+ column_name="id",
nullable=False,
autoincrement=True,
type_=sa.Integer(),
diff --git a/data/migrations/versions/6c7014e84a5e_add_user_prompt_support.py b/data/migrations/versions/6c7014e84a5e_add_user_prompt_support.py
index 99ee1e77c..4f040d432 100644
--- a/data/migrations/versions/6c7014e84a5e_add_user_prompt_support.py
+++ b/data/migrations/versions/6c7014e84a5e_add_user_prompt_support.py
@@ -7,50 +7,62 @@ Create Date: 2016-10-31 16:26:31.447705
"""
# revision identifiers, used by Alembic.
-revision = '6c7014e84a5e'
-down_revision = 'c156deb8845d'
+revision = "6c7014e84a5e"
+down_revision = "c156deb8845d"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.create_table('userpromptkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_userpromptkind'))
+ op.create_table(
+ "userpromptkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_userpromptkind")),
)
- op.create_index('userpromptkind_name', 'userpromptkind', ['name'], unique=False)
- op.create_table('userprompt',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['kind_id'], ['userpromptkind.id'], name=op.f('fk_userprompt_kind_id_userpromptkind')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_userprompt_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_userprompt'))
+ op.create_index("userpromptkind_name", "userpromptkind", ["name"], unique=False)
+ op.create_table(
+ "userprompt",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["userpromptkind.id"],
+ name=op.f("fk_userprompt_kind_id_userpromptkind"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_userprompt_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_userprompt")),
+ )
+ op.create_index("userprompt_kind_id", "userprompt", ["kind_id"], unique=False)
+ op.create_index("userprompt_user_id", "userprompt", ["user_id"], unique=False)
+ op.create_index(
+ "userprompt_user_id_kind_id", "userprompt", ["user_id", "kind_id"], unique=True
)
- op.create_index('userprompt_kind_id', 'userprompt', ['kind_id'], unique=False)
- op.create_index('userprompt_user_id', 'userprompt', ['user_id'], unique=False)
- op.create_index('userprompt_user_id_kind_id', 'userprompt', ['user_id', 'kind_id'], unique=True)
### end Alembic commands ###
- op.bulk_insert(tables.userpromptkind,
- [
- {'name':'confirm_username'},
- ])
+ op.bulk_insert(tables.userpromptkind, [{"name": "confirm_username"}])
# ### population of test data ### #
- tester.populate_table('userprompt', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('kind_id', tester.TestDataType.Foreign('userpromptkind')),
- ])
+ tester.populate_table(
+ "userprompt",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("kind_id", tester.TestDataType.Foreign("userpromptkind")),
+ ],
+ )
# ### end population of test data ### #
+
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.drop_table('userprompt')
- op.drop_table('userpromptkind')
+ op.drop_table("userprompt")
+ op.drop_table("userpromptkind")
### end Alembic commands ###
diff --git a/data/migrations/versions/6ec8726c0ace_add_logentry3_table.py b/data/migrations/versions/6ec8726c0ace_add_logentry3_table.py
index 47ecf1cb1..59568e4bc 100644
--- a/data/migrations/versions/6ec8726c0ace_add_logentry3_table.py
+++ b/data/migrations/versions/6ec8726c0ace_add_logentry3_table.py
@@ -7,37 +7,54 @@ Create Date: 2019-01-03 13:41:02.897957
"""
# revision identifiers, used by Alembic.
-revision = '6ec8726c0ace'
-down_revision = '54492a68a3cf'
+revision = "6ec8726c0ace"
+down_revision = "54492a68a3cf"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('logentry3',
- sa.Column('id', sa.BigInteger(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.Column('account_id', sa.Integer(), nullable=False),
- sa.Column('performer_id', sa.Integer(), nullable=True),
- sa.Column('repository_id', sa.Integer(), nullable=True),
- sa.Column('datetime', sa.DateTime(), nullable=False),
- sa.Column('ip', sa.String(length=255), nullable=True),
- sa.Column('metadata_json', sa.Text(), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_logentry3'))
+ op.create_table(
+ "logentry3",
+ sa.Column("id", sa.BigInteger(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.Column("account_id", sa.Integer(), nullable=False),
+ sa.Column("performer_id", sa.Integer(), nullable=True),
+ sa.Column("repository_id", sa.Integer(), nullable=True),
+ sa.Column("datetime", sa.DateTime(), nullable=False),
+ sa.Column("ip", sa.String(length=255), nullable=True),
+ sa.Column("metadata_json", sa.Text(), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_logentry3")),
+ )
+ op.create_index(
+ "logentry3_account_id_datetime",
+ "logentry3",
+ ["account_id", "datetime"],
+ unique=False,
+ )
+ op.create_index("logentry3_datetime", "logentry3", ["datetime"], unique=False)
+ op.create_index(
+ "logentry3_performer_id_datetime",
+ "logentry3",
+ ["performer_id", "datetime"],
+ unique=False,
+ )
+ op.create_index(
+ "logentry3_repository_id_datetime_kind_id",
+ "logentry3",
+ ["repository_id", "datetime", "kind_id"],
+ unique=False,
)
- op.create_index('logentry3_account_id_datetime', 'logentry3', ['account_id', 'datetime'], unique=False)
- op.create_index('logentry3_datetime', 'logentry3', ['datetime'], unique=False)
- op.create_index('logentry3_performer_id_datetime', 'logentry3', ['performer_id', 'datetime'], unique=False)
- op.create_index('logentry3_repository_id_datetime_kind_id', 'logentry3', ['repository_id', 'datetime', 'kind_id'], unique=False)
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('logentry3')
+ op.drop_table("logentry3")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py b/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py
index 43459af40..575bc3429 100644
--- a/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py
+++ b/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py
@@ -6,282 +6,474 @@ Create Date: 2019-08-19 16:07:48.109889
"""
# revision identifiers, used by Alembic.
-revision = '703298a825c2'
-down_revision = 'c13c8052f7a6'
+revision = "703298a825c2"
+down_revision = "c13c8052f7a6"
import logging
import uuid
from datetime import datetime
-from peewee import (JOIN, IntegrityError, DateTimeField, CharField, ForeignKeyField,
- BooleanField, TextField, IntegerField)
+from peewee import (
+ JOIN,
+ IntegrityError,
+ DateTimeField,
+ CharField,
+ ForeignKeyField,
+ BooleanField,
+ TextField,
+ IntegerField,
+)
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
-from data.database import (BaseModel, User, Repository, AccessTokenKind, Role,
- random_string_generator, QuayUserField, BuildTriggerService,
- uuid_generator, DisableReason)
-from data.fields import Credential, DecryptedValue, EncryptedCharField, EncryptedTextField, EnumField, CredentialField
+from data.database import (
+ BaseModel,
+ User,
+ Repository,
+ AccessTokenKind,
+ Role,
+ random_string_generator,
+ QuayUserField,
+ BuildTriggerService,
+ uuid_generator,
+ DisableReason,
+)
+from data.fields import (
+ Credential,
+ DecryptedValue,
+ EncryptedCharField,
+ EncryptedTextField,
+ EnumField,
+ CredentialField,
+)
from data.model.token import ACCESS_TOKEN_NAME_PREFIX_LENGTH
-from data.model.appspecifictoken import TOKEN_NAME_PREFIX_LENGTH as AST_TOKEN_NAME_PREFIX_LENGTH
-from data.model.oauth import ACCESS_TOKEN_PREFIX_LENGTH as OAUTH_ACCESS_TOKEN_PREFIX_LENGTH
+from data.model.appspecifictoken import (
+ TOKEN_NAME_PREFIX_LENGTH as AST_TOKEN_NAME_PREFIX_LENGTH,
+)
+from data.model.oauth import (
+ ACCESS_TOKEN_PREFIX_LENGTH as OAUTH_ACCESS_TOKEN_PREFIX_LENGTH,
+)
from data.model.oauth import AUTHORIZATION_CODE_PREFIX_LENGTH
BATCH_SIZE = 10
logger = logging.getLogger(__name__)
-def _iterate(model_class, clause):
- while True:
- has_rows = False
- for row in list(model_class.select().where(clause).limit(BATCH_SIZE)):
- has_rows = True
- yield row
- if not has_rows:
- break
+def _iterate(model_class, clause):
+ while True:
+ has_rows = False
+ for row in list(model_class.select().where(clause).limit(BATCH_SIZE)):
+ has_rows = True
+ yield row
+
+ if not has_rows:
+ break
def _decrypted(value):
- if value is None:
- return None
+ if value is None:
+ return None
- assert isinstance(value, basestring)
- return DecryptedValue(value)
+ assert isinstance(value, basestring)
+ return DecryptedValue(value)
# NOTE: As per standard migrations involving Peewee models, we copy them here, as they will change
# after this call.
class AccessToken(BaseModel):
- code = CharField(default=random_string_generator(length=64), unique=True, index=True)
- token_name = CharField(default=random_string_generator(length=32), unique=True, index=True)
- token_code = EncryptedCharField(default_token_length=32)
+ code = CharField(
+ default=random_string_generator(length=64), unique=True, index=True
+ )
+ token_name = CharField(
+ default=random_string_generator(length=32), unique=True, index=True
+ )
+ token_code = EncryptedCharField(default_token_length=32)
+
class RobotAccountToken(BaseModel):
- robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
- token = EncryptedCharField(default_token_length=64)
- fully_migrated = BooleanField(default=False)
+ robot_account = QuayUserField(index=True, allows_robots=True, unique=True)
+ token = EncryptedCharField(default_token_length=64)
+ fully_migrated = BooleanField(default=False)
+
class RepositoryBuildTrigger(BaseModel):
- uuid = CharField(default=uuid_generator, index=True)
- auth_token = CharField(null=True)
- private_key = TextField(null=True)
+ uuid = CharField(default=uuid_generator, index=True)
+ auth_token = CharField(null=True)
+ private_key = TextField(null=True)
+
+ secure_auth_token = EncryptedCharField(null=True)
+ secure_private_key = EncryptedTextField(null=True)
+ fully_migrated = BooleanField(default=False)
- secure_auth_token = EncryptedCharField(null=True)
- secure_private_key = EncryptedTextField(null=True)
- fully_migrated = BooleanField(default=False)
class AppSpecificAuthToken(BaseModel):
- token_name = CharField(index=True, unique=True, default=random_string_generator(60))
- token_secret = EncryptedCharField(default_token_length=60)
- token_code = CharField(default=random_string_generator(length=120), unique=True, index=True)
+ token_name = CharField(index=True, unique=True, default=random_string_generator(60))
+ token_secret = EncryptedCharField(default_token_length=60)
+ token_code = CharField(
+ default=random_string_generator(length=120), unique=True, index=True
+ )
+
class OAuthAccessToken(BaseModel):
- token_name = CharField(index=True, unique=True)
- token_code = CredentialField()
- access_token = CharField(index=True)
+ token_name = CharField(index=True, unique=True)
+ token_code = CredentialField()
+ access_token = CharField(index=True)
+
class OAuthAuthorizationCode(BaseModel):
- code = CharField(index=True, unique=True, null=True)
- code_name = CharField(index=True, unique=True)
- code_credential = CredentialField()
+ code = CharField(index=True, unique=True, null=True)
+ code_name = CharField(index=True, unique=True)
+ code_credential = CredentialField()
+
class OAuthApplication(BaseModel):
- secure_client_secret = EncryptedCharField(default_token_length=40, null=True)
- fully_migrated = BooleanField(default=False)
- client_secret = CharField(default=random_string_generator(length=40))
+ secure_client_secret = EncryptedCharField(default_token_length=40, null=True)
+ fully_migrated = BooleanField(default=False)
+ client_secret = CharField(default=random_string_generator(length=40))
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
+ op = ProgressWrapper(original_op, progress_reporter)
- from app import app
- if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
- # Empty all access token names to fix the bug where we put the wrong name and code
- # in for some tokens.
- AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute()
+ from app import app
- # AccessToken.
- logger.info('Backfilling encrypted credentials for access tokens')
- for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) |
- (AccessToken.token_name == ''))):
- logger.info('Backfilling encrypted credentials for access token %s', access_token.id)
- assert access_token.code is not None
- assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
- assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
+ if app.config.get("SETUP_COMPLETE", False) or tester.is_testing:
+ # Empty all access token names to fix the bug where we put the wrong name and code
+ # in for some tokens.
+ AccessToken.update(token_name=None).where(
+ AccessToken.token_name >> None
+ ).execute()
- token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
- token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:])
+ # AccessToken.
+ logger.info("Backfilling encrypted credentials for access tokens")
+ for access_token in _iterate(
+ AccessToken,
+ ((AccessToken.token_name >> None) | (AccessToken.token_name == "")),
+ ):
+ logger.info(
+ "Backfilling encrypted credentials for access token %s", access_token.id
+ )
+ assert access_token.code is not None
+ assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
+ assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
- (AccessToken
- .update(token_name=token_name, token_code=token_code)
- .where(AccessToken.id == access_token.id, AccessToken.code == access_token.code)
- .execute())
+ token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
+ token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:])
- assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0
+ (
+ AccessToken.update(token_name=token_name, token_code=token_code)
+ .where(
+ AccessToken.id == access_token.id,
+ AccessToken.code == access_token.code,
+ )
+ .execute()
+ )
- # Robots.
- logger.info('Backfilling encrypted credentials for robots')
- while True:
- has_row = False
- query = (User
- .select()
- .join(RobotAccountToken, JOIN.LEFT_OUTER)
- .where(User.robot == True, RobotAccountToken.id >> None)
- .limit(BATCH_SIZE))
+ assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0
- for robot_user in query:
- logger.info('Backfilling encrypted credentials for robot %s', robot_user.id)
- has_row = True
- try:
- RobotAccountToken.create(robot_account=robot_user,
- token=_decrypted(robot_user.email),
- fully_migrated=False)
- except IntegrityError:
- break
+ # Robots.
+ logger.info("Backfilling encrypted credentials for robots")
+ while True:
+ has_row = False
+ query = (
+ User.select()
+ .join(RobotAccountToken, JOIN.LEFT_OUTER)
+ .where(User.robot == True, RobotAccountToken.id >> None)
+ .limit(BATCH_SIZE)
+ )
- if not has_row:
- break
+ for robot_user in query:
+ logger.info(
+ "Backfilling encrypted credentials for robot %s", robot_user.id
+ )
+ has_row = True
+ try:
+ RobotAccountToken.create(
+ robot_account=robot_user,
+ token=_decrypted(robot_user.email),
+ fully_migrated=False,
+ )
+ except IntegrityError:
+ break
- # RepositoryBuildTrigger
- logger.info('Backfilling encrypted credentials for repo build triggers')
- for repo_build_trigger in _iterate(RepositoryBuildTrigger,
- (RepositoryBuildTrigger.fully_migrated == False)):
- logger.info('Backfilling encrypted credentials for repo build trigger %s',
- repo_build_trigger.id)
+ if not has_row:
+ break
- (RepositoryBuildTrigger
- .update(secure_auth_token=_decrypted(repo_build_trigger.auth_token),
- secure_private_key=_decrypted(repo_build_trigger.private_key),
- fully_migrated=True)
- .where(RepositoryBuildTrigger.id == repo_build_trigger.id,
- RepositoryBuildTrigger.uuid == repo_build_trigger.uuid)
- .execute())
+ # RepositoryBuildTrigger
+ logger.info("Backfilling encrypted credentials for repo build triggers")
+ for repo_build_trigger in _iterate(
+ RepositoryBuildTrigger, (RepositoryBuildTrigger.fully_migrated == False)
+ ):
+ logger.info(
+ "Backfilling encrypted credentials for repo build trigger %s",
+ repo_build_trigger.id,
+ )
- assert (RepositoryBuildTrigger
- .select()
+ (
+ RepositoryBuildTrigger.update(
+ secure_auth_token=_decrypted(repo_build_trigger.auth_token),
+ secure_private_key=_decrypted(repo_build_trigger.private_key),
+ fully_migrated=True,
+ )
+ .where(
+ RepositoryBuildTrigger.id == repo_build_trigger.id,
+ RepositoryBuildTrigger.uuid == repo_build_trigger.uuid,
+ )
+ .execute()
+ )
+
+ assert (
+ RepositoryBuildTrigger.select()
.where(RepositoryBuildTrigger.fully_migrated == False)
- .count()) == 0
+ .count()
+ ) == 0
- # AppSpecificAuthToken
- logger.info('Backfilling encrypted credentials for app specific auth tokens')
- for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) |
- (AppSpecificAuthToken.token_name == '') |
- (AppSpecificAuthToken.token_secret >> None))):
- logger.info('Backfilling encrypted credentials for app specific auth %s',
- token.id)
- assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
+ # AppSpecificAuthToken
+ logger.info("Backfilling encrypted credentials for app specific auth tokens")
+ for token in _iterate(
+ AppSpecificAuthToken,
+ (
+ (AppSpecificAuthToken.token_name >> None)
+ | (AppSpecificAuthToken.token_name == "")
+ | (AppSpecificAuthToken.token_secret >> None)
+ ),
+ ):
+ logger.info(
+ "Backfilling encrypted credentials for app specific auth %s", token.id
+ )
+ assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
- token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH]
- token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:])
- assert token_name
- assert token_secret
+ token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH]
+ token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:])
+ assert token_name
+ assert token_secret
- (AppSpecificAuthToken
- .update(token_name=token_name,
- token_secret=token_secret)
- .where(AppSpecificAuthToken.id == token.id,
- AppSpecificAuthToken.token_code == token.token_code)
- .execute())
+ (
+ AppSpecificAuthToken.update(
+ token_name=token_name, token_secret=token_secret
+ )
+ .where(
+ AppSpecificAuthToken.id == token.id,
+ AppSpecificAuthToken.token_code == token.token_code,
+ )
+ .execute()
+ )
- assert (AppSpecificAuthToken
- .select()
+ assert (
+ AppSpecificAuthToken.select()
.where(AppSpecificAuthToken.token_name >> None)
- .count()) == 0
+ .count()
+ ) == 0
- # OAuthAccessToken
- logger.info('Backfilling credentials for OAuth access tokens')
- for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) |
- (OAuthAccessToken.token_name == ''))):
- logger.info('Backfilling credentials for OAuth access token %s', token.id)
- token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH]
- token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:])
- assert token_name
- assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
+ # OAuthAccessToken
+ logger.info("Backfilling credentials for OAuth access tokens")
+ for token in _iterate(
+ OAuthAccessToken,
+ (
+ (OAuthAccessToken.token_name >> None)
+ | (OAuthAccessToken.token_name == "")
+ ),
+ ):
+ logger.info("Backfilling credentials for OAuth access token %s", token.id)
+ token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH]
+ token_code = Credential.from_string(
+ token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
+ )
+ assert token_name
+ assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
- (OAuthAccessToken
- .update(token_name=token_name,
- token_code=token_code)
- .where(OAuthAccessToken.id == token.id,
- OAuthAccessToken.access_token == token.access_token)
- .execute())
+ (
+ OAuthAccessToken.update(token_name=token_name, token_code=token_code)
+ .where(
+ OAuthAccessToken.id == token.id,
+ OAuthAccessToken.access_token == token.access_token,
+ )
+ .execute()
+ )
- assert (OAuthAccessToken
- .select()
- .where(OAuthAccessToken.token_name >> None)
- .count()) == 0
+ assert (
+ OAuthAccessToken.select().where(OAuthAccessToken.token_name >> None).count()
+ ) == 0
- # OAuthAuthorizationCode
- logger.info('Backfilling credentials for OAuth auth code')
- for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) |
- (OAuthAuthorizationCode.code_name == ''))):
- logger.info('Backfilling credentials for OAuth auth code %s', code.id)
- user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
- code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
- code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:])
- assert code_name
- assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
+ # OAuthAuthorizationCode
+ logger.info("Backfilling credentials for OAuth auth code")
+ for code in _iterate(
+ OAuthAuthorizationCode,
+ (
+ (OAuthAuthorizationCode.code_name >> None)
+ | (OAuthAuthorizationCode.code_name == "")
+ ),
+ ):
+ logger.info("Backfilling credentials for OAuth auth code %s", code.id)
+ user_code = (
+ code.code
+ or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
+ )
+ code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
+ code_credential = Credential.from_string(
+ user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
+ )
+ assert code_name
+ assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
- (OAuthAuthorizationCode
- .update(code_name=code_name, code_credential=code_credential)
- .where(OAuthAuthorizationCode.id == code.id)
- .execute())
+ (
+ OAuthAuthorizationCode.update(
+ code_name=code_name, code_credential=code_credential
+ )
+ .where(OAuthAuthorizationCode.id == code.id)
+ .execute()
+ )
- assert (OAuthAuthorizationCode
- .select()
+ assert (
+ OAuthAuthorizationCode.select()
.where(OAuthAuthorizationCode.code_name >> None)
- .count()) == 0
+ .count()
+ ) == 0
- # OAuthApplication
- logger.info('Backfilling secret for OAuth applications')
- for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False):
- logger.info('Backfilling secret for OAuth application %s', app.id)
- client_secret = app.client_secret or str(uuid.uuid4())
- secure_client_secret = _decrypted(client_secret)
+ # OAuthApplication
+ logger.info("Backfilling secret for OAuth applications")
+ for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False):
+ logger.info("Backfilling secret for OAuth application %s", app.id)
+ client_secret = app.client_secret or str(uuid.uuid4())
+ secure_client_secret = _decrypted(client_secret)
- (OAuthApplication
- .update(secure_client_secret=secure_client_secret, fully_migrated=True)
- .where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False)
- .execute())
+ (
+ OAuthApplication.update(
+ secure_client_secret=secure_client_secret, fully_migrated=True
+ )
+ .where(
+ OAuthApplication.id == app.id,
+ OAuthApplication.fully_migrated == False,
+ )
+ .execute()
+ )
- assert (OAuthApplication
- .select()
+ assert (
+ OAuthApplication.select()
.where(OAuthApplication.fully_migrated == False)
- .count()) == 0
+ .count()
+ ) == 0
- # Adjust existing fields to be nullable.
- op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('oauthaccesstoken', 'access_token', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('oauthauthorizationcode', 'code', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('appspecificauthtoken', 'token_code', nullable=True, existing_type=sa.String(length=255))
+ # Adjust existing fields to be nullable.
+ op.alter_column(
+ "accesstoken", "code", nullable=True, existing_type=sa.String(length=255)
+ )
+ op.alter_column(
+ "oauthaccesstoken",
+ "access_token",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "oauthauthorizationcode",
+ "code",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "appspecificauthtoken",
+ "token_code",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
- # Adjust new fields to be non-nullable.
- op.alter_column('accesstoken', 'token_name', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('accesstoken', 'token_code', nullable=False, existing_type=sa.String(length=255))
+ # Adjust new fields to be non-nullable.
+ op.alter_column(
+ "accesstoken", "token_name", nullable=False, existing_type=sa.String(length=255)
+ )
+ op.alter_column(
+ "accesstoken", "token_code", nullable=False, existing_type=sa.String(length=255)
+ )
- op.alter_column('appspecificauthtoken', 'token_name', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('appspecificauthtoken', 'token_secret', nullable=False, existing_type=sa.String(length=255))
+ op.alter_column(
+ "appspecificauthtoken",
+ "token_name",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "appspecificauthtoken",
+ "token_secret",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
- op.alter_column('oauthaccesstoken', 'token_name', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('oauthaccesstoken', 'token_code', nullable=False, existing_type=sa.String(length=255))
+ op.alter_column(
+ "oauthaccesstoken",
+ "token_name",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "oauthaccesstoken",
+ "token_code",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+
+ op.alter_column(
+ "oauthauthorizationcode",
+ "code_name",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "oauthauthorizationcode",
+ "code_credential",
+ nullable=False,
+ existing_type=sa.String(length=255),
+ )
- op.alter_column('oauthauthorizationcode', 'code_name', nullable=False, existing_type=sa.String(length=255))
- op.alter_column('oauthauthorizationcode', 'code_credential', nullable=False, existing_type=sa.String(length=255))
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
+ op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))
+ op.alter_column(
+ "accesstoken", "token_name", nullable=True, existing_type=sa.String(length=255)
+ )
+ op.alter_column(
+ "accesstoken", "token_code", nullable=True, existing_type=sa.String(length=255)
+ )
- op.alter_column('appspecificauthtoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('appspecificauthtoken', 'token_secret', nullable=True, existing_type=sa.String(length=255))
+ op.alter_column(
+ "appspecificauthtoken",
+ "token_name",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "appspecificauthtoken",
+ "token_secret",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
- op.alter_column('oauthaccesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('oauthaccesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))
+ op.alter_column(
+ "oauthaccesstoken",
+ "token_name",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "oauthaccesstoken",
+ "token_code",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
- op.alter_column('oauthauthorizationcode', 'code_name', nullable=True, existing_type=sa.String(length=255))
- op.alter_column('oauthauthorizationcode', 'code_credential', nullable=True, existing_type=sa.String(length=255))
+ op.alter_column(
+ "oauthauthorizationcode",
+ "code_name",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
+ op.alter_column(
+ "oauthauthorizationcode",
+ "code_credential",
+ nullable=True,
+ existing_type=sa.String(length=255),
+ )
diff --git a/data/migrations/versions/7367229b38d9_add_support_for_app_specific_tokens.py b/data/migrations/versions/7367229b38d9_add_support_for_app_specific_tokens.py
index b5fb97d63..3f2e41fd1 100644
--- a/data/migrations/versions/7367229b38d9_add_support_for_app_specific_tokens.py
+++ b/data/migrations/versions/7367229b38d9_add_support_for_app_specific_tokens.py
@@ -7,8 +7,8 @@ Create Date: 2017-12-12 13:15:42.419764
"""
# revision identifiers, used by Alembic.
-revision = '7367229b38d9'
-down_revision = 'd8989249f8f6'
+revision = "7367229b38d9"
+down_revision = "d8989249f8f6"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,59 +16,83 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from util.migrate import UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('appspecificauthtoken',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=36), nullable=False),
- sa.Column('title', UTF8CharField(length=255), nullable=False),
- sa.Column('token_code', sa.String(length=255), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.Column('expiration', sa.DateTime(), nullable=True),
- sa.Column('last_accessed', sa.DateTime(), nullable=True),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_appspecificauthtoken_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_appspecificauthtoken'))
+ op.create_table(
+ "appspecificauthtoken",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=36), nullable=False),
+ sa.Column("title", UTF8CharField(length=255), nullable=False),
+ sa.Column("token_code", sa.String(length=255), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.Column("expiration", sa.DateTime(), nullable=True),
+ sa.Column("last_accessed", sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_appspecificauthtoken_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_appspecificauthtoken")),
+ )
+ op.create_index(
+ "appspecificauthtoken_token_code",
+ "appspecificauthtoken",
+ ["token_code"],
+ unique=True,
+ )
+ op.create_index(
+ "appspecificauthtoken_user_id",
+ "appspecificauthtoken",
+ ["user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "appspecificauthtoken_user_id_expiration",
+ "appspecificauthtoken",
+ ["user_id", "expiration"],
+ unique=False,
+ )
+ op.create_index(
+ "appspecificauthtoken_uuid", "appspecificauthtoken", ["uuid"], unique=False
)
- op.create_index('appspecificauthtoken_token_code', 'appspecificauthtoken', ['token_code'], unique=True)
- op.create_index('appspecificauthtoken_user_id', 'appspecificauthtoken', ['user_id'], unique=False)
- op.create_index('appspecificauthtoken_user_id_expiration', 'appspecificauthtoken', ['user_id', 'expiration'], unique=False)
- op.create_index('appspecificauthtoken_uuid', 'appspecificauthtoken', ['uuid'], unique=False)
# ### end Alembic commands ###
- op.bulk_insert(tables.logentrykind, [
- {'name': 'create_app_specific_token'},
- {'name': 'revoke_app_specific_token'},
- ])
+ op.bulk_insert(
+ tables.logentrykind,
+ [{"name": "create_app_specific_token"}, {"name": "revoke_app_specific_token"}],
+ )
# ### population of test data ### #
- tester.populate_table('appspecificauthtoken', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('uuid', tester.TestDataType.UUID),
- ('title', tester.TestDataType.UTF8Char),
- ('token_code', tester.TestDataType.String),
- ('created', tester.TestDataType.DateTime),
- ('expiration', tester.TestDataType.DateTime),
- ('last_accessed', tester.TestDataType.DateTime),
- ])
+ tester.populate_table(
+ "appspecificauthtoken",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("uuid", tester.TestDataType.UUID),
+ ("title", tester.TestDataType.UTF8Char),
+ ("token_code", tester.TestDataType.String),
+ ("created", tester.TestDataType.DateTime),
+ ("expiration", tester.TestDataType.DateTime),
+ ("last_accessed", tester.TestDataType.DateTime),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('appspecificauthtoken')
+ op.drop_table("appspecificauthtoken")
# ### end Alembic commands ###
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.
- logentrykind.name == op.inline_literal('create_app_specific_token')))
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.name == op.inline_literal("create_app_specific_token")
+ )
+ )
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.
- logentrykind.name == op.inline_literal('revoke_app_specific_token')))
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.name == op.inline_literal("revoke_app_specific_token")
+ )
+ )
diff --git a/data/migrations/versions/7a525c68eb13_add_oci_app_models.py b/data/migrations/versions/7a525c68eb13_add_oci_app_models.py
index 7cade6854..903440477 100644
--- a/data/migrations/versions/7a525c68eb13_add_oci_app_models.py
+++ b/data/migrations/versions/7a525c68eb13_add_oci_app_models.py
@@ -7,8 +7,8 @@ Create Date: 2017-01-24 16:25:52.170277
"""
# revision identifiers, used by Alembic.
-revision = '7a525c68eb13'
-down_revision = 'e2894a3a3c19'
+revision = "7a525c68eb13"
+down_revision = "e2894a3a3c19"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -19,322 +19,563 @@ from util.migrate import UTF8LongText, UTF8CharField
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.create_table(
- 'tagkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagkind'))
- )
- op.create_index('tagkind_name', 'tagkind', ['name'], unique=True)
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.create_table(
+ "tagkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagkind")),
+ )
+ op.create_index("tagkind_name", "tagkind", ["name"], unique=True)
- op.create_table(
- 'blobplacementlocation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacementlocation'))
- )
- op.create_index('blobplacementlocation_name', 'blobplacementlocation', ['name'], unique=True)
+ op.create_table(
+ "blobplacementlocation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacementlocation")),
+ )
+ op.create_index(
+ "blobplacementlocation_name", "blobplacementlocation", ["name"], unique=True
+ )
- op.create_table(
- 'blob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('size', sa.BigInteger(), nullable=False),
- sa.Column('uncompressed_size', sa.BigInteger(), nullable=True),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_blob_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blob'))
- )
- op.create_index('blob_digest', 'blob', ['digest'], unique=True)
- op.create_index('blob_media_type_id', 'blob', ['media_type_id'], unique=False)
+ op.create_table(
+ "blob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("size", sa.BigInteger(), nullable=False),
+ sa.Column("uncompressed_size", sa.BigInteger(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_blob_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blob")),
+ )
+ op.create_index("blob_digest", "blob", ["digest"], unique=True)
+ op.create_index("blob_media_type_id", "blob", ["media_type_id"], unique=False)
- op.create_table(
- 'blobplacementlocationpreference',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobplacementlocpref_locid_blobplacementlocation')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_blobplacementlocationpreference_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacementlocationpreference'))
- )
- op.create_index('blobplacementlocationpreference_location_id', 'blobplacementlocationpreference', ['location_id'], unique=False)
- op.create_index('blobplacementlocationpreference_user_id', 'blobplacementlocationpreference', ['user_id'], unique=False)
+ op.create_table(
+ "blobplacementlocationpreference",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobplacementlocpref_locid_blobplacementlocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["user.id"],
+ name=op.f("fk_blobplacementlocationpreference_user_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacementlocationpreference")),
+ )
+ op.create_index(
+ "blobplacementlocationpreference_location_id",
+ "blobplacementlocationpreference",
+ ["location_id"],
+ unique=False,
+ )
+ op.create_index(
+ "blobplacementlocationpreference_user_id",
+ "blobplacementlocationpreference",
+ ["user_id"],
+ unique=False,
+ )
- op.create_table(
- 'manifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('manifest_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifest'))
- )
- op.create_index('manifest_digest', 'manifest', ['digest'], unique=True)
- op.create_index('manifest_media_type_id', 'manifest', ['media_type_id'], unique=False)
+ op.create_table(
+ "manifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifest")),
+ )
+ op.create_index("manifest_digest", "manifest", ["digest"], unique=True)
+ op.create_index(
+ "manifest_media_type_id", "manifest", ["media_type_id"], unique=False
+ )
- op.create_table(
- 'manifestlist',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('manifest_list_json', UTF8LongText, nullable=False),
- sa.Column('schema_version', UTF8CharField(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifestlist_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlist'))
- )
- op.create_index('manifestlist_digest', 'manifestlist', ['digest'], unique=True)
- op.create_index('manifestlist_media_type_id', 'manifestlist', ['media_type_id'], unique=False)
+ op.create_table(
+ "manifestlist",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("manifest_list_json", UTF8LongText, nullable=False),
+ sa.Column("schema_version", UTF8CharField(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifestlist_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlist")),
+ )
+ op.create_index("manifestlist_digest", "manifestlist", ["digest"], unique=True)
+ op.create_index(
+ "manifestlist_media_type_id", "manifestlist", ["media_type_id"], unique=False
+ )
- op.create_table(
- 'bittorrentpieces',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('pieces', UTF8LongText, nullable=False),
- sa.Column('piece_length', sa.BigInteger(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_bittorrentpieces_blob_id_blob')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_bittorrentpieces'))
- )
- op.create_index('bittorrentpieces_blob_id', 'bittorrentpieces', ['blob_id'], unique=False)
- op.create_index('bittorrentpieces_blob_id_piece_length', 'bittorrentpieces', ['blob_id', 'piece_length'], unique=True)
+ op.create_table(
+ "bittorrentpieces",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("pieces", UTF8LongText, nullable=False),
+ sa.Column("piece_length", sa.BigInteger(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_bittorrentpieces_blob_id_blob")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_bittorrentpieces")),
+ )
+ op.create_index(
+ "bittorrentpieces_blob_id", "bittorrentpieces", ["blob_id"], unique=False
+ )
+ op.create_index(
+ "bittorrentpieces_blob_id_piece_length",
+ "bittorrentpieces",
+ ["blob_id", "piece_length"],
+ unique=True,
+ )
- op.create_table(
- 'blobplacement',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_blobplacement_blob_id_blob')),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobplacement_location_id_blobplacementlocation')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacement'))
- )
- op.create_index('blobplacement_blob_id', 'blobplacement', ['blob_id'], unique=False)
- op.create_index('blobplacement_blob_id_location_id', 'blobplacement', ['blob_id', 'location_id'], unique=True)
- op.create_index('blobplacement_location_id', 'blobplacement', ['location_id'], unique=False)
+ op.create_table(
+ "blobplacement",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_blobplacement_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobplacement_location_id_blobplacementlocation"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacement")),
+ )
+ op.create_index("blobplacement_blob_id", "blobplacement", ["blob_id"], unique=False)
+ op.create_index(
+ "blobplacement_blob_id_location_id",
+ "blobplacement",
+ ["blob_id", "location_id"],
+ unique=True,
+ )
+ op.create_index(
+ "blobplacement_location_id", "blobplacement", ["location_id"], unique=False
+ )
- op.create_table(
- 'blobuploading',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.Column('byte_count', sa.BigInteger(), nullable=False),
- sa.Column('uncompressed_byte_count', sa.BigInteger(), nullable=True),
- sa.Column('chunk_count', sa.BigInteger(), nullable=False),
- sa.Column('storage_metadata', UTF8LongText, nullable=True),
- sa.Column('sha_state', UTF8LongText, nullable=True),
- sa.Column('piece_sha_state', UTF8LongText, nullable=True),
- sa.Column('piece_hashes', UTF8LongText, nullable=True),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobuploading_location_id_blobplacementlocation')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_blobuploading_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobuploading'))
- )
- op.create_index('blobuploading_created', 'blobuploading', ['created'], unique=False)
- op.create_index('blobuploading_location_id', 'blobuploading', ['location_id'], unique=False)
- op.create_index('blobuploading_repository_id', 'blobuploading', ['repository_id'], unique=False)
- op.create_index('blobuploading_repository_id_uuid', 'blobuploading', ['repository_id', 'uuid'], unique=True)
- op.create_index('blobuploading_uuid', 'blobuploading', ['uuid'], unique=True)
+ op.create_table(
+ "blobuploading",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.Column("byte_count", sa.BigInteger(), nullable=False),
+ sa.Column("uncompressed_byte_count", sa.BigInteger(), nullable=True),
+ sa.Column("chunk_count", sa.BigInteger(), nullable=False),
+ sa.Column("storage_metadata", UTF8LongText, nullable=True),
+ sa.Column("sha_state", UTF8LongText, nullable=True),
+ sa.Column("piece_sha_state", UTF8LongText, nullable=True),
+ sa.Column("piece_hashes", UTF8LongText, nullable=True),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobuploading_location_id_blobplacementlocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_blobuploading_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobuploading")),
+ )
+ op.create_index("blobuploading_created", "blobuploading", ["created"], unique=False)
+ op.create_index(
+ "blobuploading_location_id", "blobuploading", ["location_id"], unique=False
+ )
+ op.create_index(
+ "blobuploading_repository_id", "blobuploading", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "blobuploading_repository_id_uuid",
+ "blobuploading",
+ ["repository_id", "uuid"],
+ unique=True,
+ )
+ op.create_index("blobuploading_uuid", "blobuploading", ["uuid"], unique=True)
- op.create_table(
- 'derivedimage',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('source_manifest_id', sa.Integer(), nullable=False),
- sa.Column('derived_manifest_json', UTF8LongText, nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('uniqueness_hash', sa.String(length=255), nullable=False),
- sa.Column('signature_blob_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_derivedimage_blob_id_blob')),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_derivedimage_media_type_id_mediatype')),
- sa.ForeignKeyConstraint(['signature_blob_id'], ['blob.id'], name=op.f('fk_derivedimage_signature_blob_id_blob')),
- sa.ForeignKeyConstraint(['source_manifest_id'], ['manifest.id'], name=op.f('fk_derivedimage_source_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_derivedimage'))
- )
- op.create_index('derivedimage_blob_id', 'derivedimage', ['blob_id'], unique=False)
- op.create_index('derivedimage_media_type_id', 'derivedimage', ['media_type_id'], unique=False)
- op.create_index('derivedimage_signature_blob_id', 'derivedimage', ['signature_blob_id'], unique=False)
- op.create_index('derivedimage_source_manifest_id', 'derivedimage', ['source_manifest_id'], unique=False)
- op.create_index('derivedimage_source_manifest_id_blob_id', 'derivedimage', ['source_manifest_id', 'blob_id'], unique=True)
- op.create_index('derivedimage_source_manifest_id_media_type_id_uniqueness_hash', 'derivedimage', ['source_manifest_id', 'media_type_id', 'uniqueness_hash'], unique=True)
- op.create_index('derivedimage_uniqueness_hash', 'derivedimage', ['uniqueness_hash'], unique=True)
- op.create_index('derivedimage_uuid', 'derivedimage', ['uuid'], unique=True)
+ op.create_table(
+ "derivedimage",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("source_manifest_id", sa.Integer(), nullable=False),
+ sa.Column("derived_manifest_json", UTF8LongText, nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("uniqueness_hash", sa.String(length=255), nullable=False),
+ sa.Column("signature_blob_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_derivedimage_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_derivedimage_media_type_id_mediatype"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["signature_blob_id"],
+ ["blob.id"],
+ name=op.f("fk_derivedimage_signature_blob_id_blob"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["source_manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_derivedimage_source_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_derivedimage")),
+ )
+ op.create_index("derivedimage_blob_id", "derivedimage", ["blob_id"], unique=False)
+ op.create_index(
+ "derivedimage_media_type_id", "derivedimage", ["media_type_id"], unique=False
+ )
+ op.create_index(
+ "derivedimage_signature_blob_id",
+ "derivedimage",
+ ["signature_blob_id"],
+ unique=False,
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id",
+ "derivedimage",
+ ["source_manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id_blob_id",
+ "derivedimage",
+ ["source_manifest_id", "blob_id"],
+ unique=True,
+ )
+ op.create_index(
+ "derivedimage_source_manifest_id_media_type_id_uniqueness_hash",
+ "derivedimage",
+ ["source_manifest_id", "media_type_id", "uniqueness_hash"],
+ unique=True,
+ )
+ op.create_index(
+ "derivedimage_uniqueness_hash", "derivedimage", ["uniqueness_hash"], unique=True
+ )
+ op.create_index("derivedimage_uuid", "derivedimage", ["uuid"], unique=True)
- op.create_table(
- 'manifestblob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_manifestblob_blob_id_blob')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestblob_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestblob'))
- )
- op.create_index('manifestblob_blob_id', 'manifestblob', ['blob_id'], unique=False)
- op.create_index('manifestblob_manifest_id', 'manifestblob', ['manifest_id'], unique=False)
- op.create_index('manifestblob_manifest_id_blob_id', 'manifestblob', ['manifest_id', 'blob_id'], unique=True)
+ op.create_table(
+ "manifestblob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_manifestblob_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestblob_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestblob")),
+ )
+ op.create_index("manifestblob_blob_id", "manifestblob", ["blob_id"], unique=False)
+ op.create_index(
+ "manifestblob_manifest_id", "manifestblob", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestblob_manifest_id_blob_id",
+ "manifestblob",
+ ["manifest_id", "blob_id"],
+ unique=True,
+ )
- op.create_table(
- 'manifestlabel',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('annotated_id', sa.Integer(), nullable=False),
- sa.Column('label_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['annotated_id'], ['manifest.id'], name=op.f('fk_manifestlabel_annotated_id_manifest')),
- sa.ForeignKeyConstraint(['label_id'], ['label.id'], name=op.f('fk_manifestlabel_label_id_label')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestlabel_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlabel'))
- )
- op.create_index('manifestlabel_annotated_id', 'manifestlabel', ['annotated_id'], unique=False)
- op.create_index('manifestlabel_label_id', 'manifestlabel', ['label_id'], unique=False)
- op.create_index('manifestlabel_repository_id', 'manifestlabel', ['repository_id'], unique=False)
- op.create_index('manifestlabel_repository_id_annotated_id_label_id', 'manifestlabel', ['repository_id', 'annotated_id', 'label_id'], unique=True)
+ op.create_table(
+ "manifestlabel",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("annotated_id", sa.Integer(), nullable=False),
+ sa.Column("label_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["annotated_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlabel_annotated_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["label_id"], ["label.id"], name=op.f("fk_manifestlabel_label_id_label")
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestlabel_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlabel")),
+ )
+ op.create_index(
+ "manifestlabel_annotated_id", "manifestlabel", ["annotated_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_label_id", "manifestlabel", ["label_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_repository_id", "manifestlabel", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_repository_id_annotated_id_label_id",
+ "manifestlabel",
+ ["repository_id", "annotated_id", "label_id"],
+ unique=True,
+ )
- op.create_table(
- 'manifestlayer',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('manifest_index', sa.BigInteger(), nullable=False),
- sa.Column('metadata_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_manifestlayer_blob_id_blob')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlayer_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayer'))
- )
- op.create_index('manifestlayer_blob_id', 'manifestlayer', ['blob_id'], unique=False)
- op.create_index('manifestlayer_manifest_id', 'manifestlayer', ['manifest_id'], unique=False)
- op.create_index('manifestlayer_manifest_id_manifest_index', 'manifestlayer', ['manifest_id', 'manifest_index'], unique=True)
- op.create_index('manifestlayer_manifest_index', 'manifestlayer', ['manifest_index'], unique=False)
+ op.create_table(
+ "manifestlayer",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_index", sa.BigInteger(), nullable=False),
+ sa.Column("metadata_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_manifestlayer_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlayer_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayer")),
+ )
+ op.create_index("manifestlayer_blob_id", "manifestlayer", ["blob_id"], unique=False)
+ op.create_index(
+ "manifestlayer_manifest_id", "manifestlayer", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestlayer_manifest_id_manifest_index",
+ "manifestlayer",
+ ["manifest_id", "manifest_index"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestlayer_manifest_index",
+ "manifestlayer",
+ ["manifest_index"],
+ unique=False,
+ )
- op.create_table(
- 'manifestlistmanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('operating_system', UTF8CharField(length=255), nullable=True),
- sa.Column('architecture', UTF8CharField(length=255), nullable=True),
- sa.Column('platform_json', UTF8LongText, nullable=True),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlistmanifest_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['manifestlist.id'], name=op.f('fk_manifestlistmanifest_manifest_list_id_manifestlist')),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifestlistmanifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlistmanifest'))
- )
- op.create_index('manifestlistmanifest_manifest_id', 'manifestlistmanifest', ['manifest_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_list_id', 'manifestlistmanifest', ['manifest_list_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_listid_os_arch_mtid', 'manifestlistmanifest', ['manifest_list_id', 'operating_system', 'architecture', 'media_type_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_listid_mtid', 'manifestlistmanifest', ['manifest_list_id', 'media_type_id'], unique=False)
- op.create_index('manifestlistmanifest_media_type_id', 'manifestlistmanifest', ['media_type_id'], unique=False)
+ op.create_table(
+ "manifestlistmanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("operating_system", UTF8CharField(length=255), nullable=True),
+ sa.Column("architecture", UTF8CharField(length=255), nullable=True),
+ sa.Column("platform_json", UTF8LongText, nullable=True),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlistmanifest_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["manifestlist.id"],
+ name=op.f("fk_manifestlistmanifest_manifest_list_id_manifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifestlistmanifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlistmanifest")),
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_id",
+ "manifestlistmanifest",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_list_id",
+ "manifestlistmanifest",
+ ["manifest_list_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_listid_os_arch_mtid",
+ "manifestlistmanifest",
+ ["manifest_list_id", "operating_system", "architecture", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_listid_mtid",
+ "manifestlistmanifest",
+ ["manifest_list_id", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_media_type_id",
+ "manifestlistmanifest",
+ ["media_type_id"],
+ unique=False,
+ )
- op.create_table(
- 'tag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', UTF8CharField(length=190), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=True),
- sa.Column('lifetime_start', sa.BigInteger(), nullable=False),
- sa.Column('lifetime_end', sa.BigInteger(), nullable=True),
- sa.Column('hidden', sa.Boolean(), nullable=False),
- sa.Column('reverted', sa.Boolean(), nullable=False),
- sa.Column('protected', sa.Boolean(), nullable=False),
- sa.Column('tag_kind_id', sa.Integer(), nullable=False),
- sa.Column('linked_tag_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['linked_tag_id'], ['tag.id'], name=op.f('fk_tag_linked_tag_id_tag')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['manifestlist.id'], name=op.f('fk_tag_manifest_list_id_manifestlist')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_tag_repository_id_repository')),
- sa.ForeignKeyConstraint(['tag_kind_id'], ['tagkind.id'], name=op.f('fk_tag_tag_kind_id_tagkind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tag'))
- )
- op.create_index('tag_lifetime_end', 'tag', ['lifetime_end'], unique=False)
- op.create_index('tag_linked_tag_id', 'tag', ['linked_tag_id'], unique=False)
- op.create_index('tag_manifest_list_id', 'tag', ['manifest_list_id'], unique=False)
- op.create_index('tag_repository_id', 'tag', ['repository_id'], unique=False)
- op.create_index('tag_repository_id_name_hidden', 'tag', ['repository_id', 'name', 'hidden'], unique=False)
- op.create_index('tag_repository_id_name_lifetime_end', 'tag', ['repository_id', 'name', 'lifetime_end'], unique=True)
- op.create_index('tag_repository_id_name', 'tag', ['repository_id', 'name'], unique=False)
- op.create_index('tag_tag_kind_id', 'tag', ['tag_kind_id'], unique=False)
+ op.create_table(
+ "tag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", UTF8CharField(length=190), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=True),
+ sa.Column("lifetime_start", sa.BigInteger(), nullable=False),
+ sa.Column("lifetime_end", sa.BigInteger(), nullable=True),
+ sa.Column("hidden", sa.Boolean(), nullable=False),
+ sa.Column("reverted", sa.Boolean(), nullable=False),
+ sa.Column("protected", sa.Boolean(), nullable=False),
+ sa.Column("tag_kind_id", sa.Integer(), nullable=False),
+ sa.Column("linked_tag_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["linked_tag_id"], ["tag.id"], name=op.f("fk_tag_linked_tag_id_tag")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["manifestlist.id"],
+ name=op.f("fk_tag_manifest_list_id_manifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_tag_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_kind_id"], ["tagkind.id"], name=op.f("fk_tag_tag_kind_id_tagkind")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tag")),
+ )
+ op.create_index("tag_lifetime_end", "tag", ["lifetime_end"], unique=False)
+ op.create_index("tag_linked_tag_id", "tag", ["linked_tag_id"], unique=False)
+ op.create_index("tag_manifest_list_id", "tag", ["manifest_list_id"], unique=False)
+ op.create_index("tag_repository_id", "tag", ["repository_id"], unique=False)
+ op.create_index(
+ "tag_repository_id_name_hidden",
+ "tag",
+ ["repository_id", "name", "hidden"],
+ unique=False,
+ )
+ op.create_index(
+ "tag_repository_id_name_lifetime_end",
+ "tag",
+ ["repository_id", "name", "lifetime_end"],
+ unique=True,
+ )
+ op.create_index(
+ "tag_repository_id_name", "tag", ["repository_id", "name"], unique=False
+ )
+ op.create_index("tag_tag_kind_id", "tag", ["tag_kind_id"], unique=False)
- op.create_table(
- 'manifestlayerdockerv1',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_layer_id', sa.Integer(), nullable=False),
- sa.Column('image_id', UTF8CharField(length=255), nullable=False),
- sa.Column('checksum', UTF8CharField(length=255), nullable=False),
- sa.Column('compat_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['manifest_layer_id'], ['manifestlayer.id'], name=op.f('fk_manifestlayerdockerv1_manifest_layer_id_manifestlayer')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayerdockerv1'))
- )
- op.create_index('manifestlayerdockerv1_image_id', 'manifestlayerdockerv1', ['image_id'], unique=False)
- op.create_index('manifestlayerdockerv1_manifest_layer_id', 'manifestlayerdockerv1', ['manifest_layer_id'], unique=False)
+ op.create_table(
+ "manifestlayerdockerv1",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_layer_id", sa.Integer(), nullable=False),
+ sa.Column("image_id", UTF8CharField(length=255), nullable=False),
+ sa.Column("checksum", UTF8CharField(length=255), nullable=False),
+ sa.Column("compat_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["manifest_layer_id"],
+ ["manifestlayer.id"],
+ name=op.f("fk_manifestlayerdockerv1_manifest_layer_id_manifestlayer"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayerdockerv1")),
+ )
+ op.create_index(
+ "manifestlayerdockerv1_image_id",
+ "manifestlayerdockerv1",
+ ["image_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlayerdockerv1_manifest_layer_id",
+ "manifestlayerdockerv1",
+ ["manifest_layer_id"],
+ unique=False,
+ )
- op.create_table(
- 'manifestlayerscan',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('layer_id', sa.Integer(), nullable=False),
- sa.Column('scannable', sa.Boolean(), nullable=False),
- sa.Column('scanned_by', UTF8CharField(length=255), nullable=False),
- sa.ForeignKeyConstraint(['layer_id'], ['manifestlayer.id'], name=op.f('fk_manifestlayerscan_layer_id_manifestlayer')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlayerscan'))
- )
- op.create_index('manifestlayerscan_layer_id', 'manifestlayerscan', ['layer_id'], unique=True)
+ op.create_table(
+ "manifestlayerscan",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("layer_id", sa.Integer(), nullable=False),
+ sa.Column("scannable", sa.Boolean(), nullable=False),
+ sa.Column("scanned_by", UTF8CharField(length=255), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["layer_id"],
+ ["manifestlayer.id"],
+ name=op.f("fk_manifestlayerscan_layer_id_manifestlayer"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlayerscan")),
+ )
+ op.create_index(
+ "manifestlayerscan_layer_id", "manifestlayerscan", ["layer_id"], unique=True
+ )
- blobplacementlocation_table = table('blobplacementlocation',
- column('id', sa.Integer()),
- column('name', sa.String()),
- )
+ blobplacementlocation_table = table(
+ "blobplacementlocation", column("id", sa.Integer()), column("name", sa.String())
+ )
- op.bulk_insert(
- blobplacementlocation_table,
- [
- {'name': 'local_eu'},
- {'name': 'local_us'},
- ],
- )
+ op.bulk_insert(
+ blobplacementlocation_table, [{"name": "local_eu"}, {"name": "local_us"}]
+ )
- op.bulk_insert(
- tables.mediatype,
- [
- {'name': 'application/vnd.cnr.blob.v0.tar+gzip'},
- {'name': 'application/vnd.cnr.package-manifest.helm.v0.json'},
- {'name': 'application/vnd.cnr.package-manifest.kpm.v0.json'},
- {'name': 'application/vnd.cnr.package-manifest.docker-compose.v0.json'},
- {'name': 'application/vnd.cnr.package.kpm.v0.tar+gzip'},
- {'name': 'application/vnd.cnr.package.helm.v0.tar+gzip'},
- {'name': 'application/vnd.cnr.package.docker-compose.v0.tar+gzip'},
- {'name': 'application/vnd.cnr.manifests.v0.json'},
- {'name': 'application/vnd.cnr.manifest.list.v0.json'},
- ],
- )
+ op.bulk_insert(
+ tables.mediatype,
+ [
+ {"name": "application/vnd.cnr.blob.v0.tar+gzip"},
+ {"name": "application/vnd.cnr.package-manifest.helm.v0.json"},
+ {"name": "application/vnd.cnr.package-manifest.kpm.v0.json"},
+ {"name": "application/vnd.cnr.package-manifest.docker-compose.v0.json"},
+ {"name": "application/vnd.cnr.package.kpm.v0.tar+gzip"},
+ {"name": "application/vnd.cnr.package.helm.v0.tar+gzip"},
+ {"name": "application/vnd.cnr.package.docker-compose.v0.tar+gzip"},
+ {"name": "application/vnd.cnr.manifests.v0.json"},
+ {"name": "application/vnd.cnr.manifest.list.v0.json"},
+ ],
+ )
- tagkind_table = table('tagkind',
- column('id', sa.Integer()),
- column('name', sa.String()),
- )
+ tagkind_table = table(
+ "tagkind", column("id", sa.Integer()), column("name", sa.String())
+ )
+
+ op.bulk_insert(
+ tagkind_table,
+ [
+ {"id": 1, "name": "tag"},
+ {"id": 2, "name": "release"},
+ {"id": 3, "name": "channel"},
+ ],
+ )
- op.bulk_insert(
- tagkind_table,
- [
- {'id': 1, 'name': 'tag'},
- {'id': 2, 'name': 'release'},
- {'id': 3, 'name': 'channel'},
- ]
- )
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.drop_table('manifestlayerscan')
- op.drop_table('manifestlayerdockerv1')
- op.drop_table('tag')
- op.drop_table('manifestlistmanifest')
- op.drop_table('manifestlayer')
- op.drop_table('manifestlabel')
- op.drop_table('manifestblob')
- op.drop_table('derivedimage')
- op.drop_table('blobuploading')
- op.drop_table('blobplacement')
- op.drop_table('bittorrentpieces')
- op.drop_table('manifestlist')
- op.drop_table('manifest')
- op.drop_table('blobplacementlocationpreference')
- op.drop_table('blob')
- op.drop_table('tagkind')
- op.drop_table('blobplacementlocation')
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.drop_table("manifestlayerscan")
+ op.drop_table("manifestlayerdockerv1")
+ op.drop_table("tag")
+ op.drop_table("manifestlistmanifest")
+ op.drop_table("manifestlayer")
+ op.drop_table("manifestlabel")
+ op.drop_table("manifestblob")
+ op.drop_table("derivedimage")
+ op.drop_table("blobuploading")
+ op.drop_table("blobplacement")
+ op.drop_table("bittorrentpieces")
+ op.drop_table("manifestlist")
+ op.drop_table("manifest")
+ op.drop_table("blobplacementlocationpreference")
+ op.drop_table("blob")
+ op.drop_table("tagkind")
+ op.drop_table("blobplacementlocation")
diff --git a/data/migrations/versions/87fbbc224f10_add_disabled_datetime_to_trigger.py b/data/migrations/versions/87fbbc224f10_add_disabled_datetime_to_trigger.py
index ac177cd9f..a244005b2 100644
--- a/data/migrations/versions/87fbbc224f10_add_disabled_datetime_to_trigger.py
+++ b/data/migrations/versions/87fbbc224f10_add_disabled_datetime_to_trigger.py
@@ -7,29 +7,42 @@ Create Date: 2017-10-24 14:06:37.658705
"""
# revision identifiers, used by Alembic.
-revision = '87fbbc224f10'
-down_revision = '17aff2e1354e'
+revision = "87fbbc224f10"
+down_revision = "17aff2e1354e"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('repositorybuildtrigger', sa.Column('disabled_datetime', sa.DateTime(), nullable=True))
- op.create_index('repositorybuildtrigger_disabled_datetime', 'repositorybuildtrigger', ['disabled_datetime'], unique=False)
+ op.add_column(
+ "repositorybuildtrigger",
+ sa.Column("disabled_datetime", sa.DateTime(), nullable=True),
+ )
+ op.create_index(
+ "repositorybuildtrigger_disabled_datetime",
+ "repositorybuildtrigger",
+ ["disabled_datetime"],
+ unique=False,
+ )
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('repositorybuildtrigger', 'disabled_datetime', tester.TestDataType.DateTime)
+ tester.populate_column(
+ "repositorybuildtrigger", "disabled_datetime", tester.TestDataType.DateTime
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('repositorybuildtrigger_disabled_datetime', table_name='repositorybuildtrigger')
- op.drop_column('repositorybuildtrigger', 'disabled_datetime')
+ op.drop_index(
+ "repositorybuildtrigger_disabled_datetime", table_name="repositorybuildtrigger"
+ )
+ op.drop_column("repositorybuildtrigger", "disabled_datetime")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/9093adccc784_add_v2_2_data_models_for_manifest_.py b/data/migrations/versions/9093adccc784_add_v2_2_data_models_for_manifest_.py
index 49797c6ae..94a4f11d7 100644
--- a/data/migrations/versions/9093adccc784_add_v2_2_data_models_for_manifest_.py
+++ b/data/migrations/versions/9093adccc784_add_v2_2_data_models_for_manifest_.py
@@ -7,8 +7,8 @@ Create Date: 2018-08-06 16:07:50.222749
"""
# revision identifiers, used by Alembic.
-revision = '9093adccc784'
-down_revision = '6c21e2cfb8b6'
+revision = "9093adccc784"
+down_revision = "6c21e2cfb8b6"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -19,144 +19,344 @@ from image.docker.schema1 import DOCKER_SCHEMA1_CONTENT_TYPES
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('manifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('manifest_bytes', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifest_media_type_id_mediatype')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifest_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifest'))
+ op.create_table(
+ "manifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_bytes", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifest_media_type_id_mediatype"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifest_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifest")),
)
- op.create_index('manifest_digest', 'manifest', ['digest'], unique=False)
- op.create_index('manifest_media_type_id', 'manifest', ['media_type_id'], unique=False)
- op.create_index('manifest_repository_id', 'manifest', ['repository_id'], unique=False)
- op.create_index('manifest_repository_id_digest', 'manifest', ['repository_id', 'digest'], unique=True)
- op.create_index('manifest_repository_id_media_type_id', 'manifest', ['repository_id', 'media_type_id'], unique=False)
- op.create_table('manifestblob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('blob_index', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['imagestorage.id'], name=op.f('fk_manifestblob_blob_id_imagestorage')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestblob_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestblob_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestblob'))
+ op.create_index("manifest_digest", "manifest", ["digest"], unique=False)
+ op.create_index(
+ "manifest_media_type_id", "manifest", ["media_type_id"], unique=False
)
- op.create_index('manifestblob_blob_id', 'manifestblob', ['blob_id'], unique=False)
- op.create_index('manifestblob_manifest_id', 'manifestblob', ['manifest_id'], unique=False)
- op.create_index('manifestblob_manifest_id_blob_id', 'manifestblob', ['manifest_id', 'blob_id'], unique=True)
- op.create_index('manifestblob_manifest_id_blob_index', 'manifestblob', ['manifest_id', 'blob_index'], unique=True)
- op.create_index('manifestblob_repository_id', 'manifestblob', ['repository_id'], unique=False)
- op.create_table('manifestlabel',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('label_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['label_id'], ['label.id'], name=op.f('fk_manifestlabel_label_id_label')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlabel_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestlabel_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlabel'))
+ op.create_index(
+ "manifest_repository_id", "manifest", ["repository_id"], unique=False
)
- op.create_index('manifestlabel_label_id', 'manifestlabel', ['label_id'], unique=False)
- op.create_index('manifestlabel_manifest_id', 'manifestlabel', ['manifest_id'], unique=False)
- op.create_index('manifestlabel_manifest_id_label_id', 'manifestlabel', ['manifest_id', 'label_id'], unique=True)
- op.create_index('manifestlabel_repository_id', 'manifestlabel', ['repository_id'], unique=False)
- op.create_table('manifestlegacyimage',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('image_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['image_id'], ['image.id'], name=op.f('fk_manifestlegacyimage_image_id_image')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlegacyimage_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_manifestlegacyimage_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlegacyimage'))
+ op.create_index(
+ "manifest_repository_id_digest",
+ "manifest",
+ ["repository_id", "digest"],
+ unique=True,
)
- op.create_index('manifestlegacyimage_image_id', 'manifestlegacyimage', ['image_id'], unique=False)
- op.create_index('manifestlegacyimage_manifest_id', 'manifestlegacyimage', ['manifest_id'], unique=True)
- op.create_index('manifestlegacyimage_repository_id', 'manifestlegacyimage', ['repository_id'], unique=False)
- op.create_table('tagmanifesttomanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('tag_manifest_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('broken', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_tagmanifesttomanifest_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['tag_manifest_id'], ['tagmanifest.id'], name=op.f('fk_tagmanifesttomanifest_tag_manifest_id_tagmanifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagmanifesttomanifest'))
+ op.create_index(
+ "manifest_repository_id_media_type_id",
+ "manifest",
+ ["repository_id", "media_type_id"],
+ unique=False,
)
- op.create_index('tagmanifesttomanifest_broken', 'tagmanifesttomanifest', ['broken'], unique=False)
- op.create_index('tagmanifesttomanifest_manifest_id', 'tagmanifesttomanifest', ['manifest_id'], unique=True)
- op.create_index('tagmanifesttomanifest_tag_manifest_id', 'tagmanifesttomanifest', ['tag_manifest_id'], unique=True)
- op.create_table('tagmanifestlabelmap',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('tag_manifest_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=True),
- sa.Column('label_id', sa.Integer(), nullable=False),
- sa.Column('tag_manifest_label_id', sa.Integer(), nullable=False),
- sa.Column('manifest_label_id', sa.Integer(), nullable=True),
- sa.Column('broken_manifest', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.ForeignKeyConstraint(['label_id'], ['label.id'], name=op.f('fk_tagmanifestlabelmap_label_id_label')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_tagmanifestlabelmap_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['manifest_label_id'], ['manifestlabel.id'], name=op.f('fk_tagmanifestlabelmap_manifest_label_id_manifestlabel')),
- sa.ForeignKeyConstraint(['tag_manifest_id'], ['tagmanifest.id'], name=op.f('fk_tagmanifestlabelmap_tag_manifest_id_tagmanifest')),
- sa.ForeignKeyConstraint(['tag_manifest_label_id'], ['tagmanifestlabel.id'], name=op.f('fk_tagmanifestlabelmap_tag_manifest_label_id_tagmanifestlabel')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagmanifestlabelmap'))
+ op.create_table(
+ "manifestblob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("blob_index", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_manifestblob_blob_id_imagestorage"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestblob_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestblob_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestblob")),
+ )
+ op.create_index("manifestblob_blob_id", "manifestblob", ["blob_id"], unique=False)
+ op.create_index(
+ "manifestblob_manifest_id", "manifestblob", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestblob_manifest_id_blob_id",
+ "manifestblob",
+ ["manifest_id", "blob_id"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestblob_manifest_id_blob_index",
+ "manifestblob",
+ ["manifest_id", "blob_index"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestblob_repository_id", "manifestblob", ["repository_id"], unique=False
+ )
+ op.create_table(
+ "manifestlabel",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("label_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["label_id"], ["label.id"], name=op.f("fk_manifestlabel_label_id_label")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlabel_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestlabel_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlabel")),
+ )
+ op.create_index(
+ "manifestlabel_label_id", "manifestlabel", ["label_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_manifest_id", "manifestlabel", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestlabel_manifest_id_label_id",
+ "manifestlabel",
+ ["manifest_id", "label_id"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestlabel_repository_id", "manifestlabel", ["repository_id"], unique=False
+ )
+ op.create_table(
+ "manifestlegacyimage",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("image_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["image_id"],
+ ["image.id"],
+ name=op.f("fk_manifestlegacyimage_image_id_image"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlegacyimage_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_manifestlegacyimage_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlegacyimage")),
+ )
+ op.create_index(
+ "manifestlegacyimage_image_id",
+ "manifestlegacyimage",
+ ["image_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlegacyimage_manifest_id",
+ "manifestlegacyimage",
+ ["manifest_id"],
+ unique=True,
+ )
+ op.create_index(
+ "manifestlegacyimage_repository_id",
+ "manifestlegacyimage",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_table(
+ "tagmanifesttomanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("tag_manifest_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "broken",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_tagmanifesttomanifest_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_manifest_id"],
+ ["tagmanifest.id"],
+ name=op.f("fk_tagmanifesttomanifest_tag_manifest_id_tagmanifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagmanifesttomanifest")),
+ )
+ op.create_index(
+ "tagmanifesttomanifest_broken",
+ "tagmanifesttomanifest",
+ ["broken"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifesttomanifest_manifest_id",
+ "tagmanifesttomanifest",
+ ["manifest_id"],
+ unique=True,
+ )
+ op.create_index(
+ "tagmanifesttomanifest_tag_manifest_id",
+ "tagmanifesttomanifest",
+ ["tag_manifest_id"],
+ unique=True,
+ )
+ op.create_table(
+ "tagmanifestlabelmap",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("tag_manifest_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=True),
+ sa.Column("label_id", sa.Integer(), nullable=False),
+ sa.Column("tag_manifest_label_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_label_id", sa.Integer(), nullable=True),
+ sa.Column(
+ "broken_manifest",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.ForeignKeyConstraint(
+ ["label_id"],
+ ["label.id"],
+ name=op.f("fk_tagmanifestlabelmap_label_id_label"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_tagmanifestlabelmap_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_label_id"],
+ ["manifestlabel.id"],
+ name=op.f("fk_tagmanifestlabelmap_manifest_label_id_manifestlabel"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_manifest_id"],
+ ["tagmanifest.id"],
+ name=op.f("fk_tagmanifestlabelmap_tag_manifest_id_tagmanifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_manifest_label_id"],
+ ["tagmanifestlabel.id"],
+ name=op.f("fk_tagmanifestlabelmap_tag_manifest_label_id_tagmanifestlabel"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagmanifestlabelmap")),
+ )
+ op.create_index(
+ "tagmanifestlabelmap_broken_manifest",
+ "tagmanifestlabelmap",
+ ["broken_manifest"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabelmap_label_id",
+ "tagmanifestlabelmap",
+ ["label_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabelmap_manifest_id",
+ "tagmanifestlabelmap",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabelmap_manifest_label_id",
+ "tagmanifestlabelmap",
+ ["manifest_label_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabelmap_tag_manifest_id",
+ "tagmanifestlabelmap",
+ ["tag_manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabelmap_tag_manifest_label_id",
+ "tagmanifestlabelmap",
+ ["tag_manifest_label_id"],
+ unique=False,
)
- op.create_index('tagmanifestlabelmap_broken_manifest', 'tagmanifestlabelmap', ['broken_manifest'], unique=False)
- op.create_index('tagmanifestlabelmap_label_id', 'tagmanifestlabelmap', ['label_id'], unique=False)
- op.create_index('tagmanifestlabelmap_manifest_id', 'tagmanifestlabelmap', ['manifest_id'], unique=False)
- op.create_index('tagmanifestlabelmap_manifest_label_id', 'tagmanifestlabelmap', ['manifest_label_id'], unique=False)
- op.create_index('tagmanifestlabelmap_tag_manifest_id', 'tagmanifestlabelmap', ['tag_manifest_id'], unique=False)
- op.create_index('tagmanifestlabelmap_tag_manifest_label_id', 'tagmanifestlabelmap', ['tag_manifest_label_id'], unique=False)
# ### end Alembic commands ###
for media_type in DOCKER_SCHEMA1_CONTENT_TYPES:
- op.bulk_insert(tables.mediatype,
- [
- {'name': media_type},
- ])
+ op.bulk_insert(tables.mediatype, [{"name": media_type}])
# ### population of test data ### #
- tester.populate_table('manifest', [
- ('digest', tester.TestDataType.String),
- ('manifest_bytes', tester.TestDataType.JSON),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ])
+ tester.populate_table(
+ "manifest",
+ [
+ ("digest", tester.TestDataType.String),
+ ("manifest_bytes", tester.TestDataType.JSON),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ],
+ )
- tester.populate_table('manifestblob', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('blob_id', tester.TestDataType.Foreign('imagestorage')),
- ('blob_index', tester.TestDataType.Integer),
- ])
+ tester.populate_table(
+ "manifestblob",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("blob_id", tester.TestDataType.Foreign("imagestorage")),
+ ("blob_index", tester.TestDataType.Integer),
+ ],
+ )
- tester.populate_table('manifestlabel', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('label_id', tester.TestDataType.Foreign('label')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ])
+ tester.populate_table(
+ "manifestlabel",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("label_id", tester.TestDataType.Foreign("label")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ],
+ )
- tester.populate_table('manifestlegacyimage', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('image_id', tester.TestDataType.Foreign('image')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ])
+ tester.populate_table(
+ "manifestlegacyimage",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("image_id", tester.TestDataType.Foreign("image")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ],
+ )
- tester.populate_table('tagmanifesttomanifest', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('tag_manifest_id', tester.TestDataType.Foreign('tagmanifest')),
- ])
+ tester.populate_table(
+ "tagmanifesttomanifest",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("tag_manifest_id", tester.TestDataType.Foreign("tagmanifest")),
+ ],
+ )
- tester.populate_table('tagmanifestlabelmap', [
- ('manifest_id', tester.TestDataType.Foreign('manifest')),
- ('tag_manifest_id', tester.TestDataType.Foreign('tagmanifest')),
- ('tag_manifest_label_id', tester.TestDataType.Foreign('tagmanifestlabel')),
- ('manifest_label_id', tester.TestDataType.Foreign('manifestlabel')),
- ('label_id', tester.TestDataType.Foreign('label')),
- ])
+ tester.populate_table(
+ "tagmanifestlabelmap",
+ [
+ ("manifest_id", tester.TestDataType.Foreign("manifest")),
+ ("tag_manifest_id", tester.TestDataType.Foreign("tagmanifest")),
+ ("tag_manifest_label_id", tester.TestDataType.Foreign("tagmanifestlabel")),
+ ("manifest_label_id", tester.TestDataType.Foreign("manifestlabel")),
+ ("label_id", tester.TestDataType.Foreign("label")),
+ ],
+ )
# ### end population of test data ### #
@@ -164,17 +364,17 @@ def upgrade(tables, tester, progress_reporter):
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
for media_type in DOCKER_SCHEMA1_CONTENT_TYPES:
- op.execute(tables
- .mediatype
- .delete()
- .where(tables.
- mediatype.c.name == op.inline_literal(media_type)))
+ op.execute(
+ tables.mediatype.delete().where(
+ tables.mediatype.c.name == op.inline_literal(media_type)
+ )
+ )
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('tagmanifestlabelmap')
- op.drop_table('tagmanifesttomanifest')
- op.drop_table('manifestlegacyimage')
- op.drop_table('manifestlabel')
- op.drop_table('manifestblob')
- op.drop_table('manifest')
+ op.drop_table("tagmanifestlabelmap")
+ op.drop_table("tagmanifesttomanifest")
+ op.drop_table("manifestlegacyimage")
+ op.drop_table("manifestlabel")
+ op.drop_table("manifestblob")
+ op.drop_table("manifest")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/94836b099894_create_new_notification_type.py b/data/migrations/versions/94836b099894_create_new_notification_type.py
index 6bc780d01..f46cf08ae 100644
--- a/data/migrations/versions/94836b099894_create_new_notification_type.py
+++ b/data/migrations/versions/94836b099894_create_new_notification_type.py
@@ -7,8 +7,8 @@ Create Date: 2016-11-30 10:29:51.519278
"""
# revision identifiers, used by Alembic.
-revision = '94836b099894'
-down_revision = 'faf752bd2e0a'
+revision = "94836b099894"
+down_revision = "faf752bd2e0a"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,16 +16,14 @@ from data.migrations.progress import ProgressWrapper
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.bulk_insert(tables.externalnotificationevent,
- [
- {'name': 'build_cancelled'},
- ])
+ op.bulk_insert(tables.externalnotificationevent, [{"name": "build_cancelled"}])
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.execute(tables
- .externalnotificationevent
- .delete()
- .where(tables.
- externalnotificationevent.c.name == op.inline_literal('build_cancelled')))
+ op.execute(
+ tables.externalnotificationevent.delete().where(
+ tables.externalnotificationevent.c.name
+ == op.inline_literal("build_cancelled")
+ )
+ )
diff --git a/data/migrations/versions/a6c463dfb9fe_back_fill_build_expand_config.py b/data/migrations/versions/a6c463dfb9fe_back_fill_build_expand_config.py
index c4c6b3f33..927f6952b 100644
--- a/data/migrations/versions/a6c463dfb9fe_back_fill_build_expand_config.py
+++ b/data/migrations/versions/a6c463dfb9fe_back_fill_build_expand_config.py
@@ -14,88 +14,89 @@ from app import app
from peewee import *
from data.database import BaseModel
-revision = 'a6c463dfb9fe'
-down_revision = 'b4df55dea4b3'
+revision = "a6c463dfb9fe"
+down_revision = "b4df55dea4b3"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
class RepositoryBuildTrigger(BaseModel):
- config = TextField(default='{}')
+ config = TextField(default="{}")
+
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- if not app.config.get('SETUP_COMPLETE', False):
- return
+ op = ProgressWrapper(original_op, progress_reporter)
+ if not app.config.get("SETUP_COMPLETE", False):
+ return
- repostioryBuildTriggers = RepositoryBuildTrigger.select()
- for repositoryBuildTrigger in repostioryBuildTriggers:
- config = json.loads(repositoryBuildTrigger.config)
- repositoryBuildTrigger.config = json.dumps(get_config_expand(config))
- repositoryBuildTrigger.save()
+ repostioryBuildTriggers = RepositoryBuildTrigger.select()
+ for repositoryBuildTrigger in repostioryBuildTriggers:
+ config = json.loads(repositoryBuildTrigger.config)
+ repositoryBuildTrigger.config = json.dumps(get_config_expand(config))
+ repositoryBuildTrigger.save()
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- if not app.config.get('SETUP_COMPLETE', False):
- return
+ op = ProgressWrapper(original_op, progress_reporter)
+ if not app.config.get("SETUP_COMPLETE", False):
+ return
- repostioryBuildTriggers = RepositoryBuildTrigger.select()
- for repositoryBuildTrigger in repostioryBuildTriggers:
- config = json.loads(repositoryBuildTrigger.config)
- repositoryBuildTrigger.config = json.dumps(get_config_expand(config))
- repositoryBuildTrigger.save()
+ repostioryBuildTriggers = RepositoryBuildTrigger.select()
+ for repositoryBuildTrigger in repostioryBuildTriggers:
+ config = json.loads(repositoryBuildTrigger.config)
+ repositoryBuildTrigger.config = json.dumps(get_config_expand(config))
+ repositoryBuildTrigger.save()
def create_context(current_subdir):
- if current_subdir == "":
- current_subdir = os.path.sep + current_subdir
+ if current_subdir == "":
+ current_subdir = os.path.sep + current_subdir
- if current_subdir[len(current_subdir) - 1] != os.path.sep:
- current_subdir += os.path.sep
+ if current_subdir[len(current_subdir) - 1] != os.path.sep:
+ current_subdir += os.path.sep
- context, _ = os.path.split(current_subdir)
- return context
+ context, _ = os.path.split(current_subdir)
+ return context
def create_dockerfile_path(current_subdir):
- if current_subdir == "":
- current_subdir = os.path.sep + current_subdir
+ if current_subdir == "":
+ current_subdir = os.path.sep + current_subdir
- if current_subdir[len(current_subdir) - 1] != os.path.sep:
- current_subdir += os.path.sep
+ if current_subdir[len(current_subdir) - 1] != os.path.sep:
+ current_subdir += os.path.sep
- return current_subdir + "Dockerfile"
+ return current_subdir + "Dockerfile"
def get_config_expand(config):
- """ A function to transform old records into new records """
- if not config:
- return config
+ """ A function to transform old records into new records """
+ if not config:
+ return config
- # skip records that have been updated
- if "context" in config or "dockerfile_path" in config:
- return config
+ # skip records that have been updated
+ if "context" in config or "dockerfile_path" in config:
+ return config
- config_expand = {}
- if "subdir" in config:
- config_expand = dict(config)
- config_expand["context"] = create_context(config["subdir"])
- config_expand["dockerfile_path"] = create_dockerfile_path(config["subdir"])
+ config_expand = {}
+ if "subdir" in config:
+ config_expand = dict(config)
+ config_expand["context"] = create_context(config["subdir"])
+ config_expand["dockerfile_path"] = create_dockerfile_path(config["subdir"])
- return config_expand
+ return config_expand
def get_config_contract(config):
- """ A function to delete context and dockerfile_path from config """
- if not config:
+ """ A function to delete context and dockerfile_path from config """
+ if not config:
+ return config
+
+ if "context" in config:
+ del config["context"]
+
+ if "dockerfile_path" in config:
+ del config["dockerfile_path"]
+
return config
-
- if "context" in config:
- del config["context"]
-
- if "dockerfile_path" in config:
- del config["dockerfile_path"]
-
- return config
diff --git a/data/migrations/versions/b4c2d45bc132_add_deleted_namespace_table.py b/data/migrations/versions/b4c2d45bc132_add_deleted_namespace_table.py
index d9c53f10c..a24c7d6bc 100644
--- a/data/migrations/versions/b4c2d45bc132_add_deleted_namespace_table.py
+++ b/data/migrations/versions/b4c2d45bc132_add_deleted_namespace_table.py
@@ -7,8 +7,8 @@ Create Date: 2018-02-27 11:43:02.329941
"""
# revision identifiers, used by Alembic.
-revision = 'b4c2d45bc132'
-down_revision = '152edccba18c'
+revision = "b4c2d45bc132"
+down_revision = "152edccba18c"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -18,36 +18,60 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('deletednamespace',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('namespace_id', sa.Integer(), nullable=False),
- sa.Column('marked', sa.DateTime(), nullable=False),
- sa.Column('original_username', sa.String(length=255), nullable=False),
- sa.Column('original_email', sa.String(length=255), nullable=False),
- sa.Column('queue_id', sa.String(length=255), nullable=True),
- sa.ForeignKeyConstraint(['namespace_id'], ['user.id'], name=op.f('fk_deletednamespace_namespace_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_deletednamespace'))
+ op.create_table(
+ "deletednamespace",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("namespace_id", sa.Integer(), nullable=False),
+ sa.Column("marked", sa.DateTime(), nullable=False),
+ sa.Column("original_username", sa.String(length=255), nullable=False),
+ sa.Column("original_email", sa.String(length=255), nullable=False),
+ sa.Column("queue_id", sa.String(length=255), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["namespace_id"],
+ ["user.id"],
+ name=op.f("fk_deletednamespace_namespace_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_deletednamespace")),
+ )
+ op.create_index(
+ "deletednamespace_namespace_id",
+ "deletednamespace",
+ ["namespace_id"],
+ unique=True,
+ )
+ op.create_index(
+ "deletednamespace_original_email",
+ "deletednamespace",
+ ["original_email"],
+ unique=False,
+ )
+ op.create_index(
+ "deletednamespace_original_username",
+ "deletednamespace",
+ ["original_username"],
+ unique=False,
+ )
+ op.create_index(
+ "deletednamespace_queue_id", "deletednamespace", ["queue_id"], unique=False
)
- op.create_index('deletednamespace_namespace_id', 'deletednamespace', ['namespace_id'], unique=True)
- op.create_index('deletednamespace_original_email', 'deletednamespace', ['original_email'], unique=False)
- op.create_index('deletednamespace_original_username', 'deletednamespace', ['original_username'], unique=False)
- op.create_index('deletednamespace_queue_id', 'deletednamespace', ['queue_id'], unique=False)
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_table('deletednamespace', [
- ('namespace_id', tester.TestDataType.Foreign('user')),
- ('marked', tester.TestDataType.DateTime),
- ('original_username', tester.TestDataType.UTF8Char),
- ('original_email', tester.TestDataType.String),
- ('queue_id', tester.TestDataType.Foreign('queueitem')),
- ])
+ tester.populate_table(
+ "deletednamespace",
+ [
+ ("namespace_id", tester.TestDataType.Foreign("user")),
+ ("marked", tester.TestDataType.DateTime),
+ ("original_username", tester.TestDataType.UTF8Char),
+ ("original_email", tester.TestDataType.String),
+ ("queue_id", tester.TestDataType.Foreign("queueitem")),
+ ],
+ )
# ### end population of test data ### #
-
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('deletednamespace')
+ op.drop_table("deletednamespace")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/b4df55dea4b3_add_repository_kind.py b/data/migrations/versions/b4df55dea4b3_add_repository_kind.py
index d96dd8c43..47836bc50 100644
--- a/data/migrations/versions/b4df55dea4b3_add_repository_kind.py
+++ b/data/migrations/versions/b4df55dea4b3_add_repository_kind.py
@@ -7,8 +7,8 @@ Create Date: 2017-03-19 12:59:41.484430
"""
# revision identifiers, used by Alembic.
-revision = 'b4df55dea4b3'
-down_revision = 'b8ae68ad3e52'
+revision = "b4df55dea4b3"
+down_revision = "b8ae68ad3e52"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -17,35 +17,45 @@ from sqlalchemy.dialects import mysql
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.create_table(
- 'repositorykind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorykind'))
- )
- op.create_index('repositorykind_name', 'repositorykind', ['name'], unique=True)
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.create_table(
+ "repositorykind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorykind")),
+ )
+ op.create_index("repositorykind_name", "repositorykind", ["name"], unique=True)
- op.bulk_insert(
- tables.repositorykind,
- [
- {'id': 1, 'name': 'image'},
- {'id': 2, 'name': 'application'},
- ],
- )
+ op.bulk_insert(
+ tables.repositorykind,
+ [{"id": 1, "name": "image"}, {"id": 2, "name": "application"}],
+ )
- op.add_column(u'repository', sa.Column('kind_id', sa.Integer(), nullable=False, server_default='1'))
- op.create_index('repository_kind_id', 'repository', ['kind_id'], unique=False)
- op.create_foreign_key(op.f('fk_repository_kind_id_repositorykind'), 'repository', 'repositorykind', ['kind_id'], ['id'])
+ op.add_column(
+ u"repository",
+ sa.Column("kind_id", sa.Integer(), nullable=False, server_default="1"),
+ )
+ op.create_index("repository_kind_id", "repository", ["kind_id"], unique=False)
+ op.create_foreign_key(
+ op.f("fk_repository_kind_id_repositorykind"),
+ "repository",
+ "repositorykind",
+ ["kind_id"],
+ ["id"],
+ )
- # ### population of test data ### #
- tester.populate_column('repository', 'kind_id', tester.TestDataType.Foreign('repositorykind'))
- # ### end population of test data ### #
+ # ### population of test data ### #
+ tester.populate_column(
+ "repository", "kind_id", tester.TestDataType.Foreign("repositorykind")
+ )
+ # ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.drop_constraint(op.f('fk_repository_kind_id_repositorykind'), 'repository', type_='foreignkey')
- op.drop_index('repository_kind_id', table_name='repository')
- op.drop_column(u'repository', 'kind_id')
- op.drop_table('repositorykind')
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.drop_constraint(
+ op.f("fk_repository_kind_id_repositorykind"), "repository", type_="foreignkey"
+ )
+ op.drop_index("repository_kind_id", table_name="repository")
+ op.drop_column(u"repository", "kind_id")
+ op.drop_table("repositorykind")
diff --git a/data/migrations/versions/b547bc139ad8_add_robotaccountmetadata_table.py b/data/migrations/versions/b547bc139ad8_add_robotaccountmetadata_table.py
index 1d26fa2d9..529a95cb3 100644
--- a/data/migrations/versions/b547bc139ad8_add_robotaccountmetadata_table.py
+++ b/data/migrations/versions/b547bc139ad8_add_robotaccountmetadata_table.py
@@ -7,8 +7,8 @@ Create Date: 2018-03-09 15:50:48.298880
"""
# revision identifiers, used by Alembic.
-revision = 'b547bc139ad8'
-down_revision = '0cf50323c78b'
+revision = "b547bc139ad8"
+down_revision = "0cf50323c78b"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -19,28 +19,41 @@ from util.migrate import UTF8CharField
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('robotaccountmetadata',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('robot_account_id', sa.Integer(), nullable=False),
- sa.Column('description', UTF8CharField(length=255), nullable=False),
- sa.Column('unstructured_json', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['robot_account_id'], ['user.id'], name=op.f('fk_robotaccountmetadata_robot_account_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_robotaccountmetadata'))
+ op.create_table(
+ "robotaccountmetadata",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("robot_account_id", sa.Integer(), nullable=False),
+ sa.Column("description", UTF8CharField(length=255), nullable=False),
+ sa.Column("unstructured_json", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["robot_account_id"],
+ ["user.id"],
+ name=op.f("fk_robotaccountmetadata_robot_account_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_robotaccountmetadata")),
+ )
+ op.create_index(
+ "robotaccountmetadata_robot_account_id",
+ "robotaccountmetadata",
+ ["robot_account_id"],
+ unique=True,
)
- op.create_index('robotaccountmetadata_robot_account_id', 'robotaccountmetadata', ['robot_account_id'], unique=True)
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_table('robotaccountmetadata', [
- ('robot_account_id', tester.TestDataType.Foreign('user')),
- ('description', tester.TestDataType.UTF8Char),
- ('unstructured_json', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "robotaccountmetadata",
+ [
+ ("robot_account_id", tester.TestDataType.Foreign("user")),
+ ("description", tester.TestDataType.UTF8Char),
+ ("unstructured_json", tester.TestDataType.JSON),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('robotaccountmetadata')
+ op.drop_table("robotaccountmetadata")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/b8ae68ad3e52_change_blobupload_fields_to_bigintegers_.py b/data/migrations/versions/b8ae68ad3e52_change_blobupload_fields_to_bigintegers_.py
index d76c8e018..64278a27d 100644
--- a/data/migrations/versions/b8ae68ad3e52_change_blobupload_fields_to_bigintegers_.py
+++ b/data/migrations/versions/b8ae68ad3e52_change_blobupload_fields_to_bigintegers_.py
@@ -7,31 +7,50 @@ Create Date: 2017-02-27 11:26:49.182349
"""
# revision identifiers, used by Alembic.
-revision = 'b8ae68ad3e52'
-down_revision = '7a525c68eb13'
+revision = "b8ae68ad3e52"
+down_revision = "7a525c68eb13"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.alter_column('blobupload', 'byte_count', existing_type=sa.Integer(), type_=sa.BigInteger())
- op.alter_column('blobupload', 'uncompressed_byte_count', existing_type=sa.Integer(), type_=sa.BigInteger())
+ op.alter_column(
+ "blobupload", "byte_count", existing_type=sa.Integer(), type_=sa.BigInteger()
+ )
+ op.alter_column(
+ "blobupload",
+ "uncompressed_byte_count",
+ existing_type=sa.Integer(),
+ type_=sa.BigInteger(),
+ )
# ### population of test data ### #
- tester.populate_column('blobupload', 'byte_count', tester.TestDataType.BigInteger)
- tester.populate_column('blobupload', 'uncompressed_byte_count', tester.TestDataType.BigInteger)
+ tester.populate_column("blobupload", "byte_count", tester.TestDataType.BigInteger)
+ tester.populate_column(
+ "blobupload", "uncompressed_byte_count", tester.TestDataType.BigInteger
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### population of test data ### #
- tester.populate_column('blobupload', 'byte_count', tester.TestDataType.Integer)
- tester.populate_column('blobupload', 'uncompressed_byte_count', tester.TestDataType.Integer)
+ tester.populate_column("blobupload", "byte_count", tester.TestDataType.Integer)
+ tester.populate_column(
+ "blobupload", "uncompressed_byte_count", tester.TestDataType.Integer
+ )
# ### end population of test data ### #
- op.alter_column('blobupload', 'byte_count', existing_type=sa.BigInteger(), type_=sa.Integer())
- op.alter_column('blobupload', 'uncompressed_byte_count', existing_type=sa.BigInteger(), type_=sa.Integer())
+ op.alter_column(
+ "blobupload", "byte_count", existing_type=sa.BigInteger(), type_=sa.Integer()
+ )
+ op.alter_column(
+ "blobupload",
+ "uncompressed_byte_count",
+ existing_type=sa.BigInteger(),
+ type_=sa.Integer(),
+ )
diff --git a/data/migrations/versions/b9045731c4de_add_lifetime_indexes_to_tag_tables.py b/data/migrations/versions/b9045731c4de_add_lifetime_indexes_to_tag_tables.py
index b85ae3514..59448f2e8 100644
--- a/data/migrations/versions/b9045731c4de_add_lifetime_indexes_to_tag_tables.py
+++ b/data/migrations/versions/b9045731c4de_add_lifetime_indexes_to_tag_tables.py
@@ -7,29 +7,54 @@ Create Date: 2019-02-14 17:18:40.474310
"""
# revision identifiers, used by Alembic.
-revision = 'b9045731c4de'
-down_revision = 'e184af42242d'
+revision = "b9045731c4de"
+down_revision = "e184af42242d"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_index('repositorytag_repository_id_lifetime_end_ts', 'repositorytag', ['repository_id', 'lifetime_end_ts'], unique=False)
- op.create_index('tag_repository_id_lifetime_end_ms', 'tag', ['repository_id', 'lifetime_end_ms'], unique=False)
+ op.create_index(
+ "repositorytag_repository_id_lifetime_end_ts",
+ "repositorytag",
+ ["repository_id", "lifetime_end_ts"],
+ unique=False,
+ )
+ op.create_index(
+ "tag_repository_id_lifetime_end_ms",
+ "tag",
+ ["repository_id", "lifetime_end_ms"],
+ unique=False,
+ )
- op.create_index('repositorytag_repository_id_lifetime_start_ts', 'repositorytag', ['repository_id', 'lifetime_start_ts'], unique=False)
- op.create_index('tag_repository_id_lifetime_start_ms', 'tag', ['repository_id', 'lifetime_start_ms'], unique=False)
+ op.create_index(
+ "repositorytag_repository_id_lifetime_start_ts",
+ "repositorytag",
+ ["repository_id", "lifetime_start_ts"],
+ unique=False,
+ )
+ op.create_index(
+ "tag_repository_id_lifetime_start_ms",
+ "tag",
+ ["repository_id", "lifetime_start_ms"],
+ unique=False,
+ )
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('tag_repository_id_lifetime_end_ms', table_name='tag')
- op.drop_index('repositorytag_repository_id_lifetime_end_ts', table_name='repositorytag')
+ op.drop_index("tag_repository_id_lifetime_end_ms", table_name="tag")
+ op.drop_index(
+ "repositorytag_repository_id_lifetime_end_ts", table_name="repositorytag"
+ )
- op.drop_index('tag_repository_id_lifetime_start_ms', table_name='tag')
- op.drop_index('repositorytag_repository_id_lifetime_start_ts', table_name='repositorytag')
+ op.drop_index("tag_repository_id_lifetime_start_ms", table_name="tag")
+ op.drop_index(
+ "repositorytag_repository_id_lifetime_start_ts", table_name="repositorytag"
+ )
# ### end Alembic commands ###
diff --git a/data/migrations/versions/b918abdbee43_run_full_tag_backfill.py b/data/migrations/versions/b918abdbee43_run_full_tag_backfill.py
index 3968abd32..267301ff3 100644
--- a/data/migrations/versions/b918abdbee43_run_full_tag_backfill.py
+++ b/data/migrations/versions/b918abdbee43_run_full_tag_backfill.py
@@ -7,8 +7,8 @@ Create Date: 2019-03-14 13:38:03.411609
"""
# revision identifiers, used by Alembic.
-revision = 'b918abdbee43'
-down_revision = '481623ba00ba'
+revision = "b918abdbee43"
+down_revision = "481623ba00ba"
import logging.config
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
def upgrade(tables, tester, progress_reporter):
- if not app.config.get('SETUP_COMPLETE', False):
+ if not app.config.get("SETUP_COMPLETE", False):
return
start_id = 0
@@ -40,16 +40,19 @@ def upgrade(tables, tester, progress_reporter):
if start_id > max_id:
break
- logger.info('Checking tag range %s - %s', start_id, end_id)
- r = list(RepositoryTag
- .select()
- .join(Repository)
- .switch(RepositoryTag)
- .join(TagToRepositoryTag, JOIN.LEFT_OUTER)
- .where(TagToRepositoryTag.id >> None)
- .where(RepositoryTag.hidden == False,
- RepositoryTag.id >= start_id,
- RepositoryTag.id < end_id))
+ logger.info("Checking tag range %s - %s", start_id, end_id)
+ r = list(
+ RepositoryTag.select()
+ .join(Repository)
+ .switch(RepositoryTag)
+ .join(TagToRepositoryTag, JOIN.LEFT_OUTER)
+ .where(TagToRepositoryTag.id >> None)
+ .where(
+ RepositoryTag.hidden == False,
+ RepositoryTag.id >= start_id,
+ RepositoryTag.id < end_id,
+ )
+ )
if len(r) < 1000 and size < 100000:
size *= 2
@@ -60,7 +63,7 @@ def upgrade(tables, tester, progress_reporter):
if not len(r):
continue
- logger.info('Found %s tags to backfill', len(r))
+ logger.info("Found %s tags to backfill", len(r))
for index, t in enumerate(r):
logger.info("Backfilling tag %s of %s", index, len(r))
backfill_tag(t)
diff --git a/data/migrations/versions/be8d1c402ce0_add_teamsync_table.py b/data/migrations/versions/be8d1c402ce0_add_teamsync_table.py
index 62c0aba44..52c829d89 100644
--- a/data/migrations/versions/be8d1c402ce0_add_teamsync_table.py
+++ b/data/migrations/versions/be8d1c402ce0_add_teamsync_table.py
@@ -7,46 +7,57 @@ Create Date: 2017-02-23 13:34:52.356812
"""
# revision identifiers, used by Alembic.
-revision = 'be8d1c402ce0'
-down_revision = 'a6c463dfb9fe'
+revision = "be8d1c402ce0"
+down_revision = "a6c463dfb9fe"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from util.migrate import UTF8LongText
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.create_table('teamsync',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('team_id', sa.Integer(), nullable=False),
- sa.Column('transaction_id', sa.String(length=255), nullable=False),
- sa.Column('last_updated', sa.DateTime(), nullable=True),
- sa.Column('service_id', sa.Integer(), nullable=False),
- sa.Column('config', UTF8LongText(), nullable=False),
- sa.ForeignKeyConstraint(['service_id'], ['loginservice.id'], name=op.f('fk_teamsync_service_id_loginservice')),
- sa.ForeignKeyConstraint(['team_id'], ['team.id'], name=op.f('fk_teamsync_team_id_team')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_teamsync'))
+ op.create_table(
+ "teamsync",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("team_id", sa.Integer(), nullable=False),
+ sa.Column("transaction_id", sa.String(length=255), nullable=False),
+ sa.Column("last_updated", sa.DateTime(), nullable=True),
+ sa.Column("service_id", sa.Integer(), nullable=False),
+ sa.Column("config", UTF8LongText(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["service_id"],
+ ["loginservice.id"],
+ name=op.f("fk_teamsync_service_id_loginservice"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["team_id"], ["team.id"], name=op.f("fk_teamsync_team_id_team")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_teamsync")),
)
- op.create_index('teamsync_last_updated', 'teamsync', ['last_updated'], unique=False)
- op.create_index('teamsync_service_id', 'teamsync', ['service_id'], unique=False)
- op.create_index('teamsync_team_id', 'teamsync', ['team_id'], unique=True)
+ op.create_index("teamsync_last_updated", "teamsync", ["last_updated"], unique=False)
+ op.create_index("teamsync_service_id", "teamsync", ["service_id"], unique=False)
+ op.create_index("teamsync_team_id", "teamsync", ["team_id"], unique=True)
### end Alembic commands ###
# ### population of test data ### #
- tester.populate_table('teamsync', [
- ('team_id', tester.TestDataType.Foreign('team')),
- ('transaction_id', tester.TestDataType.String),
- ('last_updated', tester.TestDataType.DateTime),
- ('service_id', tester.TestDataType.Foreign('loginservice')),
- ('config', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "teamsync",
+ [
+ ("team_id", tester.TestDataType.Foreign("team")),
+ ("transaction_id", tester.TestDataType.String),
+ ("last_updated", tester.TestDataType.DateTime),
+ ("service_id", tester.TestDataType.Foreign("loginservice")),
+ ("config", tester.TestDataType.JSON),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.drop_table('teamsync')
+ op.drop_table("teamsync")
### end Alembic commands ###
diff --git a/data/migrations/versions/c00a1f15968b_add_schema2_media_types.py b/data/migrations/versions/c00a1f15968b_add_schema2_media_types.py
index 2d2a050df..71e2b6d3b 100644
--- a/data/migrations/versions/c00a1f15968b_add_schema2_media_types.py
+++ b/data/migrations/versions/c00a1f15968b_add_schema2_media_types.py
@@ -9,26 +9,24 @@ Create Date: 2018-11-13 09:20:21.968503
"""
# revision identifiers, used by Alembic.
-revision = 'c00a1f15968b'
-down_revision = '67f0abd172ae'
+revision = "c00a1f15968b"
+down_revision = "67f0abd172ae"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
for media_type in DOCKER_SCHEMA2_CONTENT_TYPES:
- op.bulk_insert(tables.mediatype,
- [
- {'name': media_type},
- ])
+ op.bulk_insert(tables.mediatype, [{"name": media_type}])
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
for media_type in DOCKER_SCHEMA2_CONTENT_TYPES:
- op.execute(tables
- .mediatype
- .delete()
- .where(tables.
- mediatype.c.name == op.inline_literal(media_type)))
+ op.execute(
+ tables.mediatype.delete().where(
+ tables.mediatype.c.name == op.inline_literal(media_type)
+ )
+ )
diff --git a/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py b/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py
index 4854630bf..c348842aa 100644
--- a/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py
+++ b/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py
@@ -7,8 +7,8 @@ Create Date: 2019-08-19 16:31:00.952773
"""
# revision identifiers, used by Alembic.
-revision = 'c059b952ed76'
-down_revision = '703298a825c2'
+revision = "c059b952ed76"
+down_revision = "703298a825c2"
import uuid
@@ -22,25 +22,26 @@ from data.database import FederatedLogin, User, RobotAccountToken
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('oauthaccesstoken_refresh_token', table_name='oauthaccesstoken')
- op.drop_column(u'oauthaccesstoken', 'refresh_token')
+ op.drop_index("oauthaccesstoken_refresh_token", table_name="oauthaccesstoken")
+ op.drop_column(u"oauthaccesstoken", "refresh_token")
- op.drop_column('accesstoken', 'code')
+ op.drop_column("accesstoken", "code")
- op.drop_column('appspecificauthtoken', 'token_code')
+ op.drop_column("appspecificauthtoken", "token_code")
- op.drop_column('oauthaccesstoken', 'access_token')
- op.drop_column('oauthapplication', 'client_secret')
+ op.drop_column("oauthaccesstoken", "access_token")
+ op.drop_column("oauthapplication", "client_secret")
- op.drop_column('oauthauthorizationcode', 'code')
+ op.drop_column("oauthauthorizationcode", "code")
- op.drop_column('repositorybuildtrigger', 'private_key')
- op.drop_column('repositorybuildtrigger', 'auth_token')
+ op.drop_column("repositorybuildtrigger", "private_key")
+ op.drop_column("repositorybuildtrigger", "auth_token")
# ### end Alembic commands ###
# Overwrite all plaintext robot credentials.
from app import app
- if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
+
+ if app.config.get("SETUP_COMPLETE", False) or tester.is_testing:
while True:
try:
robot_account_token = RobotAccountToken.get(fully_migrated=False)
@@ -50,7 +51,7 @@ def upgrade(tables, tester, progress_reporter):
robot_account.save()
federated_login = FederatedLogin.get(user=robot_account)
- federated_login.service_ident = 'robot:%s' % robot_account.id
+ federated_login.service_ident = "robot:%s" % robot_account.id
federated_login.save()
robot_account_token.fully_migrated = True
@@ -62,23 +63,62 @@ def upgrade(tables, tester, progress_reporter):
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column(u'oauthaccesstoken', sa.Column('refresh_token', sa.String(length=255), nullable=True))
- op.create_index('oauthaccesstoken_refresh_token', 'oauthaccesstoken', ['refresh_token'], unique=False)
+ op.add_column(
+ u"oauthaccesstoken",
+ sa.Column("refresh_token", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "oauthaccesstoken_refresh_token",
+ "oauthaccesstoken",
+ ["refresh_token"],
+ unique=False,
+ )
- op.add_column('repositorybuildtrigger', sa.Column('auth_token', sa.String(length=255), nullable=True))
- op.add_column('repositorybuildtrigger', sa.Column('private_key', sa.Text(), nullable=True))
+ op.add_column(
+ "repositorybuildtrigger",
+ sa.Column("auth_token", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ "repositorybuildtrigger", sa.Column("private_key", sa.Text(), nullable=True)
+ )
- op.add_column('oauthauthorizationcode', sa.Column('code', sa.String(length=255), nullable=True))
- op.create_index('oauthauthorizationcode_code', 'oauthauthorizationcode', ['code'], unique=True)
+ op.add_column(
+ "oauthauthorizationcode",
+ sa.Column("code", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "oauthauthorizationcode_code", "oauthauthorizationcode", ["code"], unique=True
+ )
- op.add_column('oauthapplication', sa.Column('client_secret', sa.String(length=255), nullable=True))
- op.add_column('oauthaccesstoken', sa.Column('access_token', sa.String(length=255), nullable=True))
+ op.add_column(
+ "oauthapplication",
+ sa.Column("client_secret", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ "oauthaccesstoken",
+ sa.Column("access_token", sa.String(length=255), nullable=True),
+ )
- op.create_index('oauthaccesstoken_access_token', 'oauthaccesstoken', ['access_token'], unique=False)
+ op.create_index(
+ "oauthaccesstoken_access_token",
+ "oauthaccesstoken",
+ ["access_token"],
+ unique=False,
+ )
- op.add_column('appspecificauthtoken', sa.Column('token_code', sa.String(length=255), nullable=True))
- op.create_index('appspecificauthtoken_token_code', 'appspecificauthtoken', ['token_code'], unique=True)
+ op.add_column(
+ "appspecificauthtoken",
+ sa.Column("token_code", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "appspecificauthtoken_token_code",
+ "appspecificauthtoken",
+ ["token_code"],
+ unique=True,
+ )
- op.add_column('accesstoken', sa.Column('code', sa.String(length=255), nullable=True))
- op.create_index('accesstoken_code', 'accesstoken', ['code'], unique=True)
+ op.add_column(
+ "accesstoken", sa.Column("code", sa.String(length=255), nullable=True)
+ )
+ op.create_index("accesstoken_code", "accesstoken", ["code"], unique=True)
# ### end Alembic commands ###
diff --git a/data/migrations/versions/c13c8052f7a6_add_new_fields_and_tables_for_encrypted_.py b/data/migrations/versions/c13c8052f7a6_add_new_fields_and_tables_for_encrypted_.py
index 15ecabd00..2e7d5c8d3 100644
--- a/data/migrations/versions/c13c8052f7a6_add_new_fields_and_tables_for_encrypted_.py
+++ b/data/migrations/versions/c13c8052f7a6_add_new_fields_and_tables_for_encrypted_.py
@@ -7,98 +7,176 @@ Create Date: 2019-08-19 15:59:36.269155
"""
# revision identifiers, used by Alembic.
-revision = 'c13c8052f7a6'
-down_revision = '5248ddf35167'
+revision = "c13c8052f7a6"
+down_revision = "5248ddf35167"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('robotaccounttoken',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('robot_account_id', sa.Integer(), nullable=False),
- sa.Column('token', sa.String(length=255), nullable=False),
- sa.Column('fully_migrated', sa.Boolean(), nullable=False, server_default='0'),
- sa.ForeignKeyConstraint(['robot_account_id'], ['user.id'], name=op.f('fk_robotaccounttoken_robot_account_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_robotaccounttoken'))
+ op.create_table(
+ "robotaccounttoken",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("robot_account_id", sa.Integer(), nullable=False),
+ sa.Column("token", sa.String(length=255), nullable=False),
+ sa.Column("fully_migrated", sa.Boolean(), nullable=False, server_default="0"),
+ sa.ForeignKeyConstraint(
+ ["robot_account_id"],
+ ["user.id"],
+ name=op.f("fk_robotaccounttoken_robot_account_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_robotaccounttoken")),
+ )
+ op.create_index(
+ "robotaccounttoken_robot_account_id",
+ "robotaccounttoken",
+ ["robot_account_id"],
+ unique=True,
)
- op.create_index('robotaccounttoken_robot_account_id', 'robotaccounttoken', ['robot_account_id'], unique=True)
- op.add_column(u'accesstoken', sa.Column('token_code', sa.String(length=255), nullable=True))
- op.add_column(u'accesstoken', sa.Column('token_name', sa.String(length=255), nullable=True))
- op.create_index('accesstoken_token_name', 'accesstoken', ['token_name'], unique=True)
+ op.add_column(
+ u"accesstoken", sa.Column("token_code", sa.String(length=255), nullable=True)
+ )
+ op.add_column(
+ u"accesstoken", sa.Column("token_name", sa.String(length=255), nullable=True)
+ )
+ op.create_index(
+ "accesstoken_token_name", "accesstoken", ["token_name"], unique=True
+ )
- op.add_column(u'appspecificauthtoken', sa.Column('token_name', sa.String(length=255), nullable=True))
- op.add_column(u'appspecificauthtoken', sa.Column('token_secret', sa.String(length=255), nullable=True))
- op.create_index('appspecificauthtoken_token_name', 'appspecificauthtoken', ['token_name'], unique=True)
+ op.add_column(
+ u"appspecificauthtoken",
+ sa.Column("token_name", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ u"appspecificauthtoken",
+ sa.Column("token_secret", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "appspecificauthtoken_token_name",
+ "appspecificauthtoken",
+ ["token_name"],
+ unique=True,
+ )
- op.add_column(u'emailconfirmation', sa.Column('verification_code', sa.String(length=255), nullable=True))
+ op.add_column(
+ u"emailconfirmation",
+ sa.Column("verification_code", sa.String(length=255), nullable=True),
+ )
- op.add_column(u'oauthaccesstoken', sa.Column('token_code', sa.String(length=255), nullable=True))
- op.add_column(u'oauthaccesstoken', sa.Column('token_name', sa.String(length=255), nullable=True))
- op.create_index('oauthaccesstoken_token_name', 'oauthaccesstoken', ['token_name'], unique=True)
+ op.add_column(
+ u"oauthaccesstoken",
+ sa.Column("token_code", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ u"oauthaccesstoken",
+ sa.Column("token_name", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "oauthaccesstoken_token_name", "oauthaccesstoken", ["token_name"], unique=True
+ )
- op.add_column(u'oauthapplication', sa.Column('secure_client_secret', sa.String(length=255), nullable=True))
- op.add_column(u'oauthapplication', sa.Column('fully_migrated', sa.Boolean(), server_default='0', nullable=False))
+ op.add_column(
+ u"oauthapplication",
+ sa.Column("secure_client_secret", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ u"oauthapplication",
+ sa.Column("fully_migrated", sa.Boolean(), server_default="0", nullable=False),
+ )
- op.add_column(u'oauthauthorizationcode', sa.Column('code_credential', sa.String(length=255), nullable=True))
- op.add_column(u'oauthauthorizationcode', sa.Column('code_name', sa.String(length=255), nullable=True))
- op.create_index('oauthauthorizationcode_code_name', 'oauthauthorizationcode', ['code_name'], unique=True)
- op.drop_index('oauthauthorizationcode_code', table_name='oauthauthorizationcode')
- op.create_index('oauthauthorizationcode_code', 'oauthauthorizationcode', ['code'], unique=True)
+ op.add_column(
+ u"oauthauthorizationcode",
+ sa.Column("code_credential", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ u"oauthauthorizationcode",
+ sa.Column("code_name", sa.String(length=255), nullable=True),
+ )
+ op.create_index(
+ "oauthauthorizationcode_code_name",
+ "oauthauthorizationcode",
+ ["code_name"],
+ unique=True,
+ )
+ op.drop_index("oauthauthorizationcode_code", table_name="oauthauthorizationcode")
+ op.create_index(
+ "oauthauthorizationcode_code", "oauthauthorizationcode", ["code"], unique=True
+ )
- op.add_column(u'repositorybuildtrigger', sa.Column('secure_auth_token', sa.String(length=255), nullable=True))
- op.add_column(u'repositorybuildtrigger', sa.Column('secure_private_key', sa.Text(), nullable=True))
- op.add_column(u'repositorybuildtrigger', sa.Column('fully_migrated', sa.Boolean(), server_default='0', nullable=False))
+ op.add_column(
+ u"repositorybuildtrigger",
+ sa.Column("secure_auth_token", sa.String(length=255), nullable=True),
+ )
+ op.add_column(
+ u"repositorybuildtrigger",
+ sa.Column("secure_private_key", sa.Text(), nullable=True),
+ )
+ op.add_column(
+ u"repositorybuildtrigger",
+ sa.Column("fully_migrated", sa.Boolean(), server_default="0", nullable=False),
+ )
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_table('robotaccounttoken', [
- ('robot_account_id', tester.TestDataType.Foreign('user')),
- ('token', tester.TestDataType.Token),
- ('fully_migrated', tester.TestDataType.Boolean),
- ])
-
- tester.populate_column('accesstoken', 'code', tester.TestDataType.Token)
+ tester.populate_table(
+ "robotaccounttoken",
+ [
+ ("robot_account_id", tester.TestDataType.Foreign("user")),
+ ("token", tester.TestDataType.Token),
+ ("fully_migrated", tester.TestDataType.Boolean),
+ ],
+ )
- tester.populate_column('appspecificauthtoken', 'token_code', tester.TestDataType.Token)
+ tester.populate_column("accesstoken", "code", tester.TestDataType.Token)
- tester.populate_column('emailconfirmation', 'verification_code', tester.TestDataType.Token)
+ tester.populate_column(
+ "appspecificauthtoken", "token_code", tester.TestDataType.Token
+ )
- tester.populate_column('oauthaccesstoken', 'token_code', tester.TestDataType.Token)
+ tester.populate_column(
+ "emailconfirmation", "verification_code", tester.TestDataType.Token
+ )
+
+ tester.populate_column("oauthaccesstoken", "token_code", tester.TestDataType.Token)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column(u'repositorybuildtrigger', 'secure_private_key')
- op.drop_column(u'repositorybuildtrigger', 'secure_auth_token')
+ op.drop_column(u"repositorybuildtrigger", "secure_private_key")
+ op.drop_column(u"repositorybuildtrigger", "secure_auth_token")
- op.drop_index('oauthauthorizationcode_code', table_name='oauthauthorizationcode')
- op.create_index('oauthauthorizationcode_code', 'oauthauthorizationcode', ['code'], unique=False)
- op.drop_index('oauthauthorizationcode_code_name', table_name='oauthauthorizationcode')
- op.drop_column(u'oauthauthorizationcode', 'code_name')
- op.drop_column(u'oauthauthorizationcode', 'code_credential')
+ op.drop_index("oauthauthorizationcode_code", table_name="oauthauthorizationcode")
+ op.create_index(
+ "oauthauthorizationcode_code", "oauthauthorizationcode", ["code"], unique=False
+ )
+ op.drop_index(
+ "oauthauthorizationcode_code_name", table_name="oauthauthorizationcode"
+ )
+ op.drop_column(u"oauthauthorizationcode", "code_name")
+ op.drop_column(u"oauthauthorizationcode", "code_credential")
- op.drop_column(u'oauthapplication', 'secure_client_secret')
+ op.drop_column(u"oauthapplication", "secure_client_secret")
- op.drop_index('oauthaccesstoken_token_name', table_name='oauthaccesstoken')
- op.drop_column(u'oauthaccesstoken', 'token_name')
- op.drop_column(u'oauthaccesstoken', 'token_code')
+ op.drop_index("oauthaccesstoken_token_name", table_name="oauthaccesstoken")
+ op.drop_column(u"oauthaccesstoken", "token_name")
+ op.drop_column(u"oauthaccesstoken", "token_code")
- op.drop_column(u'emailconfirmation', 'verification_code')
+ op.drop_column(u"emailconfirmation", "verification_code")
- op.drop_index('appspecificauthtoken_token_name', table_name='appspecificauthtoken')
- op.drop_column(u'appspecificauthtoken', 'token_secret')
- op.drop_column(u'appspecificauthtoken', 'token_name')
+ op.drop_index("appspecificauthtoken_token_name", table_name="appspecificauthtoken")
+ op.drop_column(u"appspecificauthtoken", "token_secret")
+ op.drop_column(u"appspecificauthtoken", "token_name")
- op.drop_index('accesstoken_token_name', table_name='accesstoken')
- op.drop_column(u'accesstoken', 'token_name')
- op.drop_column(u'accesstoken', 'token_code')
+ op.drop_index("accesstoken_token_name", table_name="accesstoken")
+ op.drop_column(u"accesstoken", "token_name")
+ op.drop_column(u"accesstoken", "token_code")
- op.drop_table('robotaccounttoken')
+ op.drop_table("robotaccounttoken")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/c156deb8845d_reset_our_migrations_with_a_required_.py b/data/migrations/versions/c156deb8845d_reset_our_migrations_with_a_required_.py
index 3277f5ae6..0158dcb61 100644
--- a/data/migrations/versions/c156deb8845d_reset_our_migrations_with_a_required_.py
+++ b/data/migrations/versions/c156deb8845d_reset_our_migrations_with_a_required_.py
@@ -7,7 +7,7 @@ Create Date: 2016-11-08 11:58:11.110762
"""
# revision identifiers, used by Alembic.
-revision = 'c156deb8845d'
+revision = "c156deb8845d"
down_revision = None
from alembic import op as original_op
@@ -16,1239 +16,2149 @@ import sqlalchemy as sa
from util.migrate import UTF8LongText, UTF8CharField
from datetime import datetime
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
now = datetime.now().strftime("'%Y-%m-%d %H:%M:%S'")
- op.create_table('accesstokenkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_accesstokenkind'))
+ op.create_table(
+ "accesstokenkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_accesstokenkind")),
)
- op.create_index('accesstokenkind_name', 'accesstokenkind', ['name'], unique=True)
- op.create_table('buildtriggerservice',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_buildtriggerservice'))
+ op.create_index("accesstokenkind_name", "accesstokenkind", ["name"], unique=True)
+ op.create_table(
+ "buildtriggerservice",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_buildtriggerservice")),
)
- op.create_index('buildtriggerservice_name', 'buildtriggerservice', ['name'], unique=True)
- op.create_table('externalnotificationevent',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_externalnotificationevent'))
+ op.create_index(
+ "buildtriggerservice_name", "buildtriggerservice", ["name"], unique=True
)
- op.create_index('externalnotificationevent_name', 'externalnotificationevent', ['name'], unique=True)
- op.create_table('externalnotificationmethod',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_externalnotificationmethod'))
+ op.create_table(
+ "externalnotificationevent",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_externalnotificationevent")),
)
- op.create_index('externalnotificationmethod_name', 'externalnotificationmethod', ['name'], unique=True)
- op.create_table('imagestorage',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('checksum', sa.String(length=255), nullable=True),
- sa.Column('image_size', sa.BigInteger(), nullable=True),
- sa.Column('uncompressed_size', sa.BigInteger(), nullable=True),
- sa.Column('uploading', sa.Boolean(), nullable=True),
- sa.Column('cas_path', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('content_checksum', sa.String(length=255), nullable=True),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestorage'))
+ op.create_index(
+ "externalnotificationevent_name",
+ "externalnotificationevent",
+ ["name"],
+ unique=True,
)
- op.create_index('imagestorage_content_checksum', 'imagestorage', ['content_checksum'], unique=False)
- op.create_index('imagestorage_uuid', 'imagestorage', ['uuid'], unique=True)
- op.create_table('imagestoragelocation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestoragelocation'))
+ op.create_table(
+ "externalnotificationmethod",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_externalnotificationmethod")),
)
- op.create_index('imagestoragelocation_name', 'imagestoragelocation', ['name'], unique=True)
- op.create_table('imagestoragesignaturekind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestoragesignaturekind'))
+ op.create_index(
+ "externalnotificationmethod_name",
+ "externalnotificationmethod",
+ ["name"],
+ unique=True,
)
- op.create_index('imagestoragesignaturekind_name', 'imagestoragesignaturekind', ['name'], unique=True)
- op.create_table('imagestoragetransformation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestoragetransformation'))
+ op.create_table(
+ "imagestorage",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("checksum", sa.String(length=255), nullable=True),
+ sa.Column("image_size", sa.BigInteger(), nullable=True),
+ sa.Column("uncompressed_size", sa.BigInteger(), nullable=True),
+ sa.Column("uploading", sa.Boolean(), nullable=True),
+ sa.Column(
+ "cas_path",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column("content_checksum", sa.String(length=255), nullable=True),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestorage")),
)
- op.create_index('imagestoragetransformation_name', 'imagestoragetransformation', ['name'], unique=True)
- op.create_table('labelsourcetype',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('mutable', sa.Boolean(), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_labelsourcetype'))
+ op.create_index(
+ "imagestorage_content_checksum",
+ "imagestorage",
+ ["content_checksum"],
+ unique=False,
)
- op.create_index('labelsourcetype_name', 'labelsourcetype', ['name'], unique=True)
- op.create_table('logentrykind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_logentrykind'))
+ op.create_index("imagestorage_uuid", "imagestorage", ["uuid"], unique=True)
+ op.create_table(
+ "imagestoragelocation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestoragelocation")),
)
- op.create_index('logentrykind_name', 'logentrykind', ['name'], unique=True)
- op.create_table('loginservice',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_loginservice'))
+ op.create_index(
+ "imagestoragelocation_name", "imagestoragelocation", ["name"], unique=True
)
- op.create_index('loginservice_name', 'loginservice', ['name'], unique=True)
- op.create_table('mediatype',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_mediatype'))
+ op.create_table(
+ "imagestoragesignaturekind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestoragesignaturekind")),
)
- op.create_index('mediatype_name', 'mediatype', ['name'], unique=True)
- op.create_table('messages',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('content', sa.Text(), nullable=False),
- sa.Column('uuid', sa.String(length=36), nullable=True),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
+ op.create_index(
+ "imagestoragesignaturekind_name",
+ "imagestoragesignaturekind",
+ ["name"],
+ unique=True,
)
- op.create_table('notificationkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_notificationkind'))
+ op.create_table(
+ "imagestoragetransformation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestoragetransformation")),
)
- op.create_index('notificationkind_name', 'notificationkind', ['name'], unique=True)
- op.create_table('quayregion',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_quayregion'))
+ op.create_index(
+ "imagestoragetransformation_name",
+ "imagestoragetransformation",
+ ["name"],
+ unique=True,
)
- op.create_index('quayregion_name', 'quayregion', ['name'], unique=True)
- op.create_table('quayservice',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_quayservice'))
+ op.create_table(
+ "labelsourcetype",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("mutable", sa.Boolean(), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_labelsourcetype")),
)
- op.create_index('quayservice_name', 'quayservice', ['name'], unique=True)
- op.create_table('queueitem',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('queue_name', sa.String(length=1024), nullable=False),
- sa.Column('body', sa.Text(), nullable=False),
- sa.Column('available_after', sa.DateTime(), nullable=False),
- sa.Column('available', sa.Boolean(), nullable=False),
- sa.Column('processing_expires', sa.DateTime(), nullable=True),
- sa.Column('retries_remaining', sa.Integer(), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_queueitem'))
+ op.create_index("labelsourcetype_name", "labelsourcetype", ["name"], unique=True)
+ op.create_table(
+ "logentrykind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_logentrykind")),
)
- op.create_index('queueitem_available', 'queueitem', ['available'], unique=False)
- op.create_index('queueitem_available_after', 'queueitem', ['available_after'], unique=False)
- op.create_index('queueitem_processing_expires', 'queueitem', ['processing_expires'], unique=False)
- op.create_index('queueitem_queue_name', 'queueitem', ['queue_name'], unique=False, mysql_length=767)
- op.create_index('queueitem_retries_remaining', 'queueitem', ['retries_remaining'], unique=False)
- op.create_table('role',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_role'))
+ op.create_index("logentrykind_name", "logentrykind", ["name"], unique=True)
+ op.create_table(
+ "loginservice",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_loginservice")),
)
- op.create_index('role_name', 'role', ['name'], unique=True)
- op.create_table('servicekeyapproval',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('approver_id', sa.Integer(), nullable=True),
- sa.Column('approval_type', sa.String(length=255), nullable=False),
- sa.Column('approved_date', sa.DateTime(), nullable=False),
- sa.Column('notes', UTF8LongText(), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_servicekeyapproval'))
+ op.create_index("loginservice_name", "loginservice", ["name"], unique=True)
+ op.create_table(
+ "mediatype",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_mediatype")),
)
- op.create_index('servicekeyapproval_approval_type', 'servicekeyapproval', ['approval_type'], unique=False)
- op.create_index('servicekeyapproval_approver_id', 'servicekeyapproval', ['approver_id'], unique=False)
- op.create_table('teamrole',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_teamrole'))
+ op.create_index("mediatype_name", "mediatype", ["name"], unique=True)
+ op.create_table(
+ "messages",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("content", sa.Text(), nullable=False),
+ sa.Column("uuid", sa.String(length=36), nullable=True),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")),
)
- op.create_index('teamrole_name', 'teamrole', ['name'], unique=False)
- op.create_table('user',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=36), nullable=True),
- sa.Column('username', sa.String(length=255), nullable=False),
- sa.Column('password_hash', sa.String(length=255), nullable=True),
- sa.Column('email', sa.String(length=255), nullable=False),
- sa.Column('verified', sa.Boolean(), nullable=False),
- sa.Column('stripe_id', sa.String(length=255), nullable=True),
- sa.Column('organization', sa.Boolean(), nullable=False),
- sa.Column('robot', sa.Boolean(), nullable=False),
- sa.Column('invoice_email', sa.Boolean(), nullable=False),
- sa.Column('invalid_login_attempts', sa.Integer(), nullable=False, server_default='0'),
- sa.Column('last_invalid_login', sa.DateTime(), nullable=False),
- sa.Column('removed_tag_expiration_s', sa.Integer(), nullable=False, server_default='1209600'),
- sa.Column('enabled', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()),
- sa.Column('invoice_email_address', sa.String(length=255), nullable=True),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_user'))
+ op.create_table(
+ "notificationkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_notificationkind")),
)
- op.create_index('user_email', 'user', ['email'], unique=True)
- op.create_index('user_invoice_email_address', 'user', ['invoice_email_address'], unique=False)
- op.create_index('user_organization', 'user', ['organization'], unique=False)
- op.create_index('user_robot', 'user', ['robot'], unique=False)
- op.create_index('user_stripe_id', 'user', ['stripe_id'], unique=False)
- op.create_index('user_username', 'user', ['username'], unique=True)
- op.create_table('visibility',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_visibility'))
+ op.create_index("notificationkind_name", "notificationkind", ["name"], unique=True)
+ op.create_table(
+ "quayregion",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_quayregion")),
)
- op.create_index('visibility_name', 'visibility', ['name'], unique=True)
- op.create_table('emailconfirmation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('code', sa.String(length=255), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('pw_reset', sa.Boolean(), nullable=False),
- sa.Column('new_email', sa.String(length=255), nullable=True),
- sa.Column('email_confirm', sa.Boolean(), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_emailconfirmation_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_emailconfirmation'))
+ op.create_index("quayregion_name", "quayregion", ["name"], unique=True)
+ op.create_table(
+ "quayservice",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_quayservice")),
)
- op.create_index('emailconfirmation_code', 'emailconfirmation', ['code'], unique=True)
- op.create_index('emailconfirmation_user_id', 'emailconfirmation', ['user_id'], unique=False)
- op.create_table('federatedlogin',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('service_id', sa.Integer(), nullable=False),
- sa.Column('service_ident', sa.String(length=255), nullable=False),
- sa.Column('metadata_json', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['service_id'], ['loginservice.id'], name=op.f('fk_federatedlogin_service_id_loginservice')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_federatedlogin_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_federatedlogin'))
+ op.create_index("quayservice_name", "quayservice", ["name"], unique=True)
+ op.create_table(
+ "queueitem",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("queue_name", sa.String(length=1024), nullable=False),
+ sa.Column("body", sa.Text(), nullable=False),
+ sa.Column("available_after", sa.DateTime(), nullable=False),
+ sa.Column("available", sa.Boolean(), nullable=False),
+ sa.Column("processing_expires", sa.DateTime(), nullable=True),
+ sa.Column("retries_remaining", sa.Integer(), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_queueitem")),
)
- op.create_index('federatedlogin_service_id', 'federatedlogin', ['service_id'], unique=False)
- op.create_index('federatedlogin_service_id_service_ident', 'federatedlogin', ['service_id', 'service_ident'], unique=True)
- op.create_index('federatedlogin_service_id_user_id', 'federatedlogin', ['service_id', 'user_id'], unique=True)
- op.create_index('federatedlogin_user_id', 'federatedlogin', ['user_id'], unique=False)
- op.create_table('imagestorageplacement',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('storage_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['location_id'], ['imagestoragelocation.id'], name=op.f('fk_imagestorageplacement_location_id_imagestoragelocation')),
- sa.ForeignKeyConstraint(['storage_id'], ['imagestorage.id'], name=op.f('fk_imagestorageplacement_storage_id_imagestorage')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestorageplacement'))
+ op.create_index("queueitem_available", "queueitem", ["available"], unique=False)
+ op.create_index(
+ "queueitem_available_after", "queueitem", ["available_after"], unique=False
)
- op.create_index('imagestorageplacement_location_id', 'imagestorageplacement', ['location_id'], unique=False)
- op.create_index('imagestorageplacement_storage_id', 'imagestorageplacement', ['storage_id'], unique=False)
- op.create_index('imagestorageplacement_storage_id_location_id', 'imagestorageplacement', ['storage_id', 'location_id'], unique=True)
- op.create_table('imagestoragesignature',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('storage_id', sa.Integer(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.Column('signature', sa.Text(), nullable=True),
- sa.Column('uploading', sa.Boolean(), nullable=True),
- sa.ForeignKeyConstraint(['kind_id'], ['imagestoragesignaturekind.id'], name=op.f('fk_imagestoragesignature_kind_id_imagestoragesignaturekind')),
- sa.ForeignKeyConstraint(['storage_id'], ['imagestorage.id'], name=op.f('fk_imagestoragesignature_storage_id_imagestorage')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_imagestoragesignature'))
+ op.create_index(
+ "queueitem_processing_expires",
+ "queueitem",
+ ["processing_expires"],
+ unique=False,
)
- op.create_index('imagestoragesignature_kind_id', 'imagestoragesignature', ['kind_id'], unique=False)
- op.create_index('imagestoragesignature_kind_id_storage_id', 'imagestoragesignature', ['kind_id', 'storage_id'], unique=True)
- op.create_index('imagestoragesignature_storage_id', 'imagestoragesignature', ['storage_id'], unique=False)
- op.create_table('label',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('key', UTF8CharField(length=255), nullable=False),
- sa.Column('value', UTF8LongText(), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('source_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_label_media_type_id_mediatype')),
- sa.ForeignKeyConstraint(['source_type_id'], ['labelsourcetype.id'], name=op.f('fk_label_source_type_id_labelsourcetype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_label'))
+ op.create_index(
+ "queueitem_queue_name",
+ "queueitem",
+ ["queue_name"],
+ unique=False,
+ mysql_length=767,
)
- op.create_index('label_key', 'label', ['key'], unique=False)
- op.create_index('label_media_type_id', 'label', ['media_type_id'], unique=False)
- op.create_index('label_source_type_id', 'label', ['source_type_id'], unique=False)
- op.create_index('label_uuid', 'label', ['uuid'], unique=True)
- op.create_table('logentry',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.Column('account_id', sa.Integer(), nullable=False),
- sa.Column('performer_id', sa.Integer(), nullable=True),
- sa.Column('repository_id', sa.Integer(), nullable=True),
- sa.Column('datetime', sa.DateTime(), nullable=False),
- sa.Column('ip', sa.String(length=255), nullable=True),
- sa.Column('metadata_json', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['kind_id'], ['logentrykind.id'], name=op.f('fk_logentry_kind_id_logentrykind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_logentry'))
+ op.create_index(
+ "queueitem_retries_remaining", "queueitem", ["retries_remaining"], unique=False
)
- op.create_index('logentry_account_id', 'logentry', ['account_id'], unique=False)
- op.create_index('logentry_account_id_datetime', 'logentry', ['account_id', 'datetime'], unique=False)
- op.create_index('logentry_datetime', 'logentry', ['datetime'], unique=False)
- op.create_index('logentry_kind_id', 'logentry', ['kind_id'], unique=False)
- op.create_index('logentry_performer_id', 'logentry', ['performer_id'], unique=False)
- op.create_index('logentry_performer_id_datetime', 'logentry', ['performer_id', 'datetime'], unique=False)
- op.create_index('logentry_repository_id', 'logentry', ['repository_id'], unique=False)
- op.create_index('logentry_repository_id_datetime', 'logentry', ['repository_id', 'datetime'], unique=False)
- op.create_index('logentry_repository_id_datetime_kind_id', 'logentry', ['repository_id', 'datetime', 'kind_id'], unique=False)
- op.create_table('notification',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=False),
- sa.Column('target_id', sa.Integer(), nullable=False),
- sa.Column('metadata_json', sa.Text(), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.Column('dismissed', sa.Boolean(), nullable=False),
- sa.Column('lookup_path', sa.String(length=255), nullable=True),
- sa.ForeignKeyConstraint(['kind_id'], ['notificationkind.id'], name=op.f('fk_notification_kind_id_notificationkind')),
- sa.ForeignKeyConstraint(['target_id'], ['user.id'], name=op.f('fk_notification_target_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_notification'))
+ op.create_table(
+ "role",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
)
- op.create_index('notification_created', 'notification', ['created'], unique=False)
- op.create_index('notification_kind_id', 'notification', ['kind_id'], unique=False)
- op.create_index('notification_lookup_path', 'notification', ['lookup_path'], unique=False)
- op.create_index('notification_target_id', 'notification', ['target_id'], unique=False)
- op.create_index('notification_uuid', 'notification', ['uuid'], unique=False)
- op.create_table('oauthapplication',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('client_id', sa.String(length=255), nullable=False),
- sa.Column('client_secret', sa.String(length=255), nullable=False),
- sa.Column('redirect_uri', sa.String(length=255), nullable=False),
- sa.Column('application_uri', sa.String(length=255), nullable=False),
- sa.Column('organization_id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('description', sa.Text(), nullable=False),
- sa.Column('gravatar_email', sa.String(length=255), nullable=True),
- sa.ForeignKeyConstraint(['organization_id'], ['user.id'], name=op.f('fk_oauthapplication_organization_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_oauthapplication'))
+ op.create_index("role_name", "role", ["name"], unique=True)
+ op.create_table(
+ "servicekeyapproval",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("approver_id", sa.Integer(), nullable=True),
+ sa.Column("approval_type", sa.String(length=255), nullable=False),
+ sa.Column("approved_date", sa.DateTime(), nullable=False),
+ sa.Column("notes", UTF8LongText(), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_servicekeyapproval")),
)
- op.create_index('oauthapplication_client_id', 'oauthapplication', ['client_id'], unique=False)
- op.create_index('oauthapplication_organization_id', 'oauthapplication', ['organization_id'], unique=False)
- op.create_table('quayrelease',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('service_id', sa.Integer(), nullable=False),
- sa.Column('version', sa.String(length=255), nullable=False),
- sa.Column('region_id', sa.Integer(), nullable=False),
- sa.Column('reverted', sa.Boolean(), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['region_id'], ['quayregion.id'], name=op.f('fk_quayrelease_region_id_quayregion')),
- sa.ForeignKeyConstraint(['service_id'], ['quayservice.id'], name=op.f('fk_quayrelease_service_id_quayservice')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_quayrelease'))
+ op.create_index(
+ "servicekeyapproval_approval_type",
+ "servicekeyapproval",
+ ["approval_type"],
+ unique=False,
)
- op.create_index('quayrelease_created', 'quayrelease', ['created'], unique=False)
- op.create_index('quayrelease_region_id', 'quayrelease', ['region_id'], unique=False)
- op.create_index('quayrelease_service_id', 'quayrelease', ['service_id'], unique=False)
- op.create_index('quayrelease_service_id_region_id_created', 'quayrelease', ['service_id', 'region_id', 'created'], unique=False)
- op.create_index('quayrelease_service_id_version_region_id', 'quayrelease', ['service_id', 'version', 'region_id'], unique=True)
- op.create_table('repository',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('namespace_user_id', sa.Integer(), nullable=True),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('visibility_id', sa.Integer(), nullable=False),
- sa.Column('description', sa.Text(), nullable=True),
- sa.Column('badge_token', sa.String(length=255), nullable=False),
- sa.ForeignKeyConstraint(['namespace_user_id'], ['user.id'], name=op.f('fk_repository_namespace_user_id_user')),
- sa.ForeignKeyConstraint(['visibility_id'], ['visibility.id'], name=op.f('fk_repository_visibility_id_visibility')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repository'))
+ op.create_index(
+ "servicekeyapproval_approver_id",
+ "servicekeyapproval",
+ ["approver_id"],
+ unique=False,
)
- op.create_index('repository_namespace_user_id', 'repository', ['namespace_user_id'], unique=False)
- op.create_index('repository_namespace_user_id_name', 'repository', ['namespace_user_id', 'name'], unique=True)
- op.create_index('repository_visibility_id', 'repository', ['visibility_id'], unique=False)
- op.create_table('servicekey',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('kid', sa.String(length=255), nullable=False),
- sa.Column('service', sa.String(length=255), nullable=False),
- sa.Column('jwk', UTF8LongText(), nullable=False),
- sa.Column('metadata', UTF8LongText(), nullable=False),
- sa.Column('created_date', sa.DateTime(), nullable=False),
- sa.Column('expiration_date', sa.DateTime(), nullable=True),
- sa.Column('rotation_duration', sa.Integer(), nullable=True),
- sa.Column('approval_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['approval_id'], ['servicekeyapproval.id'], name=op.f('fk_servicekey_approval_id_servicekeyapproval')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_servicekey'))
+ op.create_table(
+ "teamrole",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_teamrole")),
)
- op.create_index('servicekey_approval_id', 'servicekey', ['approval_id'], unique=False)
- op.create_index('servicekey_kid', 'servicekey', ['kid'], unique=True)
- op.create_index('servicekey_service', 'servicekey', ['service'], unique=False)
- op.create_table('team',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('organization_id', sa.Integer(), nullable=False),
- sa.Column('role_id', sa.Integer(), nullable=False),
- sa.Column('description', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['organization_id'], ['user.id'], name=op.f('fk_team_organization_id_user')),
- sa.ForeignKeyConstraint(['role_id'], ['teamrole.id'], name=op.f('fk_team_role_id_teamrole')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_team'))
+ op.create_index("teamrole_name", "teamrole", ["name"], unique=False)
+ op.create_table(
+ "user",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=36), nullable=True),
+ sa.Column("username", sa.String(length=255), nullable=False),
+ sa.Column("password_hash", sa.String(length=255), nullable=True),
+ sa.Column("email", sa.String(length=255), nullable=False),
+ sa.Column("verified", sa.Boolean(), nullable=False),
+ sa.Column("stripe_id", sa.String(length=255), nullable=True),
+ sa.Column("organization", sa.Boolean(), nullable=False),
+ sa.Column("robot", sa.Boolean(), nullable=False),
+ sa.Column("invoice_email", sa.Boolean(), nullable=False),
+ sa.Column(
+ "invalid_login_attempts", sa.Integer(), nullable=False, server_default="0"
+ ),
+ sa.Column("last_invalid_login", sa.DateTime(), nullable=False),
+ sa.Column(
+ "removed_tag_expiration_s",
+ sa.Integer(),
+ nullable=False,
+ server_default="1209600",
+ ),
+ sa.Column(
+ "enabled",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.true(),
+ ),
+ sa.Column("invoice_email_address", sa.String(length=255), nullable=True),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_user")),
)
- op.create_index('team_name', 'team', ['name'], unique=False)
- op.create_index('team_name_organization_id', 'team', ['name', 'organization_id'], unique=True)
- op.create_index('team_organization_id', 'team', ['organization_id'], unique=False)
- op.create_index('team_role_id', 'team', ['role_id'], unique=False)
- op.create_table('torrentinfo',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('storage_id', sa.Integer(), nullable=False),
- sa.Column('piece_length', sa.Integer(), nullable=False),
- sa.Column('pieces', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['storage_id'], ['imagestorage.id'], name=op.f('fk_torrentinfo_storage_id_imagestorage')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_torrentinfo'))
+ op.create_index("user_email", "user", ["email"], unique=True)
+ op.create_index(
+ "user_invoice_email_address", "user", ["invoice_email_address"], unique=False
)
- op.create_index('torrentinfo_storage_id', 'torrentinfo', ['storage_id'], unique=False)
- op.create_index('torrentinfo_storage_id_piece_length', 'torrentinfo', ['storage_id', 'piece_length'], unique=True)
- op.create_table('userregion',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['location_id'], ['imagestoragelocation.id'], name=op.f('fk_userregion_location_id_imagestoragelocation')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_userregion_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_userregion'))
+ op.create_index("user_organization", "user", ["organization"], unique=False)
+ op.create_index("user_robot", "user", ["robot"], unique=False)
+ op.create_index("user_stripe_id", "user", ["stripe_id"], unique=False)
+ op.create_index("user_username", "user", ["username"], unique=True)
+ op.create_table(
+ "visibility",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_visibility")),
)
- op.create_index('userregion_location_id', 'userregion', ['location_id'], unique=False)
- op.create_index('userregion_user_id', 'userregion', ['user_id'], unique=False)
- op.create_table('accesstoken',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('friendly_name', sa.String(length=255), nullable=True),
- sa.Column('code', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.Column('role_id', sa.Integer(), nullable=False),
- sa.Column('temporary', sa.Boolean(), nullable=False),
- sa.Column('kind_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['kind_id'], ['accesstokenkind.id'], name=op.f('fk_accesstoken_kind_id_accesstokenkind')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_accesstoken_repository_id_repository')),
- sa.ForeignKeyConstraint(['role_id'], ['role.id'], name=op.f('fk_accesstoken_role_id_role')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_accesstoken'))
+ op.create_index("visibility_name", "visibility", ["name"], unique=True)
+ op.create_table(
+ "emailconfirmation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("code", sa.String(length=255), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("pw_reset", sa.Boolean(), nullable=False),
+ sa.Column("new_email", sa.String(length=255), nullable=True),
+ sa.Column("email_confirm", sa.Boolean(), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_emailconfirmation_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_emailconfirmation")),
)
- op.create_index('accesstoken_code', 'accesstoken', ['code'], unique=True)
- op.create_index('accesstoken_kind_id', 'accesstoken', ['kind_id'], unique=False)
- op.create_index('accesstoken_repository_id', 'accesstoken', ['repository_id'], unique=False)
- op.create_index('accesstoken_role_id', 'accesstoken', ['role_id'], unique=False)
- op.create_table('blobupload',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('byte_count', sa.Integer(), nullable=False),
- sa.Column('sha_state', sa.Text(), nullable=True),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.Column('storage_metadata', sa.Text(), nullable=True),
- sa.Column('chunk_count', sa.Integer(), nullable=False, server_default='0'),
- sa.Column('uncompressed_byte_count', sa.Integer(), nullable=True),
- sa.Column('created', sa.DateTime(), nullable=False, server_default=sa.text(now)),
- sa.Column('piece_sha_state', UTF8LongText(), nullable=True),
- sa.Column('piece_hashes', UTF8LongText(), nullable=True),
- sa.ForeignKeyConstraint(['location_id'], ['imagestoragelocation.id'], name=op.f('fk_blobupload_location_id_imagestoragelocation')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_blobupload_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobupload'))
+ op.create_index(
+ "emailconfirmation_code", "emailconfirmation", ["code"], unique=True
)
- op.create_index('blobupload_created', 'blobupload', ['created'], unique=False)
- op.create_index('blobupload_location_id', 'blobupload', ['location_id'], unique=False)
- op.create_index('blobupload_repository_id', 'blobupload', ['repository_id'], unique=False)
- op.create_index('blobupload_repository_id_uuid', 'blobupload', ['repository_id', 'uuid'], unique=True)
- op.create_index('blobupload_uuid', 'blobupload', ['uuid'], unique=True)
- op.create_table('image',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('docker_image_id', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('ancestors', sa.String(length=60535), nullable=True),
- sa.Column('storage_id', sa.Integer(), nullable=True),
- sa.Column('created', sa.DateTime(), nullable=True),
- sa.Column('comment', UTF8LongText(), nullable=True),
- sa.Column('command', sa.Text(), nullable=True),
- sa.Column('aggregate_size', sa.BigInteger(), nullable=True),
- sa.Column('v1_json_metadata', UTF8LongText(), nullable=True),
- sa.Column('v1_checksum', sa.String(length=255), nullable=True),
- sa.Column('security_indexed', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('security_indexed_engine', sa.Integer(), nullable=False, server_default='-1'),
- sa.Column('parent_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_image_repository_id_repository')),
- sa.ForeignKeyConstraint(['storage_id'], ['imagestorage.id'], name=op.f('fk_image_storage_id_imagestorage')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_image'))
+ op.create_index(
+ "emailconfirmation_user_id", "emailconfirmation", ["user_id"], unique=False
)
- op.create_index('image_ancestors', 'image', ['ancestors'], unique=False, mysql_length=767)
- op.create_index('image_docker_image_id', 'image', ['docker_image_id'], unique=False)
- op.create_index('image_parent_id', 'image', ['parent_id'], unique=False)
- op.create_index('image_repository_id', 'image', ['repository_id'], unique=False)
- op.create_index('image_repository_id_docker_image_id', 'image', ['repository_id', 'docker_image_id'], unique=True)
- op.create_index('image_security_indexed', 'image', ['security_indexed'], unique=False)
- op.create_index('image_security_indexed_engine', 'image', ['security_indexed_engine'], unique=False)
- op.create_index('image_security_indexed_engine_security_indexed', 'image', ['security_indexed_engine', 'security_indexed'], unique=False)
- op.create_index('image_storage_id', 'image', ['storage_id'], unique=False)
- op.create_table('oauthaccesstoken',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('application_id', sa.Integer(), nullable=False),
- sa.Column('authorized_user_id', sa.Integer(), nullable=False),
- sa.Column('scope', sa.String(length=255), nullable=False),
- sa.Column('access_token', sa.String(length=255), nullable=False),
- sa.Column('token_type', sa.String(length=255), nullable=False),
- sa.Column('expires_at', sa.DateTime(), nullable=False),
- sa.Column('refresh_token', sa.String(length=255), nullable=True),
- sa.Column('data', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['application_id'], ['oauthapplication.id'], name=op.f('fk_oauthaccesstoken_application_id_oauthapplication')),
- sa.ForeignKeyConstraint(['authorized_user_id'], ['user.id'], name=op.f('fk_oauthaccesstoken_authorized_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_oauthaccesstoken'))
+ op.create_table(
+ "federatedlogin",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("service_id", sa.Integer(), nullable=False),
+ sa.Column("service_ident", sa.String(length=255), nullable=False),
+ sa.Column("metadata_json", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["service_id"],
+ ["loginservice.id"],
+ name=op.f("fk_federatedlogin_service_id_loginservice"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_federatedlogin_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_federatedlogin")),
)
- op.create_index('oauthaccesstoken_access_token', 'oauthaccesstoken', ['access_token'], unique=False)
- op.create_index('oauthaccesstoken_application_id', 'oauthaccesstoken', ['application_id'], unique=False)
- op.create_index('oauthaccesstoken_authorized_user_id', 'oauthaccesstoken', ['authorized_user_id'], unique=False)
- op.create_index('oauthaccesstoken_refresh_token', 'oauthaccesstoken', ['refresh_token'], unique=False)
- op.create_index('oauthaccesstoken_uuid', 'oauthaccesstoken', ['uuid'], unique=False)
- op.create_table('oauthauthorizationcode',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('application_id', sa.Integer(), nullable=False),
- sa.Column('code', sa.String(length=255), nullable=False),
- sa.Column('scope', sa.String(length=255), nullable=False),
- sa.Column('data', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['application_id'], ['oauthapplication.id'], name=op.f('fk_oauthauthorizationcode_application_id_oauthapplication')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_oauthauthorizationcode'))
+ op.create_index(
+ "federatedlogin_service_id", "federatedlogin", ["service_id"], unique=False
)
- op.create_index('oauthauthorizationcode_application_id', 'oauthauthorizationcode', ['application_id'], unique=False)
- op.create_index('oauthauthorizationcode_code', 'oauthauthorizationcode', ['code'], unique=False)
- op.create_table('permissionprototype',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('org_id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('activating_user_id', sa.Integer(), nullable=True),
- sa.Column('delegate_user_id', sa.Integer(), nullable=True),
- sa.Column('delegate_team_id', sa.Integer(), nullable=True),
- sa.Column('role_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['activating_user_id'], ['user.id'], name=op.f('fk_permissionprototype_activating_user_id_user')),
- sa.ForeignKeyConstraint(['delegate_team_id'], ['team.id'], name=op.f('fk_permissionprototype_delegate_team_id_team')),
- sa.ForeignKeyConstraint(['delegate_user_id'], ['user.id'], name=op.f('fk_permissionprototype_delegate_user_id_user')),
- sa.ForeignKeyConstraint(['org_id'], ['user.id'], name=op.f('fk_permissionprototype_org_id_user')),
- sa.ForeignKeyConstraint(['role_id'], ['role.id'], name=op.f('fk_permissionprototype_role_id_role')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_permissionprototype'))
+ op.create_index(
+ "federatedlogin_service_id_service_ident",
+ "federatedlogin",
+ ["service_id", "service_ident"],
+ unique=True,
)
- op.create_index('permissionprototype_activating_user_id', 'permissionprototype', ['activating_user_id'], unique=False)
- op.create_index('permissionprototype_delegate_team_id', 'permissionprototype', ['delegate_team_id'], unique=False)
- op.create_index('permissionprototype_delegate_user_id', 'permissionprototype', ['delegate_user_id'], unique=False)
- op.create_index('permissionprototype_org_id', 'permissionprototype', ['org_id'], unique=False)
- op.create_index('permissionprototype_org_id_activating_user_id', 'permissionprototype', ['org_id', 'activating_user_id'], unique=False)
- op.create_index('permissionprototype_role_id', 'permissionprototype', ['role_id'], unique=False)
- op.create_table('repositoryactioncount',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('count', sa.Integer(), nullable=False),
- sa.Column('date', sa.Date(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositoryactioncount_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositoryactioncount'))
+ op.create_index(
+ "federatedlogin_service_id_user_id",
+ "federatedlogin",
+ ["service_id", "user_id"],
+ unique=True,
)
- op.create_index('repositoryactioncount_date', 'repositoryactioncount', ['date'], unique=False)
- op.create_index('repositoryactioncount_repository_id', 'repositoryactioncount', ['repository_id'], unique=False)
- op.create_index('repositoryactioncount_repository_id_date', 'repositoryactioncount', ['repository_id', 'date'], unique=True)
- op.create_table('repositoryauthorizedemail',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('email', sa.String(length=255), nullable=False),
- sa.Column('code', sa.String(length=255), nullable=False),
- sa.Column('confirmed', sa.Boolean(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositoryauthorizedemail_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositoryauthorizedemail'))
+ op.create_index(
+ "federatedlogin_user_id", "federatedlogin", ["user_id"], unique=False
)
- op.create_index('repositoryauthorizedemail_code', 'repositoryauthorizedemail', ['code'], unique=True)
- op.create_index('repositoryauthorizedemail_email_repository_id', 'repositoryauthorizedemail', ['email', 'repository_id'], unique=True)
- op.create_index('repositoryauthorizedemail_repository_id', 'repositoryauthorizedemail', ['repository_id'], unique=False)
- op.create_table('repositorynotification',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('event_id', sa.Integer(), nullable=False),
- sa.Column('method_id', sa.Integer(), nullable=False),
- sa.Column('title', sa.String(length=255), nullable=True),
- sa.Column('config_json', sa.Text(), nullable=False),
- sa.Column('event_config_json', UTF8LongText(), nullable=False),
- sa.ForeignKeyConstraint(['event_id'], ['externalnotificationevent.id'], name=op.f('fk_repositorynotification_event_id_externalnotificationevent')),
- sa.ForeignKeyConstraint(['method_id'], ['externalnotificationmethod.id'], name=op.f('fk_repositorynotification_method_id_externalnotificationmethod')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorynotification_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorynotification'))
+ op.create_table(
+ "imagestorageplacement",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("storage_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["imagestoragelocation.id"],
+ name=op.f("fk_imagestorageplacement_location_id_imagestoragelocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["storage_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_imagestorageplacement_storage_id_imagestorage"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestorageplacement")),
)
- op.create_index('repositorynotification_event_id', 'repositorynotification', ['event_id'], unique=False)
- op.create_index('repositorynotification_method_id', 'repositorynotification', ['method_id'], unique=False)
- op.create_index('repositorynotification_repository_id', 'repositorynotification', ['repository_id'], unique=False)
- op.create_index('repositorynotification_uuid', 'repositorynotification', ['uuid'], unique=False)
- op.create_table('repositorypermission',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('team_id', sa.Integer(), nullable=True),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('role_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorypermission_repository_id_repository')),
- sa.ForeignKeyConstraint(['role_id'], ['role.id'], name=op.f('fk_repositorypermission_role_id_role')),
- sa.ForeignKeyConstraint(['team_id'], ['team.id'], name=op.f('fk_repositorypermission_team_id_team')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_repositorypermission_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorypermission'))
+ op.create_index(
+ "imagestorageplacement_location_id",
+ "imagestorageplacement",
+ ["location_id"],
+ unique=False,
)
- op.create_index('repositorypermission_repository_id', 'repositorypermission', ['repository_id'], unique=False)
- op.create_index('repositorypermission_role_id', 'repositorypermission', ['role_id'], unique=False)
- op.create_index('repositorypermission_team_id', 'repositorypermission', ['team_id'], unique=False)
- op.create_index('repositorypermission_team_id_repository_id', 'repositorypermission', ['team_id', 'repository_id'], unique=True)
- op.create_index('repositorypermission_user_id', 'repositorypermission', ['user_id'], unique=False)
- op.create_index('repositorypermission_user_id_repository_id', 'repositorypermission', ['user_id', 'repository_id'], unique=True)
- op.create_table('star',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('created', sa.DateTime(), nullable=False),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_star_repository_id_repository')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_star_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_star'))
+ op.create_index(
+ "imagestorageplacement_storage_id",
+ "imagestorageplacement",
+ ["storage_id"],
+ unique=False,
)
- op.create_index('star_repository_id', 'star', ['repository_id'], unique=False)
- op.create_index('star_user_id', 'star', ['user_id'], unique=False)
- op.create_index('star_user_id_repository_id', 'star', ['user_id', 'repository_id'], unique=True)
- op.create_table('teammember',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('team_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['team_id'], ['team.id'], name=op.f('fk_teammember_team_id_team')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_teammember_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_teammember'))
+ op.create_index(
+ "imagestorageplacement_storage_id_location_id",
+ "imagestorageplacement",
+ ["storage_id", "location_id"],
+ unique=True,
)
- op.create_index('teammember_team_id', 'teammember', ['team_id'], unique=False)
- op.create_index('teammember_user_id', 'teammember', ['user_id'], unique=False)
- op.create_index('teammember_user_id_team_id', 'teammember', ['user_id', 'team_id'], unique=True)
- op.create_table('teammemberinvite',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('email', sa.String(length=255), nullable=True),
- sa.Column('team_id', sa.Integer(), nullable=False),
- sa.Column('inviter_id', sa.Integer(), nullable=False),
- sa.Column('invite_token', sa.String(length=255), nullable=False),
- sa.ForeignKeyConstraint(['inviter_id'], ['user.id'], name=op.f('fk_teammemberinvite_inviter_id_user')),
- sa.ForeignKeyConstraint(['team_id'], ['team.id'], name=op.f('fk_teammemberinvite_team_id_team')),
- sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_teammemberinvite_user_id_user')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_teammemberinvite'))
+ op.create_table(
+ "imagestoragesignature",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("storage_id", sa.Integer(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.Column("signature", sa.Text(), nullable=True),
+ sa.Column("uploading", sa.Boolean(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["imagestoragesignaturekind.id"],
+ name=op.f("fk_imagestoragesignature_kind_id_imagestoragesignaturekind"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["storage_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_imagestoragesignature_storage_id_imagestorage"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_imagestoragesignature")),
)
- op.create_index('teammemberinvite_inviter_id', 'teammemberinvite', ['inviter_id'], unique=False)
- op.create_index('teammemberinvite_team_id', 'teammemberinvite', ['team_id'], unique=False)
- op.create_index('teammemberinvite_user_id', 'teammemberinvite', ['user_id'], unique=False)
- op.create_table('derivedstorageforimage',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('source_image_id', sa.Integer(), nullable=False),
- sa.Column('derivative_id', sa.Integer(), nullable=False),
- sa.Column('transformation_id', sa.Integer(), nullable=False),
- sa.Column('uniqueness_hash', sa.String(length=255), nullable=True),
- sa.ForeignKeyConstraint(['derivative_id'], ['imagestorage.id'], name=op.f('fk_derivedstorageforimage_derivative_id_imagestorage')),
- sa.ForeignKeyConstraint(['source_image_id'], ['image.id'], name=op.f('fk_derivedstorageforimage_source_image_id_image')),
- sa.ForeignKeyConstraint(['transformation_id'], ['imagestoragetransformation.id'], name=op.f('fk_derivedstorageforimage_transformation_constraint')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_derivedstorageforimage'))
+ op.create_index(
+ "imagestoragesignature_kind_id",
+ "imagestoragesignature",
+ ["kind_id"],
+ unique=False,
)
- op.create_index('derivedstorageforimage_derivative_id', 'derivedstorageforimage', ['derivative_id'], unique=False)
- op.create_index('derivedstorageforimage_source_image_id', 'derivedstorageforimage', ['source_image_id'], unique=False)
- op.create_index('uniqueness_index', 'derivedstorageforimage', ['source_image_id', 'transformation_id', 'uniqueness_hash'], unique=True)
- op.create_index('derivedstorageforimage_transformation_id', 'derivedstorageforimage', ['transformation_id'], unique=False)
- op.create_table('repositorybuildtrigger',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('service_id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('connected_user_id', sa.Integer(), nullable=False),
- sa.Column('auth_token', sa.String(length=255), nullable=True),
- sa.Column('private_key', sa.Text(), nullable=True),
- sa.Column('config', sa.Text(), nullable=False),
- sa.Column('write_token_id', sa.Integer(), nullable=True),
- sa.Column('pull_robot_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['connected_user_id'], ['user.id'], name=op.f('fk_repositorybuildtrigger_connected_user_id_user')),
- sa.ForeignKeyConstraint(['pull_robot_id'], ['user.id'], name=op.f('fk_repositorybuildtrigger_pull_robot_id_user')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorybuildtrigger_repository_id_repository')),
- sa.ForeignKeyConstraint(['service_id'], ['buildtriggerservice.id'], name=op.f('fk_repositorybuildtrigger_service_id_buildtriggerservice')),
- sa.ForeignKeyConstraint(['write_token_id'], ['accesstoken.id'], name=op.f('fk_repositorybuildtrigger_write_token_id_accesstoken')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorybuildtrigger'))
+ op.create_index(
+ "imagestoragesignature_kind_id_storage_id",
+ "imagestoragesignature",
+ ["kind_id", "storage_id"],
+ unique=True,
)
- op.create_index('repositorybuildtrigger_connected_user_id', 'repositorybuildtrigger', ['connected_user_id'], unique=False)
- op.create_index('repositorybuildtrigger_pull_robot_id', 'repositorybuildtrigger', ['pull_robot_id'], unique=False)
- op.create_index('repositorybuildtrigger_repository_id', 'repositorybuildtrigger', ['repository_id'], unique=False)
- op.create_index('repositorybuildtrigger_service_id', 'repositorybuildtrigger', ['service_id'], unique=False)
- op.create_index('repositorybuildtrigger_write_token_id', 'repositorybuildtrigger', ['write_token_id'], unique=False)
- op.create_table('repositorytag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.Column('image_id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('lifetime_start_ts', sa.Integer(), nullable=False, server_default='0'),
- sa.Column('lifetime_end_ts', sa.Integer(), nullable=True),
- sa.Column('hidden', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('reversion', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.ForeignKeyConstraint(['image_id'], ['image.id'], name=op.f('fk_repositorytag_image_id_image')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorytag_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorytag'))
+ op.create_index(
+ "imagestoragesignature_storage_id",
+ "imagestoragesignature",
+ ["storage_id"],
+ unique=False,
)
- op.create_index('repositorytag_image_id', 'repositorytag', ['image_id'], unique=False)
- op.create_index('repositorytag_lifetime_end_ts', 'repositorytag', ['lifetime_end_ts'], unique=False)
- op.create_index('repositorytag_repository_id', 'repositorytag', ['repository_id'], unique=False)
- op.create_index('repositorytag_repository_id_name', 'repositorytag', ['repository_id', 'name'], unique=False)
- op.create_index('repositorytag_repository_id_name_lifetime_end_ts', 'repositorytag', ['repository_id', 'name', 'lifetime_end_ts'], unique=True)
- op.create_table('repositorybuild',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('uuid', sa.String(length=255), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('access_token_id', sa.Integer(), nullable=False),
- sa.Column('resource_key', sa.String(length=255), nullable=True),
- sa.Column('job_config', sa.Text(), nullable=False),
- sa.Column('phase', sa.String(length=255), nullable=False),
- sa.Column('started', sa.DateTime(), nullable=False),
- sa.Column('display_name', sa.String(length=255), nullable=False),
- sa.Column('trigger_id', sa.Integer(), nullable=True),
- sa.Column('pull_robot_id', sa.Integer(), nullable=True),
- sa.Column('logs_archived', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()),
- sa.Column('queue_id', sa.String(length=255), nullable=True),
- sa.ForeignKeyConstraint(['access_token_id'], ['accesstoken.id'], name=op.f('fk_repositorybuild_access_token_id_accesstoken')),
- sa.ForeignKeyConstraint(['pull_robot_id'], ['user.id'], name=op.f('fk_repositorybuild_pull_robot_id_user')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorybuild_repository_id_repository')),
- sa.ForeignKeyConstraint(['trigger_id'], ['repositorybuildtrigger.id'], name=op.f('fk_repositorybuild_trigger_id_repositorybuildtrigger')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorybuild'))
+ op.create_table(
+ "label",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("key", UTF8CharField(length=255), nullable=False),
+ sa.Column("value", UTF8LongText(), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("source_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_label_media_type_id_mediatype"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["source_type_id"],
+ ["labelsourcetype.id"],
+ name=op.f("fk_label_source_type_id_labelsourcetype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_label")),
)
- op.create_index('repositorybuild_access_token_id', 'repositorybuild', ['access_token_id'], unique=False)
- op.create_index('repositorybuild_pull_robot_id', 'repositorybuild', ['pull_robot_id'], unique=False)
- op.create_index('repositorybuild_queue_id', 'repositorybuild', ['queue_id'], unique=False)
- op.create_index('repositorybuild_repository_id', 'repositorybuild', ['repository_id'], unique=False)
- op.create_index('repositorybuild_repository_id_started_phase', 'repositorybuild', ['repository_id', 'started', 'phase'], unique=False)
- op.create_index('repositorybuild_resource_key', 'repositorybuild', ['resource_key'], unique=False)
- op.create_index('repositorybuild_started', 'repositorybuild', ['started'], unique=False)
- op.create_index('repositorybuild_started_logs_archived_phase', 'repositorybuild', ['started', 'logs_archived', 'phase'], unique=False)
- op.create_index('repositorybuild_trigger_id', 'repositorybuild', ['trigger_id'], unique=False)
- op.create_index('repositorybuild_uuid', 'repositorybuild', ['uuid'], unique=False)
- op.create_table('tagmanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('tag_id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('json_data', UTF8LongText(), nullable=False),
- sa.ForeignKeyConstraint(['tag_id'], ['repositorytag.id'], name=op.f('fk_tagmanifest_tag_id_repositorytag')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagmanifest'))
+ op.create_index("label_key", "label", ["key"], unique=False)
+ op.create_index("label_media_type_id", "label", ["media_type_id"], unique=False)
+ op.create_index("label_source_type_id", "label", ["source_type_id"], unique=False)
+ op.create_index("label_uuid", "label", ["uuid"], unique=True)
+ op.create_table(
+ "logentry",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.Column("account_id", sa.Integer(), nullable=False),
+ sa.Column("performer_id", sa.Integer(), nullable=True),
+ sa.Column("repository_id", sa.Integer(), nullable=True),
+ sa.Column("datetime", sa.DateTime(), nullable=False),
+ sa.Column("ip", sa.String(length=255), nullable=True),
+ sa.Column("metadata_json", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["logentrykind.id"],
+ name=op.f("fk_logentry_kind_id_logentrykind"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_logentry")),
)
- op.create_index('tagmanifest_digest', 'tagmanifest', ['digest'], unique=False)
- op.create_index('tagmanifest_tag_id', 'tagmanifest', ['tag_id'], unique=True)
- op.create_table('tagmanifestlabel',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('annotated_id', sa.Integer(), nullable=False),
- sa.Column('label_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['annotated_id'], ['tagmanifest.id'], name=op.f('fk_tagmanifestlabel_annotated_id_tagmanifest')),
- sa.ForeignKeyConstraint(['label_id'], ['label.id'], name=op.f('fk_tagmanifestlabel_label_id_label')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_tagmanifestlabel_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagmanifestlabel'))
+ op.create_index("logentry_account_id", "logentry", ["account_id"], unique=False)
+ op.create_index(
+ "logentry_account_id_datetime",
+ "logentry",
+ ["account_id", "datetime"],
+ unique=False,
+ )
+ op.create_index("logentry_datetime", "logentry", ["datetime"], unique=False)
+ op.create_index("logentry_kind_id", "logentry", ["kind_id"], unique=False)
+ op.create_index("logentry_performer_id", "logentry", ["performer_id"], unique=False)
+ op.create_index(
+ "logentry_performer_id_datetime",
+ "logentry",
+ ["performer_id", "datetime"],
+ unique=False,
+ )
+ op.create_index(
+ "logentry_repository_id", "logentry", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "logentry_repository_id_datetime",
+ "logentry",
+ ["repository_id", "datetime"],
+ unique=False,
+ )
+ op.create_index(
+ "logentry_repository_id_datetime_kind_id",
+ "logentry",
+ ["repository_id", "datetime", "kind_id"],
+ unique=False,
+ )
+ op.create_table(
+ "notification",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=False),
+ sa.Column("target_id", sa.Integer(), nullable=False),
+ sa.Column("metadata_json", sa.Text(), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.Column("dismissed", sa.Boolean(), nullable=False),
+ sa.Column("lookup_path", sa.String(length=255), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["notificationkind.id"],
+ name=op.f("fk_notification_kind_id_notificationkind"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["target_id"], ["user.id"], name=op.f("fk_notification_target_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_notification")),
+ )
+ op.create_index("notification_created", "notification", ["created"], unique=False)
+ op.create_index("notification_kind_id", "notification", ["kind_id"], unique=False)
+ op.create_index(
+ "notification_lookup_path", "notification", ["lookup_path"], unique=False
+ )
+ op.create_index(
+ "notification_target_id", "notification", ["target_id"], unique=False
+ )
+ op.create_index("notification_uuid", "notification", ["uuid"], unique=False)
+ op.create_table(
+ "oauthapplication",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("client_id", sa.String(length=255), nullable=False),
+ sa.Column("client_secret", sa.String(length=255), nullable=False),
+ sa.Column("redirect_uri", sa.String(length=255), nullable=False),
+ sa.Column("application_uri", sa.String(length=255), nullable=False),
+ sa.Column("organization_id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("description", sa.Text(), nullable=False),
+ sa.Column("gravatar_email", sa.String(length=255), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["organization_id"],
+ ["user.id"],
+ name=op.f("fk_oauthapplication_organization_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_oauthapplication")),
+ )
+ op.create_index(
+ "oauthapplication_client_id", "oauthapplication", ["client_id"], unique=False
+ )
+ op.create_index(
+ "oauthapplication_organization_id",
+ "oauthapplication",
+ ["organization_id"],
+ unique=False,
+ )
+ op.create_table(
+ "quayrelease",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("service_id", sa.Integer(), nullable=False),
+ sa.Column("version", sa.String(length=255), nullable=False),
+ sa.Column("region_id", sa.Integer(), nullable=False),
+ sa.Column("reverted", sa.Boolean(), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["region_id"],
+ ["quayregion.id"],
+ name=op.f("fk_quayrelease_region_id_quayregion"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["service_id"],
+ ["quayservice.id"],
+ name=op.f("fk_quayrelease_service_id_quayservice"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_quayrelease")),
+ )
+ op.create_index("quayrelease_created", "quayrelease", ["created"], unique=False)
+ op.create_index("quayrelease_region_id", "quayrelease", ["region_id"], unique=False)
+ op.create_index(
+ "quayrelease_service_id", "quayrelease", ["service_id"], unique=False
+ )
+ op.create_index(
+ "quayrelease_service_id_region_id_created",
+ "quayrelease",
+ ["service_id", "region_id", "created"],
+ unique=False,
+ )
+ op.create_index(
+ "quayrelease_service_id_version_region_id",
+ "quayrelease",
+ ["service_id", "version", "region_id"],
+ unique=True,
+ )
+ op.create_table(
+ "repository",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("namespace_user_id", sa.Integer(), nullable=True),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("visibility_id", sa.Integer(), nullable=False),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.Column("badge_token", sa.String(length=255), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["namespace_user_id"],
+ ["user.id"],
+ name=op.f("fk_repository_namespace_user_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["visibility_id"],
+ ["visibility.id"],
+ name=op.f("fk_repository_visibility_id_visibility"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repository")),
+ )
+ op.create_index(
+ "repository_namespace_user_id",
+ "repository",
+ ["namespace_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repository_namespace_user_id_name",
+ "repository",
+ ["namespace_user_id", "name"],
+ unique=True,
+ )
+ op.create_index(
+ "repository_visibility_id", "repository", ["visibility_id"], unique=False
+ )
+ op.create_table(
+ "servicekey",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("kid", sa.String(length=255), nullable=False),
+ sa.Column("service", sa.String(length=255), nullable=False),
+ sa.Column("jwk", UTF8LongText(), nullable=False),
+ sa.Column("metadata", UTF8LongText(), nullable=False),
+ sa.Column("created_date", sa.DateTime(), nullable=False),
+ sa.Column("expiration_date", sa.DateTime(), nullable=True),
+ sa.Column("rotation_duration", sa.Integer(), nullable=True),
+ sa.Column("approval_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["approval_id"],
+ ["servicekeyapproval.id"],
+ name=op.f("fk_servicekey_approval_id_servicekeyapproval"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_servicekey")),
+ )
+ op.create_index(
+ "servicekey_approval_id", "servicekey", ["approval_id"], unique=False
+ )
+ op.create_index("servicekey_kid", "servicekey", ["kid"], unique=True)
+ op.create_index("servicekey_service", "servicekey", ["service"], unique=False)
+ op.create_table(
+ "team",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("organization_id", sa.Integer(), nullable=False),
+ sa.Column("role_id", sa.Integer(), nullable=False),
+ sa.Column("description", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["organization_id"], ["user.id"], name=op.f("fk_team_organization_id_user")
+ ),
+ sa.ForeignKeyConstraint(
+ ["role_id"], ["teamrole.id"], name=op.f("fk_team_role_id_teamrole")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_team")),
+ )
+ op.create_index("team_name", "team", ["name"], unique=False)
+ op.create_index(
+ "team_name_organization_id", "team", ["name", "organization_id"], unique=True
+ )
+ op.create_index("team_organization_id", "team", ["organization_id"], unique=False)
+ op.create_index("team_role_id", "team", ["role_id"], unique=False)
+ op.create_table(
+ "torrentinfo",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("storage_id", sa.Integer(), nullable=False),
+ sa.Column("piece_length", sa.Integer(), nullable=False),
+ sa.Column("pieces", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["storage_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_torrentinfo_storage_id_imagestorage"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_torrentinfo")),
+ )
+ op.create_index(
+ "torrentinfo_storage_id", "torrentinfo", ["storage_id"], unique=False
+ )
+ op.create_index(
+ "torrentinfo_storage_id_piece_length",
+ "torrentinfo",
+ ["storage_id", "piece_length"],
+ unique=True,
+ )
+ op.create_table(
+ "userregion",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["imagestoragelocation.id"],
+ name=op.f("fk_userregion_location_id_imagestoragelocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_userregion_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_userregion")),
+ )
+ op.create_index(
+ "userregion_location_id", "userregion", ["location_id"], unique=False
+ )
+ op.create_index("userregion_user_id", "userregion", ["user_id"], unique=False)
+ op.create_table(
+ "accesstoken",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("friendly_name", sa.String(length=255), nullable=True),
+ sa.Column("code", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.Column("role_id", sa.Integer(), nullable=False),
+ sa.Column("temporary", sa.Boolean(), nullable=False),
+ sa.Column("kind_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["kind_id"],
+ ["accesstokenkind.id"],
+ name=op.f("fk_accesstoken_kind_id_accesstokenkind"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_accesstoken_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["role_id"], ["role.id"], name=op.f("fk_accesstoken_role_id_role")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_accesstoken")),
+ )
+ op.create_index("accesstoken_code", "accesstoken", ["code"], unique=True)
+ op.create_index("accesstoken_kind_id", "accesstoken", ["kind_id"], unique=False)
+ op.create_index(
+ "accesstoken_repository_id", "accesstoken", ["repository_id"], unique=False
+ )
+ op.create_index("accesstoken_role_id", "accesstoken", ["role_id"], unique=False)
+ op.create_table(
+ "blobupload",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("byte_count", sa.Integer(), nullable=False),
+ sa.Column("sha_state", sa.Text(), nullable=True),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.Column("storage_metadata", sa.Text(), nullable=True),
+ sa.Column("chunk_count", sa.Integer(), nullable=False, server_default="0"),
+ sa.Column("uncompressed_byte_count", sa.Integer(), nullable=True),
+ sa.Column(
+ "created", sa.DateTime(), nullable=False, server_default=sa.text(now)
+ ),
+ sa.Column("piece_sha_state", UTF8LongText(), nullable=True),
+ sa.Column("piece_hashes", UTF8LongText(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["imagestoragelocation.id"],
+ name=op.f("fk_blobupload_location_id_imagestoragelocation"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_blobupload_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobupload")),
+ )
+ op.create_index("blobupload_created", "blobupload", ["created"], unique=False)
+ op.create_index(
+ "blobupload_location_id", "blobupload", ["location_id"], unique=False
+ )
+ op.create_index(
+ "blobupload_repository_id", "blobupload", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "blobupload_repository_id_uuid",
+ "blobupload",
+ ["repository_id", "uuid"],
+ unique=True,
+ )
+ op.create_index("blobupload_uuid", "blobupload", ["uuid"], unique=True)
+ op.create_table(
+ "image",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("docker_image_id", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("ancestors", sa.String(length=60535), nullable=True),
+ sa.Column("storage_id", sa.Integer(), nullable=True),
+ sa.Column("created", sa.DateTime(), nullable=True),
+ sa.Column("comment", UTF8LongText(), nullable=True),
+ sa.Column("command", sa.Text(), nullable=True),
+ sa.Column("aggregate_size", sa.BigInteger(), nullable=True),
+ sa.Column("v1_json_metadata", UTF8LongText(), nullable=True),
+ sa.Column("v1_checksum", sa.String(length=255), nullable=True),
+ sa.Column(
+ "security_indexed",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column(
+ "security_indexed_engine", sa.Integer(), nullable=False, server_default="-1"
+ ),
+ sa.Column("parent_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_image_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["storage_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_image_storage_id_imagestorage"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_image")),
+ )
+ op.create_index(
+ "image_ancestors", "image", ["ancestors"], unique=False, mysql_length=767
+ )
+ op.create_index("image_docker_image_id", "image", ["docker_image_id"], unique=False)
+ op.create_index("image_parent_id", "image", ["parent_id"], unique=False)
+ op.create_index("image_repository_id", "image", ["repository_id"], unique=False)
+ op.create_index(
+ "image_repository_id_docker_image_id",
+ "image",
+ ["repository_id", "docker_image_id"],
+ unique=True,
+ )
+ op.create_index(
+ "image_security_indexed", "image", ["security_indexed"], unique=False
+ )
+ op.create_index(
+ "image_security_indexed_engine",
+ "image",
+ ["security_indexed_engine"],
+ unique=False,
+ )
+ op.create_index(
+ "image_security_indexed_engine_security_indexed",
+ "image",
+ ["security_indexed_engine", "security_indexed"],
+ unique=False,
+ )
+ op.create_index("image_storage_id", "image", ["storage_id"], unique=False)
+ op.create_table(
+ "oauthaccesstoken",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("application_id", sa.Integer(), nullable=False),
+ sa.Column("authorized_user_id", sa.Integer(), nullable=False),
+ sa.Column("scope", sa.String(length=255), nullable=False),
+ sa.Column("access_token", sa.String(length=255), nullable=False),
+ sa.Column("token_type", sa.String(length=255), nullable=False),
+ sa.Column("expires_at", sa.DateTime(), nullable=False),
+ sa.Column("refresh_token", sa.String(length=255), nullable=True),
+ sa.Column("data", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["application_id"],
+ ["oauthapplication.id"],
+ name=op.f("fk_oauthaccesstoken_application_id_oauthapplication"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["authorized_user_id"],
+ ["user.id"],
+ name=op.f("fk_oauthaccesstoken_authorized_user_id_user"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_oauthaccesstoken")),
+ )
+ op.create_index(
+ "oauthaccesstoken_access_token",
+ "oauthaccesstoken",
+ ["access_token"],
+ unique=False,
+ )
+ op.create_index(
+ "oauthaccesstoken_application_id",
+ "oauthaccesstoken",
+ ["application_id"],
+ unique=False,
+ )
+ op.create_index(
+ "oauthaccesstoken_authorized_user_id",
+ "oauthaccesstoken",
+ ["authorized_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "oauthaccesstoken_refresh_token",
+ "oauthaccesstoken",
+ ["refresh_token"],
+ unique=False,
+ )
+ op.create_index("oauthaccesstoken_uuid", "oauthaccesstoken", ["uuid"], unique=False)
+ op.create_table(
+ "oauthauthorizationcode",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("application_id", sa.Integer(), nullable=False),
+ sa.Column("code", sa.String(length=255), nullable=False),
+ sa.Column("scope", sa.String(length=255), nullable=False),
+ sa.Column("data", sa.Text(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["application_id"],
+ ["oauthapplication.id"],
+ name=op.f("fk_oauthauthorizationcode_application_id_oauthapplication"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_oauthauthorizationcode")),
+ )
+ op.create_index(
+ "oauthauthorizationcode_application_id",
+ "oauthauthorizationcode",
+ ["application_id"],
+ unique=False,
+ )
+ op.create_index(
+ "oauthauthorizationcode_code", "oauthauthorizationcode", ["code"], unique=False
+ )
+ op.create_table(
+ "permissionprototype",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("org_id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("activating_user_id", sa.Integer(), nullable=True),
+ sa.Column("delegate_user_id", sa.Integer(), nullable=True),
+ sa.Column("delegate_team_id", sa.Integer(), nullable=True),
+ sa.Column("role_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["activating_user_id"],
+ ["user.id"],
+ name=op.f("fk_permissionprototype_activating_user_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["delegate_team_id"],
+ ["team.id"],
+ name=op.f("fk_permissionprototype_delegate_team_id_team"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["delegate_user_id"],
+ ["user.id"],
+ name=op.f("fk_permissionprototype_delegate_user_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["org_id"], ["user.id"], name=op.f("fk_permissionprototype_org_id_user")
+ ),
+ sa.ForeignKeyConstraint(
+ ["role_id"], ["role.id"], name=op.f("fk_permissionprototype_role_id_role")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_permissionprototype")),
+ )
+ op.create_index(
+ "permissionprototype_activating_user_id",
+ "permissionprototype",
+ ["activating_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "permissionprototype_delegate_team_id",
+ "permissionprototype",
+ ["delegate_team_id"],
+ unique=False,
+ )
+ op.create_index(
+ "permissionprototype_delegate_user_id",
+ "permissionprototype",
+ ["delegate_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "permissionprototype_org_id", "permissionprototype", ["org_id"], unique=False
+ )
+ op.create_index(
+ "permissionprototype_org_id_activating_user_id",
+ "permissionprototype",
+ ["org_id", "activating_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "permissionprototype_role_id", "permissionprototype", ["role_id"], unique=False
+ )
+ op.create_table(
+ "repositoryactioncount",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("count", sa.Integer(), nullable=False),
+ sa.Column("date", sa.Date(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositoryactioncount_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositoryactioncount")),
+ )
+ op.create_index(
+ "repositoryactioncount_date", "repositoryactioncount", ["date"], unique=False
+ )
+ op.create_index(
+ "repositoryactioncount_repository_id",
+ "repositoryactioncount",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositoryactioncount_repository_id_date",
+ "repositoryactioncount",
+ ["repository_id", "date"],
+ unique=True,
+ )
+ op.create_table(
+ "repositoryauthorizedemail",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("email", sa.String(length=255), nullable=False),
+ sa.Column("code", sa.String(length=255), nullable=False),
+ sa.Column("confirmed", sa.Boolean(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositoryauthorizedemail_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositoryauthorizedemail")),
+ )
+ op.create_index(
+ "repositoryauthorizedemail_code",
+ "repositoryauthorizedemail",
+ ["code"],
+ unique=True,
+ )
+ op.create_index(
+ "repositoryauthorizedemail_email_repository_id",
+ "repositoryauthorizedemail",
+ ["email", "repository_id"],
+ unique=True,
+ )
+ op.create_index(
+ "repositoryauthorizedemail_repository_id",
+ "repositoryauthorizedemail",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_table(
+ "repositorynotification",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("event_id", sa.Integer(), nullable=False),
+ sa.Column("method_id", sa.Integer(), nullable=False),
+ sa.Column("title", sa.String(length=255), nullable=True),
+ sa.Column("config_json", sa.Text(), nullable=False),
+ sa.Column("event_config_json", UTF8LongText(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["event_id"],
+ ["externalnotificationevent.id"],
+ name=op.f("fk_repositorynotification_event_id_externalnotificationevent"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["method_id"],
+ ["externalnotificationmethod.id"],
+ name=op.f("fk_repositorynotification_method_id_externalnotificationmethod"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorynotification_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorynotification")),
+ )
+ op.create_index(
+ "repositorynotification_event_id",
+ "repositorynotification",
+ ["event_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorynotification_method_id",
+ "repositorynotification",
+ ["method_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorynotification_repository_id",
+ "repositorynotification",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorynotification_uuid", "repositorynotification", ["uuid"], unique=False
+ )
+ op.create_table(
+ "repositorypermission",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("team_id", sa.Integer(), nullable=True),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("role_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorypermission_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["role_id"], ["role.id"], name=op.f("fk_repositorypermission_role_id_role")
+ ),
+ sa.ForeignKeyConstraint(
+ ["team_id"], ["team.id"], name=op.f("fk_repositorypermission_team_id_team")
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_repositorypermission_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorypermission")),
+ )
+ op.create_index(
+ "repositorypermission_repository_id",
+ "repositorypermission",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorypermission_role_id",
+ "repositorypermission",
+ ["role_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorypermission_team_id",
+ "repositorypermission",
+ ["team_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorypermission_team_id_repository_id",
+ "repositorypermission",
+ ["team_id", "repository_id"],
+ unique=True,
+ )
+ op.create_index(
+ "repositorypermission_user_id",
+ "repositorypermission",
+ ["user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorypermission_user_id_repository_id",
+ "repositorypermission",
+ ["user_id", "repository_id"],
+ unique=True,
+ )
+ op.create_table(
+ "star",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("created", sa.DateTime(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_star_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_star_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_star")),
+ )
+ op.create_index("star_repository_id", "star", ["repository_id"], unique=False)
+ op.create_index("star_user_id", "star", ["user_id"], unique=False)
+ op.create_index(
+ "star_user_id_repository_id", "star", ["user_id", "repository_id"], unique=True
+ )
+ op.create_table(
+ "teammember",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column("team_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["team_id"], ["team.id"], name=op.f("fk_teammember_team_id_team")
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_teammember_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_teammember")),
+ )
+ op.create_index("teammember_team_id", "teammember", ["team_id"], unique=False)
+ op.create_index("teammember_user_id", "teammember", ["user_id"], unique=False)
+ op.create_index(
+ "teammember_user_id_team_id", "teammember", ["user_id", "team_id"], unique=True
+ )
+ op.create_table(
+ "teammemberinvite",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("email", sa.String(length=255), nullable=True),
+ sa.Column("team_id", sa.Integer(), nullable=False),
+ sa.Column("inviter_id", sa.Integer(), nullable=False),
+ sa.Column("invite_token", sa.String(length=255), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["inviter_id"],
+ ["user.id"],
+ name=op.f("fk_teammemberinvite_inviter_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["team_id"], ["team.id"], name=op.f("fk_teammemberinvite_team_id_team")
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"], ["user.id"], name=op.f("fk_teammemberinvite_user_id_user")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_teammemberinvite")),
+ )
+ op.create_index(
+ "teammemberinvite_inviter_id", "teammemberinvite", ["inviter_id"], unique=False
+ )
+ op.create_index(
+ "teammemberinvite_team_id", "teammemberinvite", ["team_id"], unique=False
+ )
+ op.create_index(
+ "teammemberinvite_user_id", "teammemberinvite", ["user_id"], unique=False
+ )
+ op.create_table(
+ "derivedstorageforimage",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("source_image_id", sa.Integer(), nullable=False),
+ sa.Column("derivative_id", sa.Integer(), nullable=False),
+ sa.Column("transformation_id", sa.Integer(), nullable=False),
+ sa.Column("uniqueness_hash", sa.String(length=255), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["derivative_id"],
+ ["imagestorage.id"],
+ name=op.f("fk_derivedstorageforimage_derivative_id_imagestorage"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["source_image_id"],
+ ["image.id"],
+ name=op.f("fk_derivedstorageforimage_source_image_id_image"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["transformation_id"],
+ ["imagestoragetransformation.id"],
+ name=op.f("fk_derivedstorageforimage_transformation_constraint"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_derivedstorageforimage")),
+ )
+ op.create_index(
+ "derivedstorageforimage_derivative_id",
+ "derivedstorageforimage",
+ ["derivative_id"],
+ unique=False,
+ )
+ op.create_index(
+ "derivedstorageforimage_source_image_id",
+ "derivedstorageforimage",
+ ["source_image_id"],
+ unique=False,
+ )
+ op.create_index(
+ "uniqueness_index",
+ "derivedstorageforimage",
+ ["source_image_id", "transformation_id", "uniqueness_hash"],
+ unique=True,
+ )
+ op.create_index(
+ "derivedstorageforimage_transformation_id",
+ "derivedstorageforimage",
+ ["transformation_id"],
+ unique=False,
+ )
+ op.create_table(
+ "repositorybuildtrigger",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("service_id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("connected_user_id", sa.Integer(), nullable=False),
+ sa.Column("auth_token", sa.String(length=255), nullable=True),
+ sa.Column("private_key", sa.Text(), nullable=True),
+ sa.Column("config", sa.Text(), nullable=False),
+ sa.Column("write_token_id", sa.Integer(), nullable=True),
+ sa.Column("pull_robot_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["connected_user_id"],
+ ["user.id"],
+ name=op.f("fk_repositorybuildtrigger_connected_user_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["pull_robot_id"],
+ ["user.id"],
+ name=op.f("fk_repositorybuildtrigger_pull_robot_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorybuildtrigger_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["service_id"],
+ ["buildtriggerservice.id"],
+ name=op.f("fk_repositorybuildtrigger_service_id_buildtriggerservice"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["write_token_id"],
+ ["accesstoken.id"],
+ name=op.f("fk_repositorybuildtrigger_write_token_id_accesstoken"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorybuildtrigger")),
+ )
+ op.create_index(
+ "repositorybuildtrigger_connected_user_id",
+ "repositorybuildtrigger",
+ ["connected_user_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuildtrigger_pull_robot_id",
+ "repositorybuildtrigger",
+ ["pull_robot_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuildtrigger_repository_id",
+ "repositorybuildtrigger",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuildtrigger_service_id",
+ "repositorybuildtrigger",
+ ["service_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuildtrigger_write_token_id",
+ "repositorybuildtrigger",
+ ["write_token_id"],
+ unique=False,
+ )
+ op.create_table(
+ "repositorytag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.Column("image_id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "lifetime_start_ts", sa.Integer(), nullable=False, server_default="0"
+ ),
+ sa.Column("lifetime_end_ts", sa.Integer(), nullable=True),
+ sa.Column(
+ "hidden",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column(
+ "reversion",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.ForeignKeyConstraint(
+ ["image_id"], ["image.id"], name=op.f("fk_repositorytag_image_id_image")
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorytag_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorytag")),
+ )
+ op.create_index(
+ "repositorytag_image_id", "repositorytag", ["image_id"], unique=False
+ )
+ op.create_index(
+ "repositorytag_lifetime_end_ts",
+ "repositorytag",
+ ["lifetime_end_ts"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorytag_repository_id", "repositorytag", ["repository_id"], unique=False
+ )
+ op.create_index(
+ "repositorytag_repository_id_name",
+ "repositorytag",
+ ["repository_id", "name"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorytag_repository_id_name_lifetime_end_ts",
+ "repositorytag",
+ ["repository_id", "name", "lifetime_end_ts"],
+ unique=True,
+ )
+ op.create_table(
+ "repositorybuild",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("uuid", sa.String(length=255), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("access_token_id", sa.Integer(), nullable=False),
+ sa.Column("resource_key", sa.String(length=255), nullable=True),
+ sa.Column("job_config", sa.Text(), nullable=False),
+ sa.Column("phase", sa.String(length=255), nullable=False),
+ sa.Column("started", sa.DateTime(), nullable=False),
+ sa.Column("display_name", sa.String(length=255), nullable=False),
+ sa.Column("trigger_id", sa.Integer(), nullable=True),
+ sa.Column("pull_robot_id", sa.Integer(), nullable=True),
+ sa.Column(
+ "logs_archived",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ sa.Column("queue_id", sa.String(length=255), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["access_token_id"],
+ ["accesstoken.id"],
+ name=op.f("fk_repositorybuild_access_token_id_accesstoken"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["pull_robot_id"],
+ ["user.id"],
+ name=op.f("fk_repositorybuild_pull_robot_id_user"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorybuild_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["trigger_id"],
+ ["repositorybuildtrigger.id"],
+ name=op.f("fk_repositorybuild_trigger_id_repositorybuildtrigger"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorybuild")),
+ )
+ op.create_index(
+ "repositorybuild_access_token_id",
+ "repositorybuild",
+ ["access_token_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_pull_robot_id",
+ "repositorybuild",
+ ["pull_robot_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_queue_id", "repositorybuild", ["queue_id"], unique=False
+ )
+ op.create_index(
+ "repositorybuild_repository_id",
+ "repositorybuild",
+ ["repository_id"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_repository_id_started_phase",
+ "repositorybuild",
+ ["repository_id", "started", "phase"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_resource_key",
+ "repositorybuild",
+ ["resource_key"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_started", "repositorybuild", ["started"], unique=False
+ )
+ op.create_index(
+ "repositorybuild_started_logs_archived_phase",
+ "repositorybuild",
+ ["started", "logs_archived", "phase"],
+ unique=False,
+ )
+ op.create_index(
+ "repositorybuild_trigger_id", "repositorybuild", ["trigger_id"], unique=False
+ )
+ op.create_index("repositorybuild_uuid", "repositorybuild", ["uuid"], unique=False)
+ op.create_table(
+ "tagmanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("tag_id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("json_data", UTF8LongText(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["tag_id"],
+ ["repositorytag.id"],
+ name=op.f("fk_tagmanifest_tag_id_repositorytag"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagmanifest")),
+ )
+ op.create_index("tagmanifest_digest", "tagmanifest", ["digest"], unique=False)
+ op.create_index("tagmanifest_tag_id", "tagmanifest", ["tag_id"], unique=True)
+ op.create_table(
+ "tagmanifestlabel",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("annotated_id", sa.Integer(), nullable=False),
+ sa.Column("label_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["annotated_id"],
+ ["tagmanifest.id"],
+ name=op.f("fk_tagmanifestlabel_annotated_id_tagmanifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["label_id"], ["label.id"], name=op.f("fk_tagmanifestlabel_label_id_label")
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_tagmanifestlabel_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagmanifestlabel")),
+ )
+ op.create_index(
+ "tagmanifestlabel_annotated_id",
+ "tagmanifestlabel",
+ ["annotated_id"],
+ unique=False,
+ )
+ op.create_index(
+ "tagmanifestlabel_annotated_id_label_id",
+ "tagmanifestlabel",
+ ["annotated_id", "label_id"],
+ unique=True,
+ )
+ op.create_index(
+ "tagmanifestlabel_label_id", "tagmanifestlabel", ["label_id"], unique=False
+ )
+ op.create_index(
+ "tagmanifestlabel_repository_id",
+ "tagmanifestlabel",
+ ["repository_id"],
+ unique=False,
)
- op.create_index('tagmanifestlabel_annotated_id', 'tagmanifestlabel', ['annotated_id'], unique=False)
- op.create_index('tagmanifestlabel_annotated_id_label_id', 'tagmanifestlabel', ['annotated_id', 'label_id'], unique=True)
- op.create_index('tagmanifestlabel_label_id', 'tagmanifestlabel', ['label_id'], unique=False)
- op.create_index('tagmanifestlabel_repository_id', 'tagmanifestlabel', ['repository_id'], unique=False)
- op.bulk_insert(tables.accesstokenkind,
- [
- {'name':'build-worker'},
- {'name':'pushpull-token'},
- ])
+ op.bulk_insert(
+ tables.accesstokenkind, [{"name": "build-worker"}, {"name": "pushpull-token"}]
+ )
- op.bulk_insert(tables.buildtriggerservice,
- [
- {'name':'github'},
- {'name':'gitlab'},
- {'name':'bitbucket'},
- {'name':'custom-git'},
- ])
+ op.bulk_insert(
+ tables.buildtriggerservice,
+ [
+ {"name": "github"},
+ {"name": "gitlab"},
+ {"name": "bitbucket"},
+ {"name": "custom-git"},
+ ],
+ )
- op.bulk_insert(tables.externalnotificationevent,
- [
- {'name':'build_failure'},
- {'name':'build_queued'},
- {'name':'build_start'},
- {'name':'build_success'},
- {'name':'repo_push'},
- {'name':'vulnerability_found'},
- ])
+ op.bulk_insert(
+ tables.externalnotificationevent,
+ [
+ {"name": "build_failure"},
+ {"name": "build_queued"},
+ {"name": "build_start"},
+ {"name": "build_success"},
+ {"name": "repo_push"},
+ {"name": "vulnerability_found"},
+ ],
+ )
- op.bulk_insert(tables.externalnotificationmethod,
- [
- {'name':'email'},
- {'name':'flowdock'},
- {'name':'hipchat'},
- {'name':'quay_notification'},
- {'name':'slack'},
- {'name':'webhook'},
- ])
+ op.bulk_insert(
+ tables.externalnotificationmethod,
+ [
+ {"name": "email"},
+ {"name": "flowdock"},
+ {"name": "hipchat"},
+ {"name": "quay_notification"},
+ {"name": "slack"},
+ {"name": "webhook"},
+ ],
+ )
- op.bulk_insert(tables.imagestoragelocation,
- [
- {'name':'s3_us_east_1'},
- {'name':'s3_eu_west_1'},
- {'name':'s3_ap_southeast_1'},
- {'name':'s3_ap_southeast_2'},
- {'name':'s3_ap_northeast_1'},
- {'name':'s3_sa_east_1'},
- {'name':'local'},
- {'name':'s3_us_west_1'},
- ])
+ op.bulk_insert(
+ tables.imagestoragelocation,
+ [
+ {"name": "s3_us_east_1"},
+ {"name": "s3_eu_west_1"},
+ {"name": "s3_ap_southeast_1"},
+ {"name": "s3_ap_southeast_2"},
+ {"name": "s3_ap_northeast_1"},
+ {"name": "s3_sa_east_1"},
+ {"name": "local"},
+ {"name": "s3_us_west_1"},
+ ],
+ )
- op.bulk_insert(tables.imagestoragesignaturekind,
- [
- {'name':'gpg2'},
- ])
+ op.bulk_insert(tables.imagestoragesignaturekind, [{"name": "gpg2"}])
- op.bulk_insert(tables.imagestoragetransformation,
- [
- {'name':'squash'},
- {'name':'aci'},
- ])
+ op.bulk_insert(
+ tables.imagestoragetransformation, [{"name": "squash"}, {"name": "aci"}]
+ )
- op.bulk_insert(tables.labelsourcetype,
- [
- {'name':'manifest', 'mutable': False},
- {'name':'api', 'mutable': True},
- {'name':'internal', 'mutable': False},
- ])
+ op.bulk_insert(
+ tables.labelsourcetype,
+ [
+ {"name": "manifest", "mutable": False},
+ {"name": "api", "mutable": True},
+ {"name": "internal", "mutable": False},
+ ],
+ )
- op.bulk_insert(tables.logentrykind,
- [
- {'name':'account_change_cc'},
- {'name':'account_change_password'},
- {'name':'account_change_plan'},
- {'name':'account_convert'},
- {'name':'add_repo_accesstoken'},
- {'name':'add_repo_notification'},
- {'name':'add_repo_permission'},
- {'name':'add_repo_webhook'},
- {'name':'build_dockerfile'},
- {'name':'change_repo_permission'},
- {'name':'change_repo_visibility'},
- {'name':'create_application'},
- {'name':'create_prototype_permission'},
- {'name':'create_repo'},
- {'name':'create_robot'},
- {'name':'create_tag'},
- {'name':'delete_application'},
- {'name':'delete_prototype_permission'},
- {'name':'delete_repo'},
- {'name':'delete_repo_accesstoken'},
- {'name':'delete_repo_notification'},
- {'name':'delete_repo_permission'},
- {'name':'delete_repo_trigger'},
- {'name':'delete_repo_webhook'},
- {'name':'delete_robot'},
- {'name':'delete_tag'},
- {'name':'manifest_label_add'},
- {'name':'manifest_label_delete'},
- {'name':'modify_prototype_permission'},
- {'name':'move_tag'},
- {'name':'org_add_team_member'},
- {'name':'org_create_team'},
- {'name':'org_delete_team'},
- {'name':'org_delete_team_member_invite'},
- {'name':'org_invite_team_member'},
- {'name':'org_remove_team_member'},
- {'name':'org_set_team_description'},
- {'name':'org_set_team_role'},
- {'name':'org_team_member_invite_accepted'},
- {'name':'org_team_member_invite_declined'},
- {'name':'pull_repo'},
- {'name':'push_repo'},
- {'name':'regenerate_robot_token'},
- {'name':'repo_verb'},
- {'name':'reset_application_client_secret'},
- {'name':'revert_tag'},
- {'name':'service_key_approve'},
- {'name':'service_key_create'},
- {'name':'service_key_delete'},
- {'name':'service_key_extend'},
- {'name':'service_key_modify'},
- {'name':'service_key_rotate'},
- {'name':'setup_repo_trigger'},
- {'name':'set_repo_description'},
- {'name':'take_ownership'},
- {'name':'update_application'},
- ])
+ op.bulk_insert(
+ tables.logentrykind,
+ [
+ {"name": "account_change_cc"},
+ {"name": "account_change_password"},
+ {"name": "account_change_plan"},
+ {"name": "account_convert"},
+ {"name": "add_repo_accesstoken"},
+ {"name": "add_repo_notification"},
+ {"name": "add_repo_permission"},
+ {"name": "add_repo_webhook"},
+ {"name": "build_dockerfile"},
+ {"name": "change_repo_permission"},
+ {"name": "change_repo_visibility"},
+ {"name": "create_application"},
+ {"name": "create_prototype_permission"},
+ {"name": "create_repo"},
+ {"name": "create_robot"},
+ {"name": "create_tag"},
+ {"name": "delete_application"},
+ {"name": "delete_prototype_permission"},
+ {"name": "delete_repo"},
+ {"name": "delete_repo_accesstoken"},
+ {"name": "delete_repo_notification"},
+ {"name": "delete_repo_permission"},
+ {"name": "delete_repo_trigger"},
+ {"name": "delete_repo_webhook"},
+ {"name": "delete_robot"},
+ {"name": "delete_tag"},
+ {"name": "manifest_label_add"},
+ {"name": "manifest_label_delete"},
+ {"name": "modify_prototype_permission"},
+ {"name": "move_tag"},
+ {"name": "org_add_team_member"},
+ {"name": "org_create_team"},
+ {"name": "org_delete_team"},
+ {"name": "org_delete_team_member_invite"},
+ {"name": "org_invite_team_member"},
+ {"name": "org_remove_team_member"},
+ {"name": "org_set_team_description"},
+ {"name": "org_set_team_role"},
+ {"name": "org_team_member_invite_accepted"},
+ {"name": "org_team_member_invite_declined"},
+ {"name": "pull_repo"},
+ {"name": "push_repo"},
+ {"name": "regenerate_robot_token"},
+ {"name": "repo_verb"},
+ {"name": "reset_application_client_secret"},
+ {"name": "revert_tag"},
+ {"name": "service_key_approve"},
+ {"name": "service_key_create"},
+ {"name": "service_key_delete"},
+ {"name": "service_key_extend"},
+ {"name": "service_key_modify"},
+ {"name": "service_key_rotate"},
+ {"name": "setup_repo_trigger"},
+ {"name": "set_repo_description"},
+ {"name": "take_ownership"},
+ {"name": "update_application"},
+ ],
+ )
- op.bulk_insert(tables.loginservice,
- [
- {'name':'github'},
- {'name':'quayrobot'},
- {'name':'ldap'},
- {'name':'google'},
- {'name':'keystone'},
- {'name':'dex'},
- {'name':'jwtauthn'},
- ])
+ op.bulk_insert(
+ tables.loginservice,
+ [
+ {"name": "github"},
+ {"name": "quayrobot"},
+ {"name": "ldap"},
+ {"name": "google"},
+ {"name": "keystone"},
+ {"name": "dex"},
+ {"name": "jwtauthn"},
+ ],
+ )
- op.bulk_insert(tables.mediatype,
- [
- {'name':'text/plain'},
- {'name':'application/json'},
- ])
+ op.bulk_insert(
+ tables.mediatype, [{"name": "text/plain"}, {"name": "application/json"}]
+ )
- op.bulk_insert(tables.notificationkind,
- [
- {'name':'build_failure'},
- {'name':'build_queued'},
- {'name':'build_start'},
- {'name':'build_success'},
- {'name':'expiring_license'},
- {'name':'maintenance'},
- {'name':'org_team_invite'},
- {'name':'over_private_usage'},
- {'name':'password_required'},
- {'name':'repo_push'},
- {'name':'service_key_submitted'},
- {'name':'vulnerability_found'},
- ])
+ op.bulk_insert(
+ tables.notificationkind,
+ [
+ {"name": "build_failure"},
+ {"name": "build_queued"},
+ {"name": "build_start"},
+ {"name": "build_success"},
+ {"name": "expiring_license"},
+ {"name": "maintenance"},
+ {"name": "org_team_invite"},
+ {"name": "over_private_usage"},
+ {"name": "password_required"},
+ {"name": "repo_push"},
+ {"name": "service_key_submitted"},
+ {"name": "vulnerability_found"},
+ ],
+ )
- op.bulk_insert(tables.role,
- [
- {'name':'admin'},
- {'name':'write'},
- {'name':'read'},
- ])
+ op.bulk_insert(
+ tables.role, [{"name": "admin"}, {"name": "write"}, {"name": "read"}]
+ )
- op.bulk_insert(tables.teamrole,
- [
- {'name':'admin'},
- {'name':'creator'},
- {'name':'member'},
- ])
+ op.bulk_insert(
+ tables.teamrole, [{"name": "admin"}, {"name": "creator"}, {"name": "member"}]
+ )
- op.bulk_insert(tables.visibility,
- [
- {'name':'public'},
- {'name':'private'},
- ])
+ op.bulk_insert(tables.visibility, [{"name": "public"}, {"name": "private"}])
# ### population of test data ### #
- tester.populate_table('user', [
- ('uuid', tester.TestDataType.UUID),
- ('username', tester.TestDataType.String),
- ('password_hash', tester.TestDataType.String),
- ('email', tester.TestDataType.String),
- ('verified', tester.TestDataType.Boolean),
- ('organization', tester.TestDataType.Boolean),
- ('robot', tester.TestDataType.Boolean),
- ('invoice_email', tester.TestDataType.Boolean),
- ('invalid_login_attempts', tester.TestDataType.Integer),
- ('last_invalid_login', tester.TestDataType.DateTime),
- ('removed_tag_expiration_s', tester.TestDataType.Integer),
- ('enabled', tester.TestDataType.Boolean),
- ('invoice_email_address', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "user",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("username", tester.TestDataType.String),
+ ("password_hash", tester.TestDataType.String),
+ ("email", tester.TestDataType.String),
+ ("verified", tester.TestDataType.Boolean),
+ ("organization", tester.TestDataType.Boolean),
+ ("robot", tester.TestDataType.Boolean),
+ ("invoice_email", tester.TestDataType.Boolean),
+ ("invalid_login_attempts", tester.TestDataType.Integer),
+ ("last_invalid_login", tester.TestDataType.DateTime),
+ ("removed_tag_expiration_s", tester.TestDataType.Integer),
+ ("enabled", tester.TestDataType.Boolean),
+ ("invoice_email_address", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('repository', [
- ('namespace_user_id', tester.TestDataType.Foreign('user')),
- ('name', tester.TestDataType.String),
- ('visibility_id', tester.TestDataType.Foreign('visibility')),
- ('description', tester.TestDataType.String),
- ('badge_token', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "repository",
+ [
+ ("namespace_user_id", tester.TestDataType.Foreign("user")),
+ ("name", tester.TestDataType.String),
+ ("visibility_id", tester.TestDataType.Foreign("visibility")),
+ ("description", tester.TestDataType.String),
+ ("badge_token", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('emailconfirmation', [
- ('code', tester.TestDataType.String),
- ('user_id', tester.TestDataType.Foreign('user')),
- ('pw_reset', tester.TestDataType.Boolean),
- ('email_confirm', tester.TestDataType.Boolean),
- ('created', tester.TestDataType.DateTime),
- ])
+ tester.populate_table(
+ "emailconfirmation",
+ [
+ ("code", tester.TestDataType.String),
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("pw_reset", tester.TestDataType.Boolean),
+ ("email_confirm", tester.TestDataType.Boolean),
+ ("created", tester.TestDataType.DateTime),
+ ],
+ )
- tester.populate_table('federatedlogin', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('service_id', tester.TestDataType.Foreign('loginservice')),
- ('service_ident', tester.TestDataType.String),
- ('metadata_json', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "federatedlogin",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("service_id", tester.TestDataType.Foreign("loginservice")),
+ ("service_ident", tester.TestDataType.String),
+ ("metadata_json", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('imagestorage', [
- ('uuid', tester.TestDataType.UUID),
- ('checksum', tester.TestDataType.String),
- ('image_size', tester.TestDataType.BigInteger),
- ('uncompressed_size', tester.TestDataType.BigInteger),
- ('uploading', tester.TestDataType.Boolean),
- ('cas_path', tester.TestDataType.Boolean),
- ('content_checksum', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "imagestorage",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("checksum", tester.TestDataType.String),
+ ("image_size", tester.TestDataType.BigInteger),
+ ("uncompressed_size", tester.TestDataType.BigInteger),
+ ("uploading", tester.TestDataType.Boolean),
+ ("cas_path", tester.TestDataType.Boolean),
+ ("content_checksum", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('image', [
- ('docker_image_id', tester.TestDataType.UUID),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('ancestors', tester.TestDataType.String),
- ('storage_id', tester.TestDataType.Foreign('imagestorage')),
- ('security_indexed', tester.TestDataType.Boolean),
- ('security_indexed_engine', tester.TestDataType.Integer),
- ])
+ tester.populate_table(
+ "image",
+ [
+ ("docker_image_id", tester.TestDataType.UUID),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("ancestors", tester.TestDataType.String),
+ ("storage_id", tester.TestDataType.Foreign("imagestorage")),
+ ("security_indexed", tester.TestDataType.Boolean),
+ ("security_indexed_engine", tester.TestDataType.Integer),
+ ],
+ )
- tester.populate_table('imagestorageplacement', [
- ('storage_id', tester.TestDataType.Foreign('imagestorage')),
- ('location_id', tester.TestDataType.Foreign('imagestoragelocation')),
- ])
+ tester.populate_table(
+ "imagestorageplacement",
+ [
+ ("storage_id", tester.TestDataType.Foreign("imagestorage")),
+ ("location_id", tester.TestDataType.Foreign("imagestoragelocation")),
+ ],
+ )
- tester.populate_table('messages', [
- ('content', tester.TestDataType.String),
- ('uuid', tester.TestDataType.UUID),
- ])
+ tester.populate_table(
+ "messages",
+ [("content", tester.TestDataType.String), ("uuid", tester.TestDataType.UUID)],
+ )
- tester.populate_table('queueitem', [
- ('queue_name', tester.TestDataType.String),
- ('body', tester.TestDataType.JSON),
- ('available_after', tester.TestDataType.DateTime),
- ('available', tester.TestDataType.Boolean),
- ('processing_expires', tester.TestDataType.DateTime),
- ('retries_remaining', tester.TestDataType.Integer),
- ])
+ tester.populate_table(
+ "queueitem",
+ [
+ ("queue_name", tester.TestDataType.String),
+ ("body", tester.TestDataType.JSON),
+ ("available_after", tester.TestDataType.DateTime),
+ ("available", tester.TestDataType.Boolean),
+ ("processing_expires", tester.TestDataType.DateTime),
+ ("retries_remaining", tester.TestDataType.Integer),
+ ],
+ )
- tester.populate_table('servicekeyapproval', [
- ('approver_id', tester.TestDataType.Foreign('user')),
- ('approval_type', tester.TestDataType.String),
- ('approved_date', tester.TestDataType.DateTime),
- ('notes', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "servicekeyapproval",
+ [
+ ("approver_id", tester.TestDataType.Foreign("user")),
+ ("approval_type", tester.TestDataType.String),
+ ("approved_date", tester.TestDataType.DateTime),
+ ("notes", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('servicekey', [
- ('name', tester.TestDataType.String),
- ('kid', tester.TestDataType.String),
- ('service', tester.TestDataType.String),
- ('jwk', tester.TestDataType.JSON),
- ('metadata', tester.TestDataType.JSON),
- ('created_date', tester.TestDataType.DateTime),
- ('approval_id', tester.TestDataType.Foreign('servicekeyapproval')),
- ])
+ tester.populate_table(
+ "servicekey",
+ [
+ ("name", tester.TestDataType.String),
+ ("kid", tester.TestDataType.String),
+ ("service", tester.TestDataType.String),
+ ("jwk", tester.TestDataType.JSON),
+ ("metadata", tester.TestDataType.JSON),
+ ("created_date", tester.TestDataType.DateTime),
+ ("approval_id", tester.TestDataType.Foreign("servicekeyapproval")),
+ ],
+ )
- tester.populate_table('label', [
- ('uuid', tester.TestDataType.UUID),
- ('key', tester.TestDataType.UTF8Char),
- ('value', tester.TestDataType.JSON),
- ('media_type_id', tester.TestDataType.Foreign('mediatype')),
- ('source_type_id', tester.TestDataType.Foreign('labelsourcetype')),
- ])
+ tester.populate_table(
+ "label",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("key", tester.TestDataType.UTF8Char),
+ ("value", tester.TestDataType.JSON),
+ ("media_type_id", tester.TestDataType.Foreign("mediatype")),
+ ("source_type_id", tester.TestDataType.Foreign("labelsourcetype")),
+ ],
+ )
- tester.populate_table('logentry', [
- ('kind_id', tester.TestDataType.Foreign('logentrykind')),
- ('account_id', tester.TestDataType.Foreign('user')),
- ('performer_id', tester.TestDataType.Foreign('user')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('datetime', tester.TestDataType.DateTime),
- ('ip', tester.TestDataType.String),
- ('metadata_json', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "logentry",
+ [
+ ("kind_id", tester.TestDataType.Foreign("logentrykind")),
+ ("account_id", tester.TestDataType.Foreign("user")),
+ ("performer_id", tester.TestDataType.Foreign("user")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("datetime", tester.TestDataType.DateTime),
+ ("ip", tester.TestDataType.String),
+ ("metadata_json", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('notification', [
- ('uuid', tester.TestDataType.UUID),
- ('kind_id', tester.TestDataType.Foreign('notificationkind')),
- ('target_id', tester.TestDataType.Foreign('user')),
- ('metadata_json', tester.TestDataType.JSON),
- ('created', tester.TestDataType.DateTime),
- ('dismissed', tester.TestDataType.Boolean),
- ('lookup_path', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "notification",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("kind_id", tester.TestDataType.Foreign("notificationkind")),
+ ("target_id", tester.TestDataType.Foreign("user")),
+ ("metadata_json", tester.TestDataType.JSON),
+ ("created", tester.TestDataType.DateTime),
+ ("dismissed", tester.TestDataType.Boolean),
+ ("lookup_path", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('oauthapplication', [
- ('client_id', tester.TestDataType.String),
- ('client_secret', tester.TestDataType.String),
- ('redirect_uri', tester.TestDataType.String),
- ('application_uri', tester.TestDataType.String),
- ('organization_id', tester.TestDataType.Foreign('user')),
- ('name', tester.TestDataType.String),
- ('description', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "oauthapplication",
+ [
+ ("client_id", tester.TestDataType.String),
+ ("client_secret", tester.TestDataType.String),
+ ("redirect_uri", tester.TestDataType.String),
+ ("application_uri", tester.TestDataType.String),
+ ("organization_id", tester.TestDataType.Foreign("user")),
+ ("name", tester.TestDataType.String),
+ ("description", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('team', [
- ('name', tester.TestDataType.String),
- ('organization_id', tester.TestDataType.Foreign('user')),
- ('role_id', tester.TestDataType.Foreign('teamrole')),
- ('description', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "team",
+ [
+ ("name", tester.TestDataType.String),
+ ("organization_id", tester.TestDataType.Foreign("user")),
+ ("role_id", tester.TestDataType.Foreign("teamrole")),
+ ("description", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('torrentinfo', [
- ('storage_id', tester.TestDataType.Foreign('imagestorage')),
- ('piece_length', tester.TestDataType.Integer),
- ('pieces', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "torrentinfo",
+ [
+ ("storage_id", tester.TestDataType.Foreign("imagestorage")),
+ ("piece_length", tester.TestDataType.Integer),
+ ("pieces", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('userregion', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('location_id', tester.TestDataType.Foreign('imagestoragelocation')),
- ])
+ tester.populate_table(
+ "userregion",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("location_id", tester.TestDataType.Foreign("imagestoragelocation")),
+ ],
+ )
- tester.populate_table('accesstoken', [
- ('friendly_name', tester.TestDataType.String),
- ('code', tester.TestDataType.Token),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('created', tester.TestDataType.DateTime),
- ('role_id', tester.TestDataType.Foreign('role')),
- ('temporary', tester.TestDataType.Boolean),
- ('kind_id', tester.TestDataType.Foreign('accesstokenkind')),
- ])
+ tester.populate_table(
+ "accesstoken",
+ [
+ ("friendly_name", tester.TestDataType.String),
+ ("code", tester.TestDataType.Token),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("created", tester.TestDataType.DateTime),
+ ("role_id", tester.TestDataType.Foreign("role")),
+ ("temporary", tester.TestDataType.Boolean),
+ ("kind_id", tester.TestDataType.Foreign("accesstokenkind")),
+ ],
+ )
- tester.populate_table('blobupload', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('uuid', tester.TestDataType.UUID),
- ('byte_count', tester.TestDataType.Integer),
- ('sha_state', tester.TestDataType.String),
- ('location_id', tester.TestDataType.Foreign('imagestoragelocation')),
- ('chunk_count', tester.TestDataType.Integer),
- ('created', tester.TestDataType.DateTime),
- ])
+ tester.populate_table(
+ "blobupload",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("uuid", tester.TestDataType.UUID),
+ ("byte_count", tester.TestDataType.Integer),
+ ("sha_state", tester.TestDataType.String),
+ ("location_id", tester.TestDataType.Foreign("imagestoragelocation")),
+ ("chunk_count", tester.TestDataType.Integer),
+ ("created", tester.TestDataType.DateTime),
+ ],
+ )
- tester.populate_table('oauthaccesstoken', [
- ('uuid', tester.TestDataType.UUID),
- ('application_id', tester.TestDataType.Foreign('oauthapplication')),
- ('authorized_user_id', tester.TestDataType.Foreign('user')),
- ('scope', tester.TestDataType.String),
- ('access_token', tester.TestDataType.Token),
- ('token_type', tester.TestDataType.String),
- ('expires_at', tester.TestDataType.DateTime),
- ('data', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "oauthaccesstoken",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("application_id", tester.TestDataType.Foreign("oauthapplication")),
+ ("authorized_user_id", tester.TestDataType.Foreign("user")),
+ ("scope", tester.TestDataType.String),
+ ("access_token", tester.TestDataType.Token),
+ ("token_type", tester.TestDataType.String),
+ ("expires_at", tester.TestDataType.DateTime),
+ ("data", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('oauthauthorizationcode', [
- ('application_id', tester.TestDataType.Foreign('oauthapplication')),
- ('code', tester.TestDataType.Token),
- ('scope', tester.TestDataType.String),
- ('data', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "oauthauthorizationcode",
+ [
+ ("application_id", tester.TestDataType.Foreign("oauthapplication")),
+ ("code", tester.TestDataType.Token),
+ ("scope", tester.TestDataType.String),
+ ("data", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('permissionprototype', [
- ('org_id', tester.TestDataType.Foreign('user')),
- ('uuid', tester.TestDataType.UUID),
- ('activating_user_id', tester.TestDataType.Foreign('user')),
- ('delegate_user_id', tester.TestDataType.Foreign('user')),
- ('role_id', tester.TestDataType.Foreign('role')),
- ])
+ tester.populate_table(
+ "permissionprototype",
+ [
+ ("org_id", tester.TestDataType.Foreign("user")),
+ ("uuid", tester.TestDataType.UUID),
+ ("activating_user_id", tester.TestDataType.Foreign("user")),
+ ("delegate_user_id", tester.TestDataType.Foreign("user")),
+ ("role_id", tester.TestDataType.Foreign("role")),
+ ],
+ )
- tester.populate_table('repositoryactioncount', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('count', tester.TestDataType.Integer),
- ('date', tester.TestDataType.Date),
- ])
+ tester.populate_table(
+ "repositoryactioncount",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("count", tester.TestDataType.Integer),
+ ("date", tester.TestDataType.Date),
+ ],
+ )
- tester.populate_table('repositoryauthorizedemail', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('email', tester.TestDataType.String),
- ('code', tester.TestDataType.String),
- ('confirmed', tester.TestDataType.Boolean),
- ])
+ tester.populate_table(
+ "repositoryauthorizedemail",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("email", tester.TestDataType.String),
+ ("code", tester.TestDataType.String),
+ ("confirmed", tester.TestDataType.Boolean),
+ ],
+ )
- tester.populate_table('repositorynotification', [
- ('uuid', tester.TestDataType.UUID),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('event_id', tester.TestDataType.Foreign('externalnotificationevent')),
- ('method_id', tester.TestDataType.Foreign('externalnotificationmethod')),
- ('title', tester.TestDataType.String),
- ('config_json', tester.TestDataType.JSON),
- ('event_config_json', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "repositorynotification",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("event_id", tester.TestDataType.Foreign("externalnotificationevent")),
+ ("method_id", tester.TestDataType.Foreign("externalnotificationmethod")),
+ ("title", tester.TestDataType.String),
+ ("config_json", tester.TestDataType.JSON),
+ ("event_config_json", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('repositorypermission', [
- ('team_id', tester.TestDataType.Foreign('team')),
- ('user_id', tester.TestDataType.Foreign('user')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('role_id', tester.TestDataType.Foreign('role')),
- ])
+ tester.populate_table(
+ "repositorypermission",
+ [
+ ("team_id", tester.TestDataType.Foreign("team")),
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("role_id", tester.TestDataType.Foreign("role")),
+ ],
+ )
- tester.populate_table('star', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('created', tester.TestDataType.DateTime),
- ])
+ tester.populate_table(
+ "star",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("created", tester.TestDataType.DateTime),
+ ],
+ )
- tester.populate_table('teammember', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('team_id', tester.TestDataType.Foreign('team')),
- ])
+ tester.populate_table(
+ "teammember",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("team_id", tester.TestDataType.Foreign("team")),
+ ],
+ )
- tester.populate_table('teammemberinvite', [
- ('user_id', tester.TestDataType.Foreign('user')),
- ('email', tester.TestDataType.String),
- ('team_id', tester.TestDataType.Foreign('team')),
- ('inviter_id', tester.TestDataType.Foreign('user')),
- ('invite_token', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "teammemberinvite",
+ [
+ ("user_id", tester.TestDataType.Foreign("user")),
+ ("email", tester.TestDataType.String),
+ ("team_id", tester.TestDataType.Foreign("team")),
+ ("inviter_id", tester.TestDataType.Foreign("user")),
+ ("invite_token", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('derivedstorageforimage', [
- ('source_image_id', tester.TestDataType.Foreign('image')),
- ('derivative_id', tester.TestDataType.Foreign('imagestorage')),
- ('transformation_id', tester.TestDataType.Foreign('imagestoragetransformation')),
- ('uniqueness_hash', tester.TestDataType.String),
- ])
+ tester.populate_table(
+ "derivedstorageforimage",
+ [
+ ("source_image_id", tester.TestDataType.Foreign("image")),
+ ("derivative_id", tester.TestDataType.Foreign("imagestorage")),
+ (
+ "transformation_id",
+ tester.TestDataType.Foreign("imagestoragetransformation"),
+ ),
+ ("uniqueness_hash", tester.TestDataType.String),
+ ],
+ )
- tester.populate_table('repositorybuildtrigger', [
- ('uuid', tester.TestDataType.UUID),
- ('service_id', tester.TestDataType.Foreign('buildtriggerservice')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('connected_user_id', tester.TestDataType.Foreign('user')),
- ('auth_token', tester.TestDataType.String),
- ('config', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "repositorybuildtrigger",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("service_id", tester.TestDataType.Foreign("buildtriggerservice")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("connected_user_id", tester.TestDataType.Foreign("user")),
+ ("auth_token", tester.TestDataType.String),
+ ("config", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('repositorytag', [
- ('name', tester.TestDataType.String),
- ('image_id', tester.TestDataType.Foreign('image')),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('lifetime_start_ts', tester.TestDataType.Integer),
- ('hidden', tester.TestDataType.Boolean),
- ('reversion', tester.TestDataType.Boolean),
- ])
+ tester.populate_table(
+ "repositorytag",
+ [
+ ("name", tester.TestDataType.String),
+ ("image_id", tester.TestDataType.Foreign("image")),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("lifetime_start_ts", tester.TestDataType.Integer),
+ ("hidden", tester.TestDataType.Boolean),
+ ("reversion", tester.TestDataType.Boolean),
+ ],
+ )
- tester.populate_table('repositorybuild', [
- ('uuid', tester.TestDataType.UUID),
- ('phase', tester.TestDataType.String),
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('access_token_id', tester.TestDataType.Foreign('accesstoken')),
- ('resource_key', tester.TestDataType.String),
- ('job_config', tester.TestDataType.JSON),
- ('started', tester.TestDataType.DateTime),
- ('display_name', tester.TestDataType.JSON),
- ('trigger_id', tester.TestDataType.Foreign('repositorybuildtrigger')),
- ('logs_archived', tester.TestDataType.Boolean),
- ])
+ tester.populate_table(
+ "repositorybuild",
+ [
+ ("uuid", tester.TestDataType.UUID),
+ ("phase", tester.TestDataType.String),
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("access_token_id", tester.TestDataType.Foreign("accesstoken")),
+ ("resource_key", tester.TestDataType.String),
+ ("job_config", tester.TestDataType.JSON),
+ ("started", tester.TestDataType.DateTime),
+ ("display_name", tester.TestDataType.JSON),
+ ("trigger_id", tester.TestDataType.Foreign("repositorybuildtrigger")),
+ ("logs_archived", tester.TestDataType.Boolean),
+ ],
+ )
- tester.populate_table('tagmanifest', [
- ('tag_id', tester.TestDataType.Foreign('repositorytag')),
- ('digest', tester.TestDataType.String),
- ('json_data', tester.TestDataType.JSON),
- ])
+ tester.populate_table(
+ "tagmanifest",
+ [
+ ("tag_id", tester.TestDataType.Foreign("repositorytag")),
+ ("digest", tester.TestDataType.String),
+ ("json_data", tester.TestDataType.JSON),
+ ],
+ )
- tester.populate_table('tagmanifestlabel', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('annotated_id', tester.TestDataType.Foreign('tagmanifest')),
- ('label_id', tester.TestDataType.Foreign('label')),
- ])
+ tester.populate_table(
+ "tagmanifestlabel",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("annotated_id", tester.TestDataType.Foreign("tagmanifest")),
+ ("label_id", tester.TestDataType.Foreign("label")),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.drop_table('tagmanifestlabel')
- op.drop_table('tagmanifest')
- op.drop_table('repositorybuild')
- op.drop_table('repositorytag')
- op.drop_table('repositorybuildtrigger')
- op.drop_table('derivedstorageforimage')
- op.drop_table('teammemberinvite')
- op.drop_table('teammember')
- op.drop_table('star')
- op.drop_table('repositorypermission')
- op.drop_table('repositorynotification')
- op.drop_table('repositoryauthorizedemail')
- op.drop_table('repositoryactioncount')
- op.drop_table('permissionprototype')
- op.drop_table('oauthauthorizationcode')
- op.drop_table('oauthaccesstoken')
- op.drop_table('image')
- op.drop_table('blobupload')
- op.drop_table('accesstoken')
- op.drop_table('userregion')
- op.drop_table('torrentinfo')
- op.drop_table('team')
- op.drop_table('servicekey')
- op.drop_table('repository')
- op.drop_table('quayrelease')
- op.drop_table('oauthapplication')
- op.drop_table('notification')
- op.drop_table('logentry')
- op.drop_table('label')
- op.drop_table('imagestoragesignature')
- op.drop_table('imagestorageplacement')
- op.drop_table('federatedlogin')
- op.drop_table('emailconfirmation')
- op.drop_table('visibility')
- op.drop_table('user')
- op.drop_table('teamrole')
- op.drop_table('servicekeyapproval')
- op.drop_table('role')
- op.drop_table('queueitem')
- op.drop_table('quayservice')
- op.drop_table('quayregion')
- op.drop_table('notificationkind')
- op.drop_table('messages')
- op.drop_table('mediatype')
- op.drop_table('loginservice')
- op.drop_table('logentrykind')
- op.drop_table('labelsourcetype')
- op.drop_table('imagestoragetransformation')
- op.drop_table('imagestoragesignaturekind')
- op.drop_table('imagestoragelocation')
- op.drop_table('imagestorage')
- op.drop_table('externalnotificationmethod')
- op.drop_table('externalnotificationevent')
- op.drop_table('buildtriggerservice')
- op.drop_table('accesstokenkind')
+ op.drop_table("tagmanifestlabel")
+ op.drop_table("tagmanifest")
+ op.drop_table("repositorybuild")
+ op.drop_table("repositorytag")
+ op.drop_table("repositorybuildtrigger")
+ op.drop_table("derivedstorageforimage")
+ op.drop_table("teammemberinvite")
+ op.drop_table("teammember")
+ op.drop_table("star")
+ op.drop_table("repositorypermission")
+ op.drop_table("repositorynotification")
+ op.drop_table("repositoryauthorizedemail")
+ op.drop_table("repositoryactioncount")
+ op.drop_table("permissionprototype")
+ op.drop_table("oauthauthorizationcode")
+ op.drop_table("oauthaccesstoken")
+ op.drop_table("image")
+ op.drop_table("blobupload")
+ op.drop_table("accesstoken")
+ op.drop_table("userregion")
+ op.drop_table("torrentinfo")
+ op.drop_table("team")
+ op.drop_table("servicekey")
+ op.drop_table("repository")
+ op.drop_table("quayrelease")
+ op.drop_table("oauthapplication")
+ op.drop_table("notification")
+ op.drop_table("logentry")
+ op.drop_table("label")
+ op.drop_table("imagestoragesignature")
+ op.drop_table("imagestorageplacement")
+ op.drop_table("federatedlogin")
+ op.drop_table("emailconfirmation")
+ op.drop_table("visibility")
+ op.drop_table("user")
+ op.drop_table("teamrole")
+ op.drop_table("servicekeyapproval")
+ op.drop_table("role")
+ op.drop_table("queueitem")
+ op.drop_table("quayservice")
+ op.drop_table("quayregion")
+ op.drop_table("notificationkind")
+ op.drop_table("messages")
+ op.drop_table("mediatype")
+ op.drop_table("loginservice")
+ op.drop_table("logentrykind")
+ op.drop_table("labelsourcetype")
+ op.drop_table("imagestoragetransformation")
+ op.drop_table("imagestoragesignaturekind")
+ op.drop_table("imagestoragelocation")
+ op.drop_table("imagestorage")
+ op.drop_table("externalnotificationmethod")
+ op.drop_table("externalnotificationevent")
+ op.drop_table("buildtriggerservice")
+ op.drop_table("accesstokenkind")
diff --git a/data/migrations/versions/c3d4b7ebcdf7_backfill_repositorysearchscore_table.py b/data/migrations/versions/c3d4b7ebcdf7_backfill_repositorysearchscore_table.py
index 8e0a8ab8c..a2f61ae44 100644
--- a/data/migrations/versions/c3d4b7ebcdf7_backfill_repositorysearchscore_table.py
+++ b/data/migrations/versions/c3d4b7ebcdf7_backfill_repositorysearchscore_table.py
@@ -7,19 +7,23 @@ Create Date: 2017-04-13 12:01:59.572775
"""
# revision identifiers, used by Alembic.
-revision = 'c3d4b7ebcdf7'
-down_revision = 'f30984525c86'
+revision = "c3d4b7ebcdf7"
+down_revision = "f30984525c86"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# Add a 0 entry into the RepositorySearchScore table for each repository that isn't present
conn = op.get_bind()
- conn.execute("insert into repositorysearchscore (repository_id, score) SELECT id, 0 FROM " +
- "repository WHERE id not in (select repository_id from repositorysearchscore)")
+ conn.execute(
+ "insert into repositorysearchscore (repository_id, score) SELECT id, 0 FROM "
+ + "repository WHERE id not in (select repository_id from repositorysearchscore)"
+ )
+
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
diff --git a/data/migrations/versions/c91c564aad34_drop_checksum_on_imagestorage.py b/data/migrations/versions/c91c564aad34_drop_checksum_on_imagestorage.py
index dc1567bd5..0497e1559 100644
--- a/data/migrations/versions/c91c564aad34_drop_checksum_on_imagestorage.py
+++ b/data/migrations/versions/c91c564aad34_drop_checksum_on_imagestorage.py
@@ -7,8 +7,8 @@ Create Date: 2018-02-21 12:17:52.405644
"""
# revision identifiers, used by Alembic.
-revision = 'c91c564aad34'
-down_revision = '152bb29a1bb3'
+revision = "c91c564aad34"
+down_revision = "152bb29a1bb3"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -17,9 +17,11 @@ import sqlalchemy as sa
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.drop_column('imagestorage', 'checksum')
+ op.drop_column("imagestorage", "checksum")
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.add_column('imagestorage', sa.Column('checksum', sa.String(length=255), nullable=True))
+ op.add_column(
+ "imagestorage", sa.Column("checksum", sa.String(length=255), nullable=True)
+ )
diff --git a/data/migrations/versions/cbc8177760d9_add_user_location_field.py b/data/migrations/versions/cbc8177760d9_add_user_location_field.py
index cbdc87706..866aafd38 100644
--- a/data/migrations/versions/cbc8177760d9_add_user_location_field.py
+++ b/data/migrations/versions/cbc8177760d9_add_user_location_field.py
@@ -7,8 +7,8 @@ Create Date: 2018-02-02 17:39:16.589623
"""
# revision identifiers, used by Alembic.
-revision = 'cbc8177760d9'
-down_revision = '7367229b38d9'
+revision = "cbc8177760d9"
+down_revision = "7367229b38d9"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,15 +16,18 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from util.migrate import UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.add_column('user', sa.Column('location', UTF8CharField(length=255), nullable=True))
+ op.add_column(
+ "user", sa.Column("location", UTF8CharField(length=255), nullable=True)
+ )
# ### population of test data ### #
- tester.populate_column('user', 'location', tester.TestDataType.UTF8Char)
+ tester.populate_column("user", "location", tester.TestDataType.UTF8Char)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
- op.drop_column('user', 'location')
+ op.drop_column("user", "location")
diff --git a/data/migrations/versions/cc6778199cdb_repository_mirror_notification.py b/data/migrations/versions/cc6778199cdb_repository_mirror_notification.py
index a44704eec..8b64e1f65 100644
--- a/data/migrations/versions/cc6778199cdb_repository_mirror_notification.py
+++ b/data/migrations/versions/cc6778199cdb_repository_mirror_notification.py
@@ -7,62 +7,73 @@ Create Date: 2019-10-03 17:41:23.316914
"""
# revision identifiers, used by Alembic.
-revision = 'cc6778199cdb'
-down_revision = 'c059b952ed76'
+revision = "cc6778199cdb"
+down_revision = "c059b952ed76"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
-def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.bulk_insert(tables.notificationkind,
- [
- {'name': 'repo_mirror_sync_started'},
- {'name': 'repo_mirror_sync_success'},
- {'name': 'repo_mirror_sync_failed'},
- ])
- op.bulk_insert(tables.externalnotificationevent,
- [
- {'name': 'repo_mirror_sync_started'},
- {'name': 'repo_mirror_sync_success'},
- {'name': 'repo_mirror_sync_failed'},
- ])
+def upgrade(tables, tester, progress_reporter):
+ op = ProgressWrapper(original_op, progress_reporter)
+
+ op.bulk_insert(
+ tables.notificationkind,
+ [
+ {"name": "repo_mirror_sync_started"},
+ {"name": "repo_mirror_sync_success"},
+ {"name": "repo_mirror_sync_failed"},
+ ],
+ )
+ op.bulk_insert(
+ tables.externalnotificationevent,
+ [
+ {"name": "repo_mirror_sync_started"},
+ {"name": "repo_mirror_sync_success"},
+ {"name": "repo_mirror_sync_failed"},
+ ],
+ )
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
+ op = ProgressWrapper(original_op, progress_reporter)
- op.execute(tables
- .notificationkind
- .delete()
- .where(tables.
- notificationkind.c.name == op.inline_literal('repo_mirror_sync_started')))
- op.execute(tables
- .notificationkind
- .delete()
- .where(tables.
- notificationkind.c.name == op.inline_literal('repo_mirror_sync_success')))
- op.execute(tables
- .notificationkind
- .delete()
- .where(tables.
- notificationkind.c.name == op.inline_literal('repo_mirror_sync_failed')))
+ op.execute(
+ tables.notificationkind.delete().where(
+ tables.notificationkind.c.name
+ == op.inline_literal("repo_mirror_sync_started")
+ )
+ )
+ op.execute(
+ tables.notificationkind.delete().where(
+ tables.notificationkind.c.name
+ == op.inline_literal("repo_mirror_sync_success")
+ )
+ )
+ op.execute(
+ tables.notificationkind.delete().where(
+ tables.notificationkind.c.name
+ == op.inline_literal("repo_mirror_sync_failed")
+ )
+ )
- op.execute(tables
- .externalnotificationevent
- .delete()
- .where(tables.
- externalnotificationevent.c.name == op.inline_literal('repo_mirror_sync_started')))
- op.execute(tables
- .externalnotificationevent
- .delete()
- .where(tables.
- externalnotificationevent.c.name == op.inline_literal('repo_mirror_sync_success')))
- op.execute(tables
- .externalnotificationevent
- .delete()
- .where(tables.
- externalnotificationevent.c.name == op.inline_literal('repo_mirror_sync_failed')))
+ op.execute(
+ tables.externalnotificationevent.delete().where(
+ tables.externalnotificationevent.c.name
+ == op.inline_literal("repo_mirror_sync_started")
+ )
+ )
+ op.execute(
+ tables.externalnotificationevent.delete().where(
+ tables.externalnotificationevent.c.name
+ == op.inline_literal("repo_mirror_sync_success")
+ )
+ )
+ op.execute(
+ tables.externalnotificationevent.delete().where(
+ tables.externalnotificationevent.c.name
+ == op.inline_literal("repo_mirror_sync_failed")
+ )
+ )
diff --git a/data/migrations/versions/d17c695859ea_delete_old_appr_tables.py b/data/migrations/versions/d17c695859ea_delete_old_appr_tables.py
index 9e847e8e2..df4b56814 100644
--- a/data/migrations/versions/d17c695859ea_delete_old_appr_tables.py
+++ b/data/migrations/versions/d17c695859ea_delete_old_appr_tables.py
@@ -7,8 +7,8 @@ Create Date: 2018-07-16 15:21:11.593040
"""
# revision identifiers, used by Alembic.
-revision = 'd17c695859ea'
-down_revision = '5d463ea1e8a8'
+revision = "d17c695859ea"
+down_revision = "5d463ea1e8a8"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,18 +16,19 @@ import sqlalchemy as sa
from sqlalchemy.sql import table, column
from util.migrate import UTF8LongText, UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('tag')
- op.drop_table('manifestlistmanifest')
- op.drop_table('manifestlist')
- op.drop_table('manifestblob')
- op.drop_table('manifest')
- op.drop_table('blobplacement')
- op.drop_table('blob')
- op.drop_table('blobplacementlocation')
- op.drop_table('tagkind')
+ op.drop_table("tag")
+ op.drop_table("manifestlistmanifest")
+ op.drop_table("manifestlist")
+ op.drop_table("manifestblob")
+ op.drop_table("manifest")
+ op.drop_table("blobplacement")
+ op.drop_table("blob")
+ op.drop_table("blobplacementlocation")
+ op.drop_table("tagkind")
# ### end Alembic commands ###
@@ -35,158 +36,257 @@ def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
- 'tagkind',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tagkind'))
+ "tagkind",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tagkind")),
)
- op.create_index('tagkind_name', 'tagkind', ['name'], unique=True)
+ op.create_index("tagkind_name", "tagkind", ["name"], unique=True)
op.create_table(
- 'blobplacementlocation',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacementlocation'))
+ "blobplacementlocation",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", sa.String(length=255), nullable=False),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacementlocation")),
+ )
+ op.create_index(
+ "blobplacementlocation_name", "blobplacementlocation", ["name"], unique=True
)
- op.create_index('blobplacementlocation_name', 'blobplacementlocation', ['name'], unique=True)
op.create_table(
- 'blob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('size', sa.BigInteger(), nullable=False),
- sa.Column('uncompressed_size', sa.BigInteger(), nullable=True),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_blob_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blob'))
+ "blob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("size", sa.BigInteger(), nullable=False),
+ sa.Column("uncompressed_size", sa.BigInteger(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_blob_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blob")),
)
- op.create_index('blob_digest', 'blob', ['digest'], unique=True)
- op.create_index('blob_media_type_id', 'blob', ['media_type_id'], unique=False)
+ op.create_index("blob_digest", "blob", ["digest"], unique=True)
+ op.create_index("blob_media_type_id", "blob", ["media_type_id"], unique=False)
op.create_table(
- 'manifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.Column('manifest_json', UTF8LongText, nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifest'))
+ "manifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_json", UTF8LongText, nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifest")),
+ )
+ op.create_index("manifest_digest", "manifest", ["digest"], unique=True)
+ op.create_index(
+ "manifest_media_type_id", "manifest", ["media_type_id"], unique=False
)
- op.create_index('manifest_digest', 'manifest', ['digest'], unique=True)
- op.create_index('manifest_media_type_id', 'manifest', ['media_type_id'], unique=False)
op.create_table(
- 'manifestlist',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('digest', sa.String(length=255), nullable=False),
- sa.Column('manifest_list_json', UTF8LongText, nullable=False),
- sa.Column('schema_version', UTF8CharField(length=255), nullable=False),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifestlist_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlist'))
+ "manifestlist",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("digest", sa.String(length=255), nullable=False),
+ sa.Column("manifest_list_json", UTF8LongText, nullable=False),
+ sa.Column("schema_version", UTF8CharField(length=255), nullable=False),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifestlist_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlist")),
+ )
+ op.create_index("manifestlist_digest", "manifestlist", ["digest"], unique=True)
+ op.create_index(
+ "manifestlist_media_type_id", "manifestlist", ["media_type_id"], unique=False
)
- op.create_index('manifestlist_digest', 'manifestlist', ['digest'], unique=True)
- op.create_index('manifestlist_media_type_id', 'manifestlist', ['media_type_id'], unique=False)
op.create_table(
- 'blobplacement',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.Column('location_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_blobplacement_blob_id_blob')),
- sa.ForeignKeyConstraint(['location_id'], ['blobplacementlocation.id'], name=op.f('fk_blobplacement_location_id_blobplacementlocation')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_blobplacement'))
+ "blobplacement",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.Column("location_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_blobplacement_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["location_id"],
+ ["blobplacementlocation.id"],
+ name=op.f("fk_blobplacement_location_id_blobplacementlocation"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_blobplacement")),
+ )
+ op.create_index("blobplacement_blob_id", "blobplacement", ["blob_id"], unique=False)
+ op.create_index(
+ "blobplacement_blob_id_location_id",
+ "blobplacement",
+ ["blob_id", "location_id"],
+ unique=True,
+ )
+ op.create_index(
+ "blobplacement_location_id", "blobplacement", ["location_id"], unique=False
)
- op.create_index('blobplacement_blob_id', 'blobplacement', ['blob_id'], unique=False)
- op.create_index('blobplacement_blob_id_location_id', 'blobplacement', ['blob_id', 'location_id'], unique=True)
- op.create_index('blobplacement_location_id', 'blobplacement', ['location_id'], unique=False)
op.create_table(
- 'manifestblob',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('blob_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['blob_id'], ['blob.id'], name=op.f('fk_manifestblob_blob_id_blob')),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestblob_manifest_id_manifest')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestblob'))
+ "manifestblob",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("blob_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["blob_id"], ["blob.id"], name=op.f("fk_manifestblob_blob_id_blob")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestblob_manifest_id_manifest"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestblob")),
+ )
+ op.create_index("manifestblob_blob_id", "manifestblob", ["blob_id"], unique=False)
+ op.create_index(
+ "manifestblob_manifest_id", "manifestblob", ["manifest_id"], unique=False
+ )
+ op.create_index(
+ "manifestblob_manifest_id_blob_id",
+ "manifestblob",
+ ["manifest_id", "blob_id"],
+ unique=True,
)
- op.create_index('manifestblob_blob_id', 'manifestblob', ['blob_id'], unique=False)
- op.create_index('manifestblob_manifest_id', 'manifestblob', ['manifest_id'], unique=False)
- op.create_index('manifestblob_manifest_id_blob_id', 'manifestblob', ['manifest_id', 'blob_id'], unique=True)
op.create_table(
- 'manifestlistmanifest',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=False),
- sa.Column('manifest_id', sa.Integer(), nullable=False),
- sa.Column('operating_system', UTF8CharField(length=255), nullable=True),
- sa.Column('architecture', UTF8CharField(length=255), nullable=True),
- sa.Column('platform_json', UTF8LongText, nullable=True),
- sa.Column('media_type_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['manifest_id'], ['manifest.id'], name=op.f('fk_manifestlistmanifest_manifest_id_manifest')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['manifestlist.id'], name=op.f('fk_manifestlistmanifest_manifest_list_id_manifestlist')),
- sa.ForeignKeyConstraint(['media_type_id'], ['mediatype.id'], name=op.f('fk_manifestlistmanifest_media_type_id_mediatype')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_manifestlistmanifest'))
+ "manifestlistmanifest",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_id", sa.Integer(), nullable=False),
+ sa.Column("operating_system", UTF8CharField(length=255), nullable=True),
+ sa.Column("architecture", UTF8CharField(length=255), nullable=True),
+ sa.Column("platform_json", UTF8LongText, nullable=True),
+ sa.Column("media_type_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["manifest_id"],
+ ["manifest.id"],
+ name=op.f("fk_manifestlistmanifest_manifest_id_manifest"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["manifestlist.id"],
+ name=op.f("fk_manifestlistmanifest_manifest_list_id_manifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["media_type_id"],
+ ["mediatype.id"],
+ name=op.f("fk_manifestlistmanifest_media_type_id_mediatype"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_manifestlistmanifest")),
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_id",
+ "manifestlistmanifest",
+ ["manifest_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_list_id",
+ "manifestlistmanifest",
+ ["manifest_list_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_listid_os_arch_mtid",
+ "manifestlistmanifest",
+ ["manifest_list_id", "operating_system", "architecture", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_manifest_listid_mtid",
+ "manifestlistmanifest",
+ ["manifest_list_id", "media_type_id"],
+ unique=False,
+ )
+ op.create_index(
+ "manifestlistmanifest_media_type_id",
+ "manifestlistmanifest",
+ ["media_type_id"],
+ unique=False,
)
- op.create_index('manifestlistmanifest_manifest_id', 'manifestlistmanifest', ['manifest_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_list_id', 'manifestlistmanifest', ['manifest_list_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_listid_os_arch_mtid', 'manifestlistmanifest', ['manifest_list_id', 'operating_system', 'architecture', 'media_type_id'], unique=False)
- op.create_index('manifestlistmanifest_manifest_listid_mtid', 'manifestlistmanifest', ['manifest_list_id', 'media_type_id'], unique=False)
- op.create_index('manifestlistmanifest_media_type_id', 'manifestlistmanifest', ['media_type_id'], unique=False)
op.create_table(
- 'tag',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('name', UTF8CharField(length=190), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('manifest_list_id', sa.Integer(), nullable=True),
- sa.Column('lifetime_start', sa.BigInteger(), nullable=False),
- sa.Column('lifetime_end', sa.BigInteger(), nullable=True),
- sa.Column('hidden', sa.Boolean(), nullable=False),
- sa.Column('reverted', sa.Boolean(), nullable=False),
- sa.Column('protected', sa.Boolean(), nullable=False),
- sa.Column('tag_kind_id', sa.Integer(), nullable=False),
- sa.Column('linked_tag_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['linked_tag_id'], ['tag.id'], name=op.f('fk_tag_linked_tag_id_tag')),
- sa.ForeignKeyConstraint(['manifest_list_id'], ['manifestlist.id'], name=op.f('fk_tag_manifest_list_id_manifestlist')),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_tag_repository_id_repository')),
- sa.ForeignKeyConstraint(['tag_kind_id'], ['tagkind.id'], name=op.f('fk_tag_tag_kind_id_tagkind')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_tag'))
+ "tag",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("name", UTF8CharField(length=190), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("manifest_list_id", sa.Integer(), nullable=True),
+ sa.Column("lifetime_start", sa.BigInteger(), nullable=False),
+ sa.Column("lifetime_end", sa.BigInteger(), nullable=True),
+ sa.Column("hidden", sa.Boolean(), nullable=False),
+ sa.Column("reverted", sa.Boolean(), nullable=False),
+ sa.Column("protected", sa.Boolean(), nullable=False),
+ sa.Column("tag_kind_id", sa.Integer(), nullable=False),
+ sa.Column("linked_tag_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["linked_tag_id"], ["tag.id"], name=op.f("fk_tag_linked_tag_id_tag")
+ ),
+ sa.ForeignKeyConstraint(
+ ["manifest_list_id"],
+ ["manifestlist.id"],
+ name=op.f("fk_tag_manifest_list_id_manifestlist"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_tag_repository_id_repository"),
+ ),
+ sa.ForeignKeyConstraint(
+ ["tag_kind_id"], ["tagkind.id"], name=op.f("fk_tag_tag_kind_id_tagkind")
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_tag")),
)
- op.create_index('tag_lifetime_end', 'tag', ['lifetime_end'], unique=False)
- op.create_index('tag_linked_tag_id', 'tag', ['linked_tag_id'], unique=False)
- op.create_index('tag_manifest_list_id', 'tag', ['manifest_list_id'], unique=False)
- op.create_index('tag_repository_id', 'tag', ['repository_id'], unique=False)
- op.create_index('tag_repository_id_name_hidden', 'tag', ['repository_id', 'name', 'hidden'], unique=False)
- op.create_index('tag_repository_id_name_lifetime_end', 'tag', ['repository_id', 'name', 'lifetime_end'], unique=True)
- op.create_index('tag_repository_id_name', 'tag', ['repository_id', 'name'], unique=False)
- op.create_index('tag_tag_kind_id', 'tag', ['tag_kind_id'], unique=False)
+ op.create_index("tag_lifetime_end", "tag", ["lifetime_end"], unique=False)
+ op.create_index("tag_linked_tag_id", "tag", ["linked_tag_id"], unique=False)
+ op.create_index("tag_manifest_list_id", "tag", ["manifest_list_id"], unique=False)
+ op.create_index("tag_repository_id", "tag", ["repository_id"], unique=False)
+ op.create_index(
+ "tag_repository_id_name_hidden",
+ "tag",
+ ["repository_id", "name", "hidden"],
+ unique=False,
+ )
+ op.create_index(
+ "tag_repository_id_name_lifetime_end",
+ "tag",
+ ["repository_id", "name", "lifetime_end"],
+ unique=True,
+ )
+ op.create_index(
+ "tag_repository_id_name", "tag", ["repository_id", "name"], unique=False
+ )
+ op.create_index("tag_tag_kind_id", "tag", ["tag_kind_id"], unique=False)
# ### end Alembic commands ###
- blobplacementlocation_table = table('blobplacementlocation',
- column('id', sa.Integer()),
- column('name', sa.String()),
+ blobplacementlocation_table = table(
+ "blobplacementlocation", column("id", sa.Integer()), column("name", sa.String())
)
op.bulk_insert(
- blobplacementlocation_table,
- [
- {'name': 'local_eu'},
- {'name': 'local_us'},
- ],
+ blobplacementlocation_table, [{"name": "local_eu"}, {"name": "local_us"}]
)
- tagkind_table = table('tagkind',
- column('id', sa.Integer()),
- column('name', sa.String()),
+ tagkind_table = table(
+ "tagkind", column("id", sa.Integer()), column("name", sa.String())
)
op.bulk_insert(
tagkind_table,
[
- {'id': 1, 'name': 'tag'},
- {'id': 2, 'name': 'release'},
- {'id': 3, 'name': 'channel'},
- ]
- )
\ No newline at end of file
+ {"id": 1, "name": "tag"},
+ {"id": 2, "name": "release"},
+ {"id": 3, "name": "channel"},
+ ],
+ )
diff --git a/data/migrations/versions/d42c175b439a_backfill_state_id_and_make_it_unique.py b/data/migrations/versions/d42c175b439a_backfill_state_id_and_make_it_unique.py
index 24a65b8a4..a31f1acb3 100644
--- a/data/migrations/versions/d42c175b439a_backfill_state_id_and_make_it_unique.py
+++ b/data/migrations/versions/d42c175b439a_backfill_state_id_and_make_it_unique.py
@@ -7,14 +7,15 @@ Create Date: 2017-01-18 15:11:01.635632
"""
# revision identifiers, used by Alembic.
-revision = 'd42c175b439a'
-down_revision = '3e8cc74a1e7b'
+revision = "d42c175b439a"
+down_revision = "3e8cc74a1e7b"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# Backfill the queueitem table's state_id field with unique values for all entries which are
@@ -23,14 +24,14 @@ def upgrade(tables, tester, progress_reporter):
conn.execute("update queueitem set state_id = id where state_id = ''")
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('queueitem_state_id', table_name='queueitem')
- op.create_index('queueitem_state_id', 'queueitem', ['state_id'], unique=True)
+ op.drop_index("queueitem_state_id", table_name="queueitem")
+ op.create_index("queueitem_state_id", "queueitem", ["state_id"], unique=True)
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('queueitem_state_id', table_name='queueitem')
- op.create_index('queueitem_state_id', 'queueitem', ['state_id'], unique=False)
+ op.drop_index("queueitem_state_id", table_name="queueitem")
+ op.create_index("queueitem_state_id", "queueitem", ["state_id"], unique=False)
# ### end Alembic commands ###
diff --git a/data/migrations/versions/d8989249f8f6_add_change_tag_expiration_log_type.py b/data/migrations/versions/d8989249f8f6_add_change_tag_expiration_log_type.py
index 42ec883eb..27a1dbc49 100644
--- a/data/migrations/versions/d8989249f8f6_add_change_tag_expiration_log_type.py
+++ b/data/migrations/versions/d8989249f8f6_add_change_tag_expiration_log_type.py
@@ -7,22 +7,22 @@ Create Date: 2017-06-21 21:18:25.948689
"""
# revision identifiers, used by Alembic.
-revision = 'd8989249f8f6'
-down_revision = 'dc4af11a5f90'
+revision = "d8989249f8f6"
+down_revision = "dc4af11a5f90"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.bulk_insert(tables.logentrykind, [
- {'name': 'change_tag_expiration'},
- ])
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.bulk_insert(tables.logentrykind, [{"name": "change_tag_expiration"}])
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.logentrykind.c.name == op.inline_literal('change_tag_expiration')))
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.c.name == op.inline_literal("change_tag_expiration")
+ )
+ )
diff --git a/data/migrations/versions/dc4af11a5f90_add_notification_number_of_failures_.py b/data/migrations/versions/dc4af11a5f90_add_notification_number_of_failures_.py
index dc8512026..5ec6a1ddd 100644
--- a/data/migrations/versions/dc4af11a5f90_add_notification_number_of_failures_.py
+++ b/data/migrations/versions/dc4af11a5f90_add_notification_number_of_failures_.py
@@ -7,8 +7,8 @@ Create Date: 2017-05-16 17:24:02.630365
"""
# revision identifiers, used by Alembic.
-revision = 'dc4af11a5f90'
-down_revision = '53e2ac668296'
+revision = "dc4af11a5f90"
+down_revision = "53e2ac668296"
import sqlalchemy as sa
from alembic import op as original_op
@@ -16,24 +16,27 @@ from data.migrations.progress import ProgressWrapper
def upgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.add_column('repositorynotification', sa.Column('number_of_failures',
- sa.Integer(),
- nullable=False,
- server_default='0'))
- op.bulk_insert(tables.logentrykind, [
- {'name': 'reset_repo_notification'},
- ])
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.add_column(
+ "repositorynotification",
+ sa.Column(
+ "number_of_failures", sa.Integer(), nullable=False, server_default="0"
+ ),
+ )
+ op.bulk_insert(tables.logentrykind, [{"name": "reset_repo_notification"}])
- # ### population of test data ### #
- tester.populate_column('repositorynotification', 'number_of_failures', tester.TestDataType.Integer)
- # ### end population of test data ### #
+ # ### population of test data ### #
+ tester.populate_column(
+ "repositorynotification", "number_of_failures", tester.TestDataType.Integer
+ )
+ # ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
- op = ProgressWrapper(original_op, progress_reporter)
- op.drop_column('repositorynotification', 'number_of_failures')
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.logentrykind.c.name == op.inline_literal('reset_repo_notification')))
+ op = ProgressWrapper(original_op, progress_reporter)
+ op.drop_column("repositorynotification", "number_of_failures")
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.c.name == op.inline_literal("reset_repo_notification")
+ )
+ )
diff --git a/data/migrations/versions/e184af42242d_add_missing_index_on_uuid_fields.py b/data/migrations/versions/e184af42242d_add_missing_index_on_uuid_fields.py
index b4513ce6d..4477260ec 100644
--- a/data/migrations/versions/e184af42242d_add_missing_index_on_uuid_fields.py
+++ b/data/migrations/versions/e184af42242d_add_missing_index_on_uuid_fields.py
@@ -7,25 +7,30 @@ Create Date: 2019-02-14 16:35:47.768086
"""
# revision identifiers, used by Alembic.
-revision = 'e184af42242d'
-down_revision = '6ec8726c0ace'
+revision = "e184af42242d"
+down_revision = "6ec8726c0ace"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_index('permissionprototype_uuid', 'permissionprototype', ['uuid'], unique=False)
- op.create_index('repositorybuildtrigger_uuid', 'repositorybuildtrigger', ['uuid'], unique=False)
- op.create_index('user_uuid', 'user', ['uuid'], unique=False)
+ op.create_index(
+ "permissionprototype_uuid", "permissionprototype", ["uuid"], unique=False
+ )
+ op.create_index(
+ "repositorybuildtrigger_uuid", "repositorybuildtrigger", ["uuid"], unique=False
+ )
+ op.create_index("user_uuid", "user", ["uuid"], unique=False)
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('user_uuid', table_name='user')
- op.drop_index('repositorybuildtrigger_uuid', table_name='repositorybuildtrigger')
- op.drop_index('permissionprototype_uuid', table_name='permissionprototype')
+ op.drop_index("user_uuid", table_name="user")
+ op.drop_index("repositorybuildtrigger_uuid", table_name="repositorybuildtrigger")
+ op.drop_index("permissionprototype_uuid", table_name="permissionprototype")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py
index 13ed12ba5..f858e1a33 100644
--- a/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py
+++ b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py
@@ -7,25 +7,42 @@ Create Date: 2017-01-11 13:55:54.890774
"""
# revision identifiers, used by Alembic.
-revision = 'e2894a3a3c19'
-down_revision = 'd42c175b439a'
+revision = "e2894a3a3c19"
+down_revision = "d42c175b439a"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.create_index('repository_description__fulltext', 'repository', ['description'], unique=False, postgresql_using='gin', postgresql_ops={'description': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT')
- op.create_index('repository_name__fulltext', 'repository', ['name'], unique=False, postgresql_using='gin', postgresql_ops={'name': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT')
+ op.create_index(
+ "repository_description__fulltext",
+ "repository",
+ ["description"],
+ unique=False,
+ postgresql_using="gin",
+ postgresql_ops={"description": "gin_trgm_ops"},
+ mysql_prefix="FULLTEXT",
+ )
+ op.create_index(
+ "repository_name__fulltext",
+ "repository",
+ ["name"],
+ unique=False,
+ postgresql_using="gin",
+ postgresql_ops={"name": "gin_trgm_ops"},
+ mysql_prefix="FULLTEXT",
+ )
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('repository_name__fulltext', table_name='repository')
- op.drop_index('repository_description__fulltext', table_name='repository')
+ op.drop_index("repository_name__fulltext", table_name="repository")
+ op.drop_index("repository_description__fulltext", table_name="repository")
# ### end Alembic commands ###
diff --git a/data/migrations/versions/eafdeadcebc7_remove_blob_index_from_manifestblob_.py b/data/migrations/versions/eafdeadcebc7_remove_blob_index_from_manifestblob_.py
index e2e69d99f..3ed8c3856 100644
--- a/data/migrations/versions/eafdeadcebc7_remove_blob_index_from_manifestblob_.py
+++ b/data/migrations/versions/eafdeadcebc7_remove_blob_index_from_manifestblob_.py
@@ -7,25 +7,39 @@ Create Date: 2018-08-07 15:57:54.001225
"""
# revision identifiers, used by Alembic.
-revision = 'eafdeadcebc7'
-down_revision = '9093adccc784'
+revision = "eafdeadcebc7"
+down_revision = "9093adccc784"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('manifestblob_manifest_id_blob_index', table_name='manifestblob')
- op.drop_column('manifestblob', 'blob_index')
+ op.drop_index("manifestblob_manifest_id_blob_index", table_name="manifestblob")
+ op.drop_column("manifestblob", "blob_index")
# ### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('manifestblob', sa.Column('blob_index', mysql.INTEGER(display_width=11), autoincrement=False, nullable=True))
- op.create_index('manifestblob_manifest_id_blob_index', 'manifestblob', ['manifest_id', 'blob_index'], unique=True)
+ op.add_column(
+ "manifestblob",
+ sa.Column(
+ "blob_index",
+ mysql.INTEGER(display_width=11),
+ autoincrement=False,
+ nullable=True,
+ ),
+ )
+ op.create_index(
+ "manifestblob_manifest_id_blob_index",
+ "manifestblob",
+ ["manifest_id", "blob_index"],
+ unique=True,
+ )
# ### end Alembic commands ###
diff --git a/data/migrations/versions/ed01e313d3cb_add_trust_enabled_to_repository.py b/data/migrations/versions/ed01e313d3cb_add_trust_enabled_to_repository.py
index 2a59ee4ec..85d584e0c 100644
--- a/data/migrations/versions/ed01e313d3cb_add_trust_enabled_to_repository.py
+++ b/data/migrations/versions/ed01e313d3cb_add_trust_enabled_to_repository.py
@@ -7,35 +7,42 @@ Create Date: 2017-04-14 17:38:03.319695
"""
# revision identifiers, used by Alembic.
-revision = 'ed01e313d3cb'
-down_revision = 'c3d4b7ebcdf7'
+revision = "ed01e313d3cb"
+down_revision = "c3d4b7ebcdf7"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.add_column('repository', sa.Column('trust_enabled', sa.Boolean(), nullable=False, server_default=sa.sql.expression.false()))
+ op.add_column(
+ "repository",
+ sa.Column(
+ "trust_enabled",
+ sa.Boolean(),
+ nullable=False,
+ server_default=sa.sql.expression.false(),
+ ),
+ )
### end Alembic commands ###
- op.bulk_insert(tables.logentrykind, [
- {'name': 'change_repo_trust'},
- ])
+ op.bulk_insert(tables.logentrykind, [{"name": "change_repo_trust"}])
# ### population of test data ### #
- tester.populate_column('repository', 'trust_enabled', tester.TestDataType.Boolean)
+ tester.populate_column("repository", "trust_enabled", tester.TestDataType.Boolean)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.drop_column('repository', 'trust_enabled')
+ op.drop_column("repository", "trust_enabled")
### end Alembic commands ###
- op.execute(tables
- .logentrykind
- .delete()
- .where(tables.
- logentrykind.name == op.inline_literal('change_repo_trust')))
+ op.execute(
+ tables.logentrykind.delete().where(
+ tables.logentrykind.name == op.inline_literal("change_repo_trust")
+ )
+ )
diff --git a/data/migrations/versions/f30984525c86_add_repositorysearchscore_table.py b/data/migrations/versions/f30984525c86_add_repositorysearchscore_table.py
index f4a0d4045..bdec2639e 100644
--- a/data/migrations/versions/f30984525c86_add_repositorysearchscore_table.py
+++ b/data/migrations/versions/f30984525c86_add_repositorysearchscore_table.py
@@ -7,40 +7,56 @@ Create Date: 2017-04-04 14:30:13.270728
"""
# revision identifiers, used by Alembic.
-revision = 'f30984525c86'
-down_revision = 'be8d1c402ce0'
+revision = "f30984525c86"
+down_revision = "be8d1c402ce0"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.create_table('repositorysearchscore',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('repository_id', sa.Integer(), nullable=False),
- sa.Column('score', sa.BigInteger(), nullable=False),
- sa.Column('last_updated', sa.DateTime(), nullable=True),
- sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], name=op.f('fk_repositorysearchscore_repository_id_repository')),
- sa.PrimaryKeyConstraint('id', name=op.f('pk_repositorysearchscore'))
+ op.create_table(
+ "repositorysearchscore",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("repository_id", sa.Integer(), nullable=False),
+ sa.Column("score", sa.BigInteger(), nullable=False),
+ sa.Column("last_updated", sa.DateTime(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["repository_id"],
+ ["repository.id"],
+ name=op.f("fk_repositorysearchscore_repository_id_repository"),
+ ),
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_repositorysearchscore")),
+ )
+ op.create_index(
+ "repositorysearchscore_repository_id",
+ "repositorysearchscore",
+ ["repository_id"],
+ unique=True,
+ )
+ op.create_index(
+ "repositorysearchscore_score", "repositorysearchscore", ["score"], unique=False
)
- op.create_index('repositorysearchscore_repository_id', 'repositorysearchscore', ['repository_id'], unique=True)
- op.create_index('repositorysearchscore_score', 'repositorysearchscore', ['score'], unique=False)
### end Alembic commands ###
# ### population of test data ### #
- tester.populate_table('repositorysearchscore', [
- ('repository_id', tester.TestDataType.Foreign('repository')),
- ('score', tester.TestDataType.BigInteger),
- ('last_updated', tester.TestDataType.DateTime),
- ])
+ tester.populate_table(
+ "repositorysearchscore",
+ [
+ ("repository_id", tester.TestDataType.Foreign("repository")),
+ ("score", tester.TestDataType.BigInteger),
+ ("last_updated", tester.TestDataType.DateTime),
+ ],
+ )
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.drop_table('repositorysearchscore')
+ op.drop_table("repositorysearchscore")
### end Alembic commands ###
diff --git a/data/migrations/versions/f5167870dd66_update_queue_item_table_indices.py b/data/migrations/versions/f5167870dd66_update_queue_item_table_indices.py
index d801764c1..7875f1f0c 100644
--- a/data/migrations/versions/f5167870dd66_update_queue_item_table_indices.py
+++ b/data/migrations/versions/f5167870dd66_update_queue_item_table_indices.py
@@ -7,37 +7,79 @@ Create Date: 2016-12-08 17:26:20.333846
"""
# revision identifiers, used by Alembic.
-revision = 'f5167870dd66'
-down_revision = '45fd8b9869d4'
+revision = "f5167870dd66"
+down_revision = "45fd8b9869d4"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.create_index('queueitem_processing_expires_available', 'queueitem', ['processing_expires', 'available'], unique=False)
- op.create_index('queueitem_pe_aafter_qname_rremaining_available', 'queueitem', ['processing_expires', 'available_after', 'queue_name', 'retries_remaining', 'available'], unique=False)
- op.create_index('queueitem_pexpires_aafter_rremaining_available', 'queueitem', ['processing_expires', 'available_after', 'retries_remaining', 'available'], unique=False)
- op.create_index('queueitem_processing_expires_queue_name_available', 'queueitem', ['processing_expires', 'queue_name', 'available'], unique=False)
- op.drop_index('queueitem_available', table_name='queueitem')
- op.drop_index('queueitem_available_after', table_name='queueitem')
- op.drop_index('queueitem_processing_expires', table_name='queueitem')
- op.drop_index('queueitem_retries_remaining', table_name='queueitem')
+ op.create_index(
+ "queueitem_processing_expires_available",
+ "queueitem",
+ ["processing_expires", "available"],
+ unique=False,
+ )
+ op.create_index(
+ "queueitem_pe_aafter_qname_rremaining_available",
+ "queueitem",
+ [
+ "processing_expires",
+ "available_after",
+ "queue_name",
+ "retries_remaining",
+ "available",
+ ],
+ unique=False,
+ )
+ op.create_index(
+ "queueitem_pexpires_aafter_rremaining_available",
+ "queueitem",
+ ["processing_expires", "available_after", "retries_remaining", "available"],
+ unique=False,
+ )
+ op.create_index(
+ "queueitem_processing_expires_queue_name_available",
+ "queueitem",
+ ["processing_expires", "queue_name", "available"],
+ unique=False,
+ )
+ op.drop_index("queueitem_available", table_name="queueitem")
+ op.drop_index("queueitem_available_after", table_name="queueitem")
+ op.drop_index("queueitem_processing_expires", table_name="queueitem")
+ op.drop_index("queueitem_retries_remaining", table_name="queueitem")
### end Alembic commands ###
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.create_index('queueitem_retries_remaining', 'queueitem', ['retries_remaining'], unique=False)
- op.create_index('queueitem_processing_expires', 'queueitem', ['processing_expires'], unique=False)
- op.create_index('queueitem_available_after', 'queueitem', ['available_after'], unique=False)
- op.create_index('queueitem_available', 'queueitem', ['available'], unique=False)
- op.drop_index('queueitem_processing_expires_queue_name_available', table_name='queueitem')
- op.drop_index('queueitem_pexpires_aafter_rremaining_available', table_name='queueitem')
- op.drop_index('queueitem_pe_aafter_qname_rremaining_available', table_name='queueitem')
- op.drop_index('queueitem_processing_expires_available', table_name='queueitem')
+ op.create_index(
+ "queueitem_retries_remaining", "queueitem", ["retries_remaining"], unique=False
+ )
+ op.create_index(
+ "queueitem_processing_expires",
+ "queueitem",
+ ["processing_expires"],
+ unique=False,
+ )
+ op.create_index(
+ "queueitem_available_after", "queueitem", ["available_after"], unique=False
+ )
+ op.create_index("queueitem_available", "queueitem", ["available"], unique=False)
+ op.drop_index(
+ "queueitem_processing_expires_queue_name_available", table_name="queueitem"
+ )
+ op.drop_index(
+ "queueitem_pexpires_aafter_rremaining_available", table_name="queueitem"
+ )
+ op.drop_index(
+ "queueitem_pe_aafter_qname_rremaining_available", table_name="queueitem"
+ )
+ op.drop_index("queueitem_processing_expires_available", table_name="queueitem")
### end Alembic commands ###
diff --git a/data/migrations/versions/faf752bd2e0a_add_user_metadata_fields.py b/data/migrations/versions/faf752bd2e0a_add_user_metadata_fields.py
index 3e3b9b9a6..95c3e9a56 100644
--- a/data/migrations/versions/faf752bd2e0a_add_user_metadata_fields.py
+++ b/data/migrations/versions/faf752bd2e0a_add_user_metadata_fields.py
@@ -7,8 +7,8 @@ Create Date: 2016-11-14 17:29:03.984665
"""
# revision identifiers, used by Alembic.
-revision = 'faf752bd2e0a'
-down_revision = '6c7014e84a5e'
+revision = "faf752bd2e0a"
+down_revision = "6c7014e84a5e"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
@@ -16,41 +16,52 @@ import sqlalchemy as sa
from util.migrate import UTF8CharField
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.add_column('user', sa.Column('company', UTF8CharField(length=255), nullable=True))
- op.add_column('user', sa.Column('family_name', UTF8CharField(length=255), nullable=True))
- op.add_column('user', sa.Column('given_name', UTF8CharField(length=255), nullable=True))
+ op.add_column(
+ "user", sa.Column("company", UTF8CharField(length=255), nullable=True)
+ )
+ op.add_column(
+ "user", sa.Column("family_name", UTF8CharField(length=255), nullable=True)
+ )
+ op.add_column(
+ "user", sa.Column("given_name", UTF8CharField(length=255), nullable=True)
+ )
### end Alembic commands ###
- op.bulk_insert(tables.userpromptkind,
- [
- {'name':'enter_name'},
- {'name':'enter_company'},
- ])
+ op.bulk_insert(
+ tables.userpromptkind, [{"name": "enter_name"}, {"name": "enter_company"}]
+ )
# ### population of test data ### #
- tester.populate_column('user', 'company', tester.TestDataType.UTF8Char)
- tester.populate_column('user', 'family_name', tester.TestDataType.UTF8Char)
- tester.populate_column('user', 'given_name', tester.TestDataType.UTF8Char)
+ tester.populate_column("user", "company", tester.TestDataType.UTF8Char)
+ tester.populate_column("user", "family_name", tester.TestDataType.UTF8Char)
+ tester.populate_column("user", "given_name", tester.TestDataType.UTF8Char)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
### commands auto generated by Alembic - please adjust! ###
- op.drop_column('user', 'given_name')
- op.drop_column('user', 'family_name')
- op.drop_column('user', 'company')
+ op.drop_column("user", "given_name")
+ op.drop_column("user", "family_name")
+ op.drop_column("user", "company")
### end Alembic commands ###
op.execute(
- (tables.userpromptkind.delete()
- .where(tables.userpromptkind.c.name == op.inline_literal('enter_name')))
+ (
+ tables.userpromptkind.delete().where(
+ tables.userpromptkind.c.name == op.inline_literal("enter_name")
+ )
)
+ )
op.execute(
- (tables.userpromptkind.delete()
- .where(tables.userpromptkind.c.name == op.inline_literal('enter_company')))
+ (
+ tables.userpromptkind.delete().where(
+ tables.userpromptkind.c.name == op.inline_literal("enter_company")
+ )
)
+ )
diff --git a/data/migrations/versions/fc47c1ec019f_add_state_id_field_to_queueitem.py b/data/migrations/versions/fc47c1ec019f_add_state_id_field_to_queueitem.py
index dd0363ce3..b5defe0a9 100644
--- a/data/migrations/versions/fc47c1ec019f_add_state_id_field_to_queueitem.py
+++ b/data/migrations/versions/fc47c1ec019f_add_state_id_field_to_queueitem.py
@@ -7,29 +7,33 @@ Create Date: 2017-01-12 15:44:23.643016
"""
# revision identifiers, used by Alembic.
-revision = 'fc47c1ec019f'
-down_revision = 'f5167870dd66'
+revision = "fc47c1ec019f"
+down_revision = "f5167870dd66"
from alembic import op as original_op
from data.migrations.progress import ProgressWrapper
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
+
def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('queueitem', sa.Column('state_id', sa.String(length=255), nullable=False, server_default=''))
- op.create_index('queueitem_state_id', 'queueitem', ['state_id'], unique=False)
+ op.add_column(
+ "queueitem",
+ sa.Column("state_id", sa.String(length=255), nullable=False, server_default=""),
+ )
+ op.create_index("queueitem_state_id", "queueitem", ["state_id"], unique=False)
# ### end Alembic commands ###
# ### population of test data ### #
- tester.populate_column('queueitem', 'state_id', tester.TestDataType.String)
+ tester.populate_column("queueitem", "state_id", tester.TestDataType.String)
# ### end population of test data ### #
def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter)
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_index('queueitem_state_id', table_name='queueitem')
- op.drop_column('queueitem', 'state_id')
+ op.drop_index("queueitem_state_id", table_name="queueitem")
+ op.drop_column("queueitem", "state_id")
# ### end Alembic commands ###
diff --git a/data/migrationutil.py b/data/migrationutil.py
index a433605f5..b68223624 100644
--- a/data/migrationutil.py
+++ b/data/migrationutil.py
@@ -4,66 +4,73 @@ from abc import ABCMeta, abstractmethod, abstractproperty
from collections import namedtuple
from six import add_metaclass
-MigrationPhase = namedtuple('MigrationPhase', ['name', 'alembic_revision', 'flags'])
+MigrationPhase = namedtuple("MigrationPhase", ["name", "alembic_revision", "flags"])
@add_metaclass(ABCMeta)
class DataMigration(object):
- @abstractproperty
- def alembic_migration_revision(self):
- """ Returns the alembic migration revision corresponding to the currently configured phase.
+ @abstractproperty
+ def alembic_migration_revision(self):
+ """ Returns the alembic migration revision corresponding to the currently configured phase.
"""
- @abstractmethod
- def has_flag(self, flag):
- """ Returns true if the data migration's current phase has the given flag set. """
+ @abstractmethod
+ def has_flag(self, flag):
+ """ Returns true if the data migration's current phase has the given flag set. """
class NullDataMigration(DataMigration):
- @property
- def alembic_migration_revision(self):
- return 'head'
+ @property
+ def alembic_migration_revision(self):
+ return "head"
- def has_flag(self, flag):
- raise NotImplementedError()
+ def has_flag(self, flag):
+ raise NotImplementedError()
class DefinedDataMigration(DataMigration):
- def __init__(self, name, env_var, phases):
- assert phases
+ def __init__(self, name, env_var, phases):
+ assert phases
- self.name = name
- self.phases = {phase.name: phase for phase in phases}
+ self.name = name
+ self.phases = {phase.name: phase for phase in phases}
- # Add a synthetic phase for new installations that skips the entire migration.
- self.phases['new-installation'] = phases[-1]._replace(name='new-installation',
- alembic_revision='head')
+ # Add a synthetic phase for new installations that skips the entire migration.
+ self.phases["new-installation"] = phases[-1]._replace(
+ name="new-installation", alembic_revision="head"
+ )
- phase_name = os.getenv(env_var)
- if phase_name is None:
- msg = 'Missing env var `%s` for data migration `%s`. %s' % (env_var, self.name,
- self._error_suffix)
- raise Exception(msg)
+ phase_name = os.getenv(env_var)
+ if phase_name is None:
+ msg = "Missing env var `%s` for data migration `%s`. %s" % (
+ env_var,
+ self.name,
+ self._error_suffix,
+ )
+ raise Exception(msg)
- current_phase = self.phases.get(phase_name)
- if current_phase is None:
- msg = 'Unknown phase `%s` for data migration `%s`. %s' % (phase_name, self.name,
- self._error_suffix)
- raise Exception(msg)
+ current_phase = self.phases.get(phase_name)
+ if current_phase is None:
+ msg = "Unknown phase `%s` for data migration `%s`. %s" % (
+ phase_name,
+ self.name,
+ self._error_suffix,
+ )
+ raise Exception(msg)
- self.current_phase = current_phase
+ self.current_phase = current_phase
- @property
- def _error_suffix(self):
- message = 'Available values for this migration: %s. ' % (self.phases.keys())
- message += 'If this is a new installation, please use `new-installation`.'
- return message
+ @property
+ def _error_suffix(self):
+ message = "Available values for this migration: %s. " % (self.phases.keys())
+ message += "If this is a new installation, please use `new-installation`."
+ return message
- @property
- def alembic_migration_revision(self):
- assert self.current_phase
- return self.current_phase.alembic_revision
+ @property
+ def alembic_migration_revision(self):
+ assert self.current_phase
+ return self.current_phase.alembic_revision
- def has_flag(self, flag):
- assert self.current_phase
- return flag in self.current_phase.flags
+ def has_flag(self, flag):
+ assert self.current_phase
+ return flag in self.current_phase.flags
diff --git a/data/model/__init__.py b/data/model/__init__.py
index 2c9260469..357474122 100644
--- a/data/model/__init__.py
+++ b/data/model/__init__.py
@@ -2,122 +2,123 @@ from data.database import db, db_transaction
class DataModelException(Exception):
- pass
+ pass
class InvalidLabelKeyException(DataModelException):
- pass
+ pass
class InvalidMediaTypeException(DataModelException):
- pass
+ pass
class BlobDoesNotExist(DataModelException):
- pass
+ pass
class TorrentInfoDoesNotExist(DataModelException):
- pass
+ pass
class InvalidBlobUpload(DataModelException):
- pass
+ pass
class InvalidEmailAddressException(DataModelException):
- pass
+ pass
class InvalidOrganizationException(DataModelException):
- pass
+ pass
class InvalidPasswordException(DataModelException):
- pass
+ pass
class InvalidRobotException(DataModelException):
- pass
+ pass
class InvalidUsernameException(DataModelException):
- pass
+ pass
class InvalidRepositoryBuildException(DataModelException):
- pass
+ pass
class InvalidBuildTriggerException(DataModelException):
- pass
+ pass
class InvalidTokenException(DataModelException):
- pass
+ pass
class InvalidNotificationException(DataModelException):
- pass
+ pass
class InvalidImageException(DataModelException):
- pass
+ pass
class UserAlreadyInTeam(DataModelException):
- pass
+ pass
class InvalidTeamException(DataModelException):
- pass
+ pass
class InvalidTeamMemberException(DataModelException):
- pass
+ pass
class InvalidManifestException(DataModelException):
- pass
+ pass
class ServiceKeyDoesNotExist(DataModelException):
- pass
+ pass
class ServiceKeyAlreadyApproved(DataModelException):
- pass
+ pass
class ServiceNameInvalid(DataModelException):
- pass
+ pass
class TagAlreadyCreatedException(DataModelException):
- pass
+ pass
+
class StaleTagException(DataModelException):
- pass
+ pass
class TooManyLoginAttemptsException(Exception):
- def __init__(self, message, retry_after):
- super(TooManyLoginAttemptsException, self).__init__(message)
- self.retry_after = retry_after
+ def __init__(self, message, retry_after):
+ super(TooManyLoginAttemptsException, self).__init__(message)
+ self.retry_after = retry_after
class Config(object):
- def __init__(self):
- self.app_config = None
- self.store = None
- self.image_cleanup_callbacks = []
- self.repo_cleanup_callbacks = []
+ def __init__(self):
+ self.app_config = None
+ self.store = None
+ self.image_cleanup_callbacks = []
+ self.repo_cleanup_callbacks = []
- def register_image_cleanup_callback(self, callback):
- self.image_cleanup_callbacks.append(callback)
-
- def register_repo_cleanup_callback(self, callback):
- self.repo_cleanup_callbacks.append(callback)
+ def register_image_cleanup_callback(self, callback):
+ self.image_cleanup_callbacks.append(callback)
+
+ def register_repo_cleanup_callback(self, callback):
+ self.repo_cleanup_callbacks.append(callback)
config = Config()
@@ -126,28 +127,28 @@ config = Config()
# There MUST NOT be any circular dependencies between these subsections. If there are fix it by
# moving the minimal number of things to _basequery
from data.model import (
- appspecifictoken,
- blob,
- build,
- gc,
- image,
- label,
- log,
- message,
- modelutil,
- notification,
- oauth,
- organization,
- permission,
- repositoryactioncount,
- repo_mirror,
- release,
- repo_mirror,
- repository,
- service_keys,
- storage,
- tag,
- team,
- token,
- user,
+ appspecifictoken,
+ blob,
+ build,
+ gc,
+ image,
+ label,
+ log,
+ message,
+ modelutil,
+ notification,
+ oauth,
+ organization,
+ permission,
+ repositoryactioncount,
+ repo_mirror,
+ release,
+ repo_mirror,
+ repository,
+ service_keys,
+ storage,
+ tag,
+ team,
+ token,
+ user,
)
diff --git a/data/model/_basequery.py b/data/model/_basequery.py
index 5fc1733e0..c7297c87c 100644
--- a/data/model/_basequery.py
+++ b/data/model/_basequery.py
@@ -7,192 +7,228 @@ from datetime import datetime, timedelta
from data.model import DataModelException, config
from data.readreplica import ReadOnlyModeException
-from data.database import (Repository, User, Team, TeamMember, RepositoryPermission, TeamRole,
- Namespace, Visibility, ImageStorage, Image, RepositoryKind,
- db_for_update)
+from data.database import (
+ Repository,
+ User,
+ Team,
+ TeamMember,
+ RepositoryPermission,
+ TeamRole,
+ Namespace,
+ Visibility,
+ ImageStorage,
+ Image,
+ RepositoryKind,
+ db_for_update,
+)
logger = logging.getLogger(__name__)
+
def reduce_as_tree(queries_to_reduce):
- """ This method will split a list of queries into halves recursively until we reach individual
+ """ This method will split a list of queries into halves recursively until we reach individual
queries, at which point it will start unioning the queries, or the already unioned subqueries.
This works around a bug in peewee SQL generation where reducing linearly generates a chain
of queries that will exceed the recursion depth limit when it has around 80 queries.
"""
- mid = len(queries_to_reduce)/2
- left = queries_to_reduce[:mid]
- right = queries_to_reduce[mid:]
+ mid = len(queries_to_reduce) / 2
+ left = queries_to_reduce[:mid]
+ right = queries_to_reduce[mid:]
- to_reduce_right = right[0]
- if len(right) > 1:
- to_reduce_right = reduce_as_tree(right)
+ to_reduce_right = right[0]
+ if len(right) > 1:
+ to_reduce_right = reduce_as_tree(right)
- if len(left) > 1:
- to_reduce_left = reduce_as_tree(left)
- elif len(left) == 1:
- to_reduce_left = left[0]
- else:
- return to_reduce_right
+ if len(left) > 1:
+ to_reduce_left = reduce_as_tree(left)
+ elif len(left) == 1:
+ to_reduce_left = left[0]
+ else:
+ return to_reduce_right
- return to_reduce_left.union_all(to_reduce_right)
+ return to_reduce_left.union_all(to_reduce_right)
-def get_existing_repository(namespace_name, repository_name, for_update=False, kind_filter=None):
- query = (Repository
- .select(Repository, Namespace)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name,
- Repository.name == repository_name))
+def get_existing_repository(
+ namespace_name, repository_name, for_update=False, kind_filter=None
+):
+ query = (
+ Repository.select(Repository, Namespace)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
- if kind_filter:
- query = (query
- .switch(Repository)
- .join(RepositoryKind)
- .where(RepositoryKind.name == kind_filter))
+ if kind_filter:
+ query = (
+ query.switch(Repository)
+ .join(RepositoryKind)
+ .where(RepositoryKind.name == kind_filter)
+ )
- if for_update:
- query = db_for_update(query)
+ if for_update:
+ query = db_for_update(query)
- return query.get()
+ return query.get()
@lru_cache(maxsize=1)
def get_public_repo_visibility():
- return Visibility.get(name='public')
+ return Visibility.get(name="public")
def _lookup_team_role(name):
- return _lookup_team_roles()[name]
+ return _lookup_team_roles()[name]
@lru_cache(maxsize=1)
def _lookup_team_roles():
- return {role.name:role for role in TeamRole.select()}
+ return {role.name: role for role in TeamRole.select()}
-def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='image',
- include_public=True, start_id=None):
- if not include_public and not user_id:
- return Repository.select().where(Repository.id == '-1')
+def filter_to_repos_for_user(
+ query,
+ user_id=None,
+ namespace=None,
+ repo_kind="image",
+ include_public=True,
+ start_id=None,
+):
+ if not include_public and not user_id:
+ return Repository.select().where(Repository.id == "-1")
- # Filter on the type of repository.
- if repo_kind is not None:
- try:
- query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
- except RepositoryKind.DoesNotExist:
- raise DataModelException('Unknown repository kind')
+ # Filter on the type of repository.
+ if repo_kind is not None:
+ try:
+ query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
+ except RepositoryKind.DoesNotExist:
+ raise DataModelException("Unknown repository kind")
- # Add the start ID if necessary.
- if start_id is not None:
- query = query.where(Repository.id >= start_id)
+ # Add the start ID if necessary.
+ if start_id is not None:
+ query = query.where(Repository.id >= start_id)
- # Add a namespace filter if necessary.
- if namespace:
- query = query.where(Namespace.username == namespace)
+ # Add a namespace filter if necessary.
+ if namespace:
+ query = query.where(Namespace.username == namespace)
- # Build a set of queries that, when unioned together, return the full set of visible repositories
- # for the filters specified.
- queries = []
+ # Build a set of queries that, when unioned together, return the full set of visible repositories
+ # for the filters specified.
+ queries = []
- if include_public:
- queries.append(query.where(Repository.visibility == get_public_repo_visibility()))
+ if include_public:
+ queries.append(
+ query.where(Repository.visibility == get_public_repo_visibility())
+ )
- if user_id is not None:
- AdminTeam = Team.alias()
- AdminTeamMember = TeamMember.alias()
+ if user_id is not None:
+ AdminTeam = Team.alias()
+ AdminTeamMember = TeamMember.alias()
- # Add repositories in which the user has permission.
- queries.append(query
- .switch(RepositoryPermission)
- .where(RepositoryPermission.user == user_id))
+ # Add repositories in which the user has permission.
+ queries.append(
+ query.switch(RepositoryPermission).where(
+ RepositoryPermission.user == user_id
+ )
+ )
- # Add repositories in which the user is a member of a team that has permission.
- queries.append(query
- .switch(RepositoryPermission)
- .join(Team)
- .join(TeamMember)
- .where(TeamMember.user == user_id))
+ # Add repositories in which the user is a member of a team that has permission.
+ queries.append(
+ query.switch(RepositoryPermission)
+ .join(Team)
+ .join(TeamMember)
+ .where(TeamMember.user == user_id)
+ )
- # Add repositories under namespaces in which the user is the org admin.
- queries.append(query
- .switch(Repository)
- .join(AdminTeam, on=(Repository.namespace_user == AdminTeam.organization))
- .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
- .where(AdminTeam.role == _lookup_team_role('admin'))
- .where(AdminTeamMember.user == user_id))
+ # Add repositories under namespaces in which the user is the org admin.
+ queries.append(
+ query.switch(Repository)
+ .join(AdminTeam, on=(Repository.namespace_user == AdminTeam.organization))
+ .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
+ .where(AdminTeam.role == _lookup_team_role("admin"))
+ .where(AdminTeamMember.user == user_id)
+ )
- return reduce(lambda l, r: l | r, queries)
+ return reduce(lambda l, r: l | r, queries)
def get_user_organizations(username):
- UserAlias = User.alias()
- return (User
- .select()
- .distinct()
- .join(Team)
- .join(TeamMember)
- .join(UserAlias, on=(UserAlias.id == TeamMember.user))
- .where(User.organization == True, UserAlias.username == username))
+ UserAlias = User.alias()
+ return (
+ User.select()
+ .distinct()
+ .join(Team)
+ .join(TeamMember)
+ .join(UserAlias, on=(UserAlias.id == TeamMember.user))
+ .where(User.organization == True, UserAlias.username == username)
+ )
def calculate_image_aggregate_size(ancestors_str, image_size, parent_image):
- ancestors = ancestors_str.split('/')[1:-1]
- if not ancestors:
- return image_size
+ ancestors = ancestors_str.split("/")[1:-1]
+ if not ancestors:
+ return image_size
- if parent_image is None:
- raise DataModelException('Could not load parent image')
+ if parent_image is None:
+ raise DataModelException("Could not load parent image")
+
+ ancestor_size = parent_image.aggregate_size
+ if ancestor_size is not None:
+ return ancestor_size + image_size
+
+ # Fallback to a slower path if the parent doesn't have an aggregate size saved.
+ # TODO: remove this code if/when we do a full backfill.
+ ancestor_size = (
+ ImageStorage.select(fn.Sum(ImageStorage.image_size))
+ .join(Image)
+ .where(Image.id << ancestors)
+ .scalar()
+ )
+ if ancestor_size is None:
+ return None
- ancestor_size = parent_image.aggregate_size
- if ancestor_size is not None:
return ancestor_size + image_size
- # Fallback to a slower path if the parent doesn't have an aggregate size saved.
- # TODO: remove this code if/when we do a full backfill.
- ancestor_size = (ImageStorage
- .select(fn.Sum(ImageStorage.image_size))
- .join(Image)
- .where(Image.id << ancestors)
- .scalar())
- if ancestor_size is None:
- return None
-
- return ancestor_size + image_size
-
def update_last_accessed(token_or_user):
- """ Updates the `last_accessed` field on the given token or user. If the existing field's value
+ """ Updates the `last_accessed` field on the given token or user. If the existing field's value
is within the configured threshold, the update is skipped. """
- if not config.app_config.get('FEATURE_USER_LAST_ACCESSED'):
- return
+ if not config.app_config.get("FEATURE_USER_LAST_ACCESSED"):
+ return
- threshold = timedelta(seconds=config.app_config.get('LAST_ACCESSED_UPDATE_THRESHOLD_S', 120))
- if (token_or_user.last_accessed is not None and
- datetime.utcnow() - token_or_user.last_accessed < threshold):
- # Skip updating, as we don't want to put undue pressure on the database.
- return
+ threshold = timedelta(
+ seconds=config.app_config.get("LAST_ACCESSED_UPDATE_THRESHOLD_S", 120)
+ )
+ if (
+ token_or_user.last_accessed is not None
+ and datetime.utcnow() - token_or_user.last_accessed < threshold
+ ):
+ # Skip updating, as we don't want to put undue pressure on the database.
+ return
- model_class = token_or_user.__class__
- last_accessed = datetime.utcnow()
+ model_class = token_or_user.__class__
+ last_accessed = datetime.utcnow()
- try:
- (model_class
- .update(last_accessed=last_accessed)
- .where(model_class.id == token_or_user.id)
- .execute())
- token_or_user.last_accessed = last_accessed
- except ReadOnlyModeException:
- pass
- except PeeweeException as ex:
- # If there is any form of DB exception, only fail if strict logging is enabled.
- strict_logging_disabled = config.app_config.get('ALLOW_PULLS_WITHOUT_STRICT_LOGGING')
- if strict_logging_disabled:
- data = {
- 'exception': ex,
- 'token_or_user': token_or_user.id,
- 'class': str(model_class),
- }
+ try:
+ (
+ model_class.update(last_accessed=last_accessed)
+ .where(model_class.id == token_or_user.id)
+ .execute()
+ )
+ token_or_user.last_accessed = last_accessed
+ except ReadOnlyModeException:
+ pass
+ except PeeweeException as ex:
+ # If there is any form of DB exception, only fail if strict logging is enabled.
+ strict_logging_disabled = config.app_config.get(
+ "ALLOW_PULLS_WITHOUT_STRICT_LOGGING"
+ )
+ if strict_logging_disabled:
+ data = {
+ "exception": ex,
+ "token_or_user": token_or_user.id,
+ "class": str(model_class),
+ }
- logger.exception('update last_accessed for token/user failed', extra=data)
- else:
- raise
+ logger.exception("update last_accessed for token/user failed", extra=data)
+ else:
+ raise
diff --git a/data/model/appspecifictoken.py b/data/model/appspecifictoken.py
index c0ead9440..d3fa87ecc 100644
--- a/data/model/appspecifictoken.py
+++ b/data/model/appspecifictoken.py
@@ -17,156 +17,176 @@ MINIMUM_TOKEN_SUFFIX_LENGTH = 60
def _default_expiration_duration():
- expiration_str = config.app_config.get('APP_SPECIFIC_TOKEN_EXPIRATION')
- return convert_to_timedelta(expiration_str) if expiration_str else None
+ expiration_str = config.app_config.get("APP_SPECIFIC_TOKEN_EXPIRATION")
+ return convert_to_timedelta(expiration_str) if expiration_str else None
# Define a "unique" value so that callers can specifiy an expiration of None and *not* have it
# use the default.
-_default_expiration_duration_opt = '__deo'
+_default_expiration_duration_opt = "__deo"
+
def create_token(user, title, expiration=_default_expiration_duration_opt):
- """ Creates and returns an app specific token for the given user. If no expiration is specified
+ """ Creates and returns an app specific token for the given user. If no expiration is specified
(including `None`), then the default from config is used. """
- if expiration == _default_expiration_duration_opt:
- duration = _default_expiration_duration()
- expiration = duration + datetime.now() if duration else None
+ if expiration == _default_expiration_duration_opt:
+ duration = _default_expiration_duration()
+ expiration = duration + datetime.now() if duration else None
- token_code = random_string_generator(TOKEN_NAME_PREFIX_LENGTH + MINIMUM_TOKEN_SUFFIX_LENGTH)()
- token_name = token_code[:TOKEN_NAME_PREFIX_LENGTH]
- token_secret = token_code[TOKEN_NAME_PREFIX_LENGTH:]
+ token_code = random_string_generator(
+ TOKEN_NAME_PREFIX_LENGTH + MINIMUM_TOKEN_SUFFIX_LENGTH
+ )()
+ token_name = token_code[:TOKEN_NAME_PREFIX_LENGTH]
+ token_secret = token_code[TOKEN_NAME_PREFIX_LENGTH:]
- assert token_name
- assert token_secret
+ assert token_name
+ assert token_secret
- # TODO(remove-unenc): Remove legacy handling.
- old_token_code = (token_code
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS)
- else None)
- return AppSpecificAuthToken.create(user=user,
- title=title,
- expiration=expiration,
- token_name=token_name,
- token_secret=DecryptedValue(token_secret),
- token_code=old_token_code)
+ # TODO(remove-unenc): Remove legacy handling.
+ old_token_code = (
+ token_code
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS)
+ else None
+ )
+ return AppSpecificAuthToken.create(
+ user=user,
+ title=title,
+ expiration=expiration,
+ token_name=token_name,
+ token_secret=DecryptedValue(token_secret),
+ token_code=old_token_code,
+ )
def list_tokens(user):
- """ Lists all tokens for the given user. """
- return AppSpecificAuthToken.select().where(AppSpecificAuthToken.user == user)
+ """ Lists all tokens for the given user. """
+ return AppSpecificAuthToken.select().where(AppSpecificAuthToken.user == user)
def revoke_token(token):
- """ Revokes an app specific token by deleting it. """
- token.delete_instance()
+ """ Revokes an app specific token by deleting it. """
+ token.delete_instance()
def revoke_token_by_uuid(uuid, owner):
- """ Revokes an app specific token by deleting it. """
- try:
- token = AppSpecificAuthToken.get(uuid=uuid, user=owner)
- except AppSpecificAuthToken.DoesNotExist:
- return None
+ """ Revokes an app specific token by deleting it. """
+ try:
+ token = AppSpecificAuthToken.get(uuid=uuid, user=owner)
+ except AppSpecificAuthToken.DoesNotExist:
+ return None
- revoke_token(token)
- return token
+ revoke_token(token)
+ return token
def get_expiring_tokens(user, soon):
- """ Returns all tokens owned by the given user that will be expiring "soon", where soon is defined
+ """ Returns all tokens owned by the given user that will be expiring "soon", where soon is defined
by the soon parameter (a timedelta from now).
"""
- soon_datetime = datetime.now() + soon
- return (AppSpecificAuthToken
- .select()
- .where(AppSpecificAuthToken.user == user,
- AppSpecificAuthToken.expiration <= soon_datetime,
- AppSpecificAuthToken.expiration > datetime.now()))
+ soon_datetime = datetime.now() + soon
+ return AppSpecificAuthToken.select().where(
+ AppSpecificAuthToken.user == user,
+ AppSpecificAuthToken.expiration <= soon_datetime,
+ AppSpecificAuthToken.expiration > datetime.now(),
+ )
def gc_expired_tokens(expiration_window):
- """ Deletes all expired tokens outside of the expiration window. """
- (AppSpecificAuthToken
- .delete()
- .where(AppSpecificAuthToken.expiration < (datetime.now() - expiration_window))
- .execute())
+ """ Deletes all expired tokens outside of the expiration window. """
+ (
+ AppSpecificAuthToken.delete()
+ .where(AppSpecificAuthToken.expiration < (datetime.now() - expiration_window))
+ .execute()
+ )
def get_token_by_uuid(uuid, owner=None):
- """ Looks up an unexpired app specific token with the given uuid. Returns it if found or
+ """ Looks up an unexpired app specific token with the given uuid. Returns it if found or
None if none. If owner is specified, only tokens owned by the owner user will be
returned.
"""
- try:
- query = (AppSpecificAuthToken
- .select()
- .where(AppSpecificAuthToken.uuid == uuid,
- ((AppSpecificAuthToken.expiration > datetime.now()) |
- (AppSpecificAuthToken.expiration >> None))))
- if owner is not None:
- query = query.where(AppSpecificAuthToken.user == owner)
+ try:
+ query = AppSpecificAuthToken.select().where(
+ AppSpecificAuthToken.uuid == uuid,
+ (
+ (AppSpecificAuthToken.expiration > datetime.now())
+ | (AppSpecificAuthToken.expiration >> None)
+ ),
+ )
+ if owner is not None:
+ query = query.where(AppSpecificAuthToken.user == owner)
- return query.get()
- except AppSpecificAuthToken.DoesNotExist:
- return None
+ return query.get()
+ except AppSpecificAuthToken.DoesNotExist:
+ return None
def access_valid_token(token_code):
- """ Looks up an unexpired app specific token with the given token code. If found, the token's
+ """ Looks up an unexpired app specific token with the given token code. If found, the token's
last_accessed field is set to now and the token is returned. If not found, returns None.
"""
- token_code = remove_unicode(token_code)
+ token_code = remove_unicode(token_code)
- prefix = token_code[:TOKEN_NAME_PREFIX_LENGTH]
- if len(prefix) != TOKEN_NAME_PREFIX_LENGTH:
- return None
+ prefix = token_code[:TOKEN_NAME_PREFIX_LENGTH]
+ if len(prefix) != TOKEN_NAME_PREFIX_LENGTH:
+ return None
- suffix = token_code[TOKEN_NAME_PREFIX_LENGTH:]
+ suffix = token_code[TOKEN_NAME_PREFIX_LENGTH:]
- # Lookup the token by its prefix.
- try:
- token = (AppSpecificAuthToken
- .select(AppSpecificAuthToken, User)
- .join(User)
- .where(AppSpecificAuthToken.token_name == prefix,
- ((AppSpecificAuthToken.expiration > datetime.now()) |
- (AppSpecificAuthToken.expiration >> None)))
- .get())
-
- if not token.token_secret.matches(suffix):
- return None
-
- assert len(prefix) == TOKEN_NAME_PREFIX_LENGTH
- assert len(suffix) >= MINIMUM_TOKEN_SUFFIX_LENGTH
- update_last_accessed(token)
- return token
- except AppSpecificAuthToken.DoesNotExist:
- pass
-
- # TODO(remove-unenc): Remove legacy handling.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ # Lookup the token by its prefix.
try:
- token = (AppSpecificAuthToken
- .select(AppSpecificAuthToken, User)
- .join(User)
- .where(AppSpecificAuthToken.token_code == token_code,
- ((AppSpecificAuthToken.expiration > datetime.now()) |
- (AppSpecificAuthToken.expiration >> None)))
- .get())
+ token = (
+ AppSpecificAuthToken.select(AppSpecificAuthToken, User)
+ .join(User)
+ .where(
+ AppSpecificAuthToken.token_name == prefix,
+ (
+ (AppSpecificAuthToken.expiration > datetime.now())
+ | (AppSpecificAuthToken.expiration >> None)
+ ),
+ )
+ .get()
+ )
- update_last_accessed(token)
- return token
+ if not token.token_secret.matches(suffix):
+ return None
+
+ assert len(prefix) == TOKEN_NAME_PREFIX_LENGTH
+ assert len(suffix) >= MINIMUM_TOKEN_SUFFIX_LENGTH
+ update_last_accessed(token)
+ return token
except AppSpecificAuthToken.DoesNotExist:
- return None
+ pass
- return None
+ # TODO(remove-unenc): Remove legacy handling.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ try:
+ token = (
+ AppSpecificAuthToken.select(AppSpecificAuthToken, User)
+ .join(User)
+ .where(
+ AppSpecificAuthToken.token_code == token_code,
+ (
+ (AppSpecificAuthToken.expiration > datetime.now())
+ | (AppSpecificAuthToken.expiration >> None)
+ ),
+ )
+ .get()
+ )
+
+ update_last_accessed(token)
+ return token
+ except AppSpecificAuthToken.DoesNotExist:
+ return None
+
+ return None
def get_full_token_string(token):
- # TODO(remove-unenc): Remove legacy handling.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- if not token.token_name:
- return token.token_code
+ # TODO(remove-unenc): Remove legacy handling.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ if not token.token_name:
+ return token.token_code
- assert token.token_name
- return '%s%s' % (token.token_name, token.token_secret.decrypt())
+ assert token.token_name
+ return "%s%s" % (token.token_name, token.token_secret.decrypt())
diff --git a/data/model/blob.py b/data/model/blob.py
index ac14891e8..5eed2fd7b 100644
--- a/data/model/blob.py
+++ b/data/model/blob.py
@@ -3,235 +3,302 @@ import logging
from datetime import datetime
from uuid import uuid4
-from data.model import (tag, _basequery, BlobDoesNotExist, InvalidBlobUpload, db_transaction,
- storage as storage_model, InvalidImageException)
-from data.database import (Repository, Namespace, ImageStorage, Image, ImageStoragePlacement,
- BlobUpload, ImageStorageLocation, db_random_func)
+from data.model import (
+ tag,
+ _basequery,
+ BlobDoesNotExist,
+ InvalidBlobUpload,
+ db_transaction,
+ storage as storage_model,
+ InvalidImageException,
+)
+from data.database import (
+ Repository,
+ Namespace,
+ ImageStorage,
+ Image,
+ ImageStoragePlacement,
+ BlobUpload,
+ ImageStorageLocation,
+ db_random_func,
+)
logger = logging.getLogger(__name__)
def get_repository_blob_by_digest(repository, blob_digest):
- """ Find the content-addressable blob linked to the specified repository.
+ """ Find the content-addressable blob linked to the specified repository.
"""
- assert blob_digest
- try:
- storage = (ImageStorage
- .select(ImageStorage.uuid)
- .join(Image)
- .where(Image.repository == repository,
- ImageStorage.content_checksum == blob_digest,
- ImageStorage.uploading == False)
- .get())
+ assert blob_digest
+ try:
+ storage = (
+ ImageStorage.select(ImageStorage.uuid)
+ .join(Image)
+ .where(
+ Image.repository == repository,
+ ImageStorage.content_checksum == blob_digest,
+ ImageStorage.uploading == False,
+ )
+ .get()
+ )
- return storage_model.get_storage_by_uuid(storage.uuid)
- except (ImageStorage.DoesNotExist, InvalidImageException):
- raise BlobDoesNotExist('Blob does not exist with digest: {0}'.format(blob_digest))
+ return storage_model.get_storage_by_uuid(storage.uuid)
+ except (ImageStorage.DoesNotExist, InvalidImageException):
+ raise BlobDoesNotExist(
+ "Blob does not exist with digest: {0}".format(blob_digest)
+ )
def get_repo_blob_by_digest(namespace, repo_name, blob_digest):
- """ Find the content-addressable blob linked to the specified repository.
+ """ Find the content-addressable blob linked to the specified repository.
"""
- assert blob_digest
- try:
- storage = (ImageStorage
- .select(ImageStorage.uuid)
- .join(Image)
- .join(Repository)
- .join(Namespace, on=(Namespace.id == Repository.namespace_user))
- .where(Repository.name == repo_name, Namespace.username == namespace,
- ImageStorage.content_checksum == blob_digest,
- ImageStorage.uploading == False)
- .get())
-
- return storage_model.get_storage_by_uuid(storage.uuid)
- except (ImageStorage.DoesNotExist, InvalidImageException):
- raise BlobDoesNotExist('Blob does not exist with digest: {0}'.format(blob_digest))
-
-
-def store_blob_record_and_temp_link(namespace, repo_name, blob_digest, location_obj, byte_count,
- link_expiration_s, uncompressed_byte_count=None):
- repo = _basequery.get_existing_repository(namespace, repo_name)
- assert repo
-
- return store_blob_record_and_temp_link_in_repo(repo.id, blob_digest, location_obj, byte_count,
- link_expiration_s, uncompressed_byte_count)
-
-
-def store_blob_record_and_temp_link_in_repo(repository_id, blob_digest, location_obj, byte_count,
- link_expiration_s, uncompressed_byte_count=None):
- """ Store a record of the blob and temporarily link it to the specified repository.
- """
- assert blob_digest
- assert byte_count is not None
-
- with db_transaction():
+ assert blob_digest
try:
- storage = ImageStorage.get(content_checksum=blob_digest)
- save_changes = False
+ storage = (
+ ImageStorage.select(ImageStorage.uuid)
+ .join(Image)
+ .join(Repository)
+ .join(Namespace, on=(Namespace.id == Repository.namespace_user))
+ .where(
+ Repository.name == repo_name,
+ Namespace.username == namespace,
+ ImageStorage.content_checksum == blob_digest,
+ ImageStorage.uploading == False,
+ )
+ .get()
+ )
- if storage.image_size is None:
- storage.image_size = byte_count
- save_changes = True
+ return storage_model.get_storage_by_uuid(storage.uuid)
+ except (ImageStorage.DoesNotExist, InvalidImageException):
+ raise BlobDoesNotExist(
+ "Blob does not exist with digest: {0}".format(blob_digest)
+ )
- if storage.uncompressed_size is None and uncompressed_byte_count is not None:
- storage.uncompressed_size = uncompressed_byte_count
- save_changes = True
- if save_changes:
- storage.save()
+def store_blob_record_and_temp_link(
+ namespace,
+ repo_name,
+ blob_digest,
+ location_obj,
+ byte_count,
+ link_expiration_s,
+ uncompressed_byte_count=None,
+):
+ repo = _basequery.get_existing_repository(namespace, repo_name)
+ assert repo
- ImageStoragePlacement.get(storage=storage, location=location_obj)
- except ImageStorage.DoesNotExist:
- storage = ImageStorage.create(content_checksum=blob_digest, uploading=False,
- image_size=byte_count,
- uncompressed_size=uncompressed_byte_count)
- ImageStoragePlacement.create(storage=storage, location=location_obj)
- except ImageStoragePlacement.DoesNotExist:
- ImageStoragePlacement.create(storage=storage, location=location_obj)
+ return store_blob_record_and_temp_link_in_repo(
+ repo.id,
+ blob_digest,
+ location_obj,
+ byte_count,
+ link_expiration_s,
+ uncompressed_byte_count,
+ )
- _temp_link_blob(repository_id, storage, link_expiration_s)
- return storage
+
+def store_blob_record_and_temp_link_in_repo(
+ repository_id,
+ blob_digest,
+ location_obj,
+ byte_count,
+ link_expiration_s,
+ uncompressed_byte_count=None,
+):
+ """ Store a record of the blob and temporarily link it to the specified repository.
+ """
+ assert blob_digest
+ assert byte_count is not None
+
+ with db_transaction():
+ try:
+ storage = ImageStorage.get(content_checksum=blob_digest)
+ save_changes = False
+
+ if storage.image_size is None:
+ storage.image_size = byte_count
+ save_changes = True
+
+ if (
+ storage.uncompressed_size is None
+ and uncompressed_byte_count is not None
+ ):
+ storage.uncompressed_size = uncompressed_byte_count
+ save_changes = True
+
+ if save_changes:
+ storage.save()
+
+ ImageStoragePlacement.get(storage=storage, location=location_obj)
+ except ImageStorage.DoesNotExist:
+ storage = ImageStorage.create(
+ content_checksum=blob_digest,
+ uploading=False,
+ image_size=byte_count,
+ uncompressed_size=uncompressed_byte_count,
+ )
+ ImageStoragePlacement.create(storage=storage, location=location_obj)
+ except ImageStoragePlacement.DoesNotExist:
+ ImageStoragePlacement.create(storage=storage, location=location_obj)
+
+ _temp_link_blob(repository_id, storage, link_expiration_s)
+ return storage
def temp_link_blob(repository_id, blob_digest, link_expiration_s):
- """ Temporarily links to the blob record from the given repository. If the blob record is not
+ """ Temporarily links to the blob record from the given repository. If the blob record is not
found, return None.
"""
- assert blob_digest
+ assert blob_digest
- with db_transaction():
- try:
- storage = ImageStorage.get(content_checksum=blob_digest)
- except ImageStorage.DoesNotExist:
- return None
+ with db_transaction():
+ try:
+ storage = ImageStorage.get(content_checksum=blob_digest)
+ except ImageStorage.DoesNotExist:
+ return None
- _temp_link_blob(repository_id, storage, link_expiration_s)
- return storage
+ _temp_link_blob(repository_id, storage, link_expiration_s)
+ return storage
def _temp_link_blob(repository_id, storage, link_expiration_s):
- """ Note: Should *always* be called by a parent under a transaction. """
- random_image_name = str(uuid4())
+ """ Note: Should *always* be called by a parent under a transaction. """
+ random_image_name = str(uuid4())
- # Create a temporary link into the repository, to be replaced by the v1 metadata later
- # and create a temporary tag to reference it
- image = Image.create(storage=storage, docker_image_id=random_image_name, repository=repository_id)
- tag.create_temporary_hidden_tag(repository_id, image, link_expiration_s)
+ # Create a temporary link into the repository, to be replaced by the v1 metadata later
+ # and create a temporary tag to reference it
+ image = Image.create(
+ storage=storage, docker_image_id=random_image_name, repository=repository_id
+ )
+ tag.create_temporary_hidden_tag(repository_id, image, link_expiration_s)
def get_stale_blob_upload(stale_timespan):
- """ Returns a random blob upload which was created before the stale timespan. """
- stale_threshold = datetime.now() - stale_timespan
+ """ Returns a random blob upload which was created before the stale timespan. """
+ stale_threshold = datetime.now() - stale_timespan
- try:
- candidates = (BlobUpload
- .select()
- .where(BlobUpload.created <= stale_threshold)
- .limit(500)
- .distinct()
- .alias('candidates'))
+ try:
+ candidates = (
+ BlobUpload.select()
+ .where(BlobUpload.created <= stale_threshold)
+ .limit(500)
+ .distinct()
+ .alias("candidates")
+ )
- found = (BlobUpload
- .select(candidates.c.id)
- .from_(candidates)
- .order_by(db_random_func())
- .get())
- if not found:
- return None
+ found = (
+ BlobUpload.select(candidates.c.id)
+ .from_(candidates)
+ .order_by(db_random_func())
+ .get()
+ )
+ if not found:
+ return None
- return (BlobUpload
- .select(BlobUpload, ImageStorageLocation)
+ return (
+ BlobUpload.select(BlobUpload, ImageStorageLocation)
.join(ImageStorageLocation)
.where(BlobUpload.id == found.id)
- .get())
- except BlobUpload.DoesNotExist:
- return None
+ .get()
+ )
+ except BlobUpload.DoesNotExist:
+ return None
def get_blob_upload_by_uuid(upload_uuid):
- """ Loads the upload with the given UUID, if any. """
- try:
- return (BlobUpload
- .select()
- .where(BlobUpload.uuid == upload_uuid)
- .get())
- except BlobUpload.DoesNotExist:
- return None
+ """ Loads the upload with the given UUID, if any. """
+ try:
+ return BlobUpload.select().where(BlobUpload.uuid == upload_uuid).get()
+ except BlobUpload.DoesNotExist:
+ return None
def get_blob_upload(namespace, repo_name, upload_uuid):
- """ Load the upload which is already in progress.
+ """ Load the upload which is already in progress.
"""
- try:
- return (BlobUpload
- .select(BlobUpload, ImageStorageLocation)
+ try:
+ return (
+ BlobUpload.select(BlobUpload, ImageStorageLocation)
.join(ImageStorageLocation)
.switch(BlobUpload)
.join(Repository)
.join(Namespace, on=(Namespace.id == Repository.namespace_user))
- .where(Repository.name == repo_name, Namespace.username == namespace,
- BlobUpload.uuid == upload_uuid)
- .get())
- except BlobUpload.DoesNotExist:
- raise InvalidBlobUpload()
+ .where(
+ Repository.name == repo_name,
+ Namespace.username == namespace,
+ BlobUpload.uuid == upload_uuid,
+ )
+ .get()
+ )
+ except BlobUpload.DoesNotExist:
+ raise InvalidBlobUpload()
def initiate_upload(namespace, repo_name, uuid, location_name, storage_metadata):
- """ Initiates a blob upload for the repository with the given namespace and name,
+ """ Initiates a blob upload for the repository with the given namespace and name,
in a specific location. """
- repo = _basequery.get_existing_repository(namespace, repo_name)
- return initiate_upload_for_repo(repo, uuid, location_name, storage_metadata)
+ repo = _basequery.get_existing_repository(namespace, repo_name)
+ return initiate_upload_for_repo(repo, uuid, location_name, storage_metadata)
def initiate_upload_for_repo(repo, uuid, location_name, storage_metadata):
- """ Initiates a blob upload for a specific repository object, in a specific location. """
- location = storage_model.get_image_location_for_name(location_name)
- return BlobUpload.create(repository=repo, location=location.id, uuid=uuid,
- storage_metadata=storage_metadata)
+ """ Initiates a blob upload for a specific repository object, in a specific location. """
+ location = storage_model.get_image_location_for_name(location_name)
+ return BlobUpload.create(
+ repository=repo,
+ location=location.id,
+ uuid=uuid,
+ storage_metadata=storage_metadata,
+ )
def get_shared_blob(digest):
- """ Returns the ImageStorage blob with the given digest or, if not present,
+ """ Returns the ImageStorage blob with the given digest or, if not present,
returns None. This method is *only* to be used for shared blobs that are
globally accessible, such as the special empty gzipped tar layer that Docker
no longer pushes to us.
"""
- assert digest
- try:
- return ImageStorage.get(content_checksum=digest, uploading=False)
- except ImageStorage.DoesNotExist:
- return None
+ assert digest
+ try:
+ return ImageStorage.get(content_checksum=digest, uploading=False)
+ except ImageStorage.DoesNotExist:
+ return None
def get_or_create_shared_blob(digest, byte_data, storage):
- """ Returns the ImageStorage blob with the given digest or, if not present,
+ """ Returns the ImageStorage blob with the given digest or, if not present,
adds a row and writes the given byte data to the storage engine.
This method is *only* to be used for shared blobs that are globally
accessible, such as the special empty gzipped tar layer that Docker
no longer pushes to us.
"""
- assert digest
- assert byte_data is not None
- assert storage
+ assert digest
+ assert byte_data is not None
+ assert storage
- try:
- return ImageStorage.get(content_checksum=digest, uploading=False)
- except ImageStorage.DoesNotExist:
- record = ImageStorage.create(image_size=len(byte_data), content_checksum=digest,
- cas_path=True, uploading=True)
- preferred = storage.preferred_locations[0]
- location_obj = ImageStorageLocation.get(name=preferred)
try:
- storage.put_content([preferred], storage_model.get_layer_path(record), byte_data)
- ImageStoragePlacement.create(storage=record, location=location_obj)
+ return ImageStorage.get(content_checksum=digest, uploading=False)
+ except ImageStorage.DoesNotExist:
+ record = ImageStorage.create(
+ image_size=len(byte_data),
+ content_checksum=digest,
+ cas_path=True,
+ uploading=True,
+ )
+ preferred = storage.preferred_locations[0]
+ location_obj = ImageStorageLocation.get(name=preferred)
+ try:
+ storage.put_content(
+ [preferred], storage_model.get_layer_path(record), byte_data
+ )
+ ImageStoragePlacement.create(storage=record, location=location_obj)
- record.uploading = False
- record.save()
- except:
- logger.exception('Exception when trying to write special layer %s', digest)
- record.delete_instance()
- raise
+ record.uploading = False
+ record.save()
+ except:
+ logger.exception("Exception when trying to write special layer %s", digest)
+ record.delete_instance()
+ raise
- return record
+ return record
diff --git a/data/model/build.py b/data/model/build.py
index 79e282509..3732a631e 100644
--- a/data/model/build.py
+++ b/data/model/build.py
@@ -5,59 +5,88 @@ from datetime import timedelta, datetime
from peewee import JOIN
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from data.database import (BuildTriggerService, RepositoryBuildTrigger, Repository, Namespace, User,
- RepositoryBuild, BUILD_PHASE, db_random_func, UseThenDisconnect,
- TRIGGER_DISABLE_REASON)
-from data.model import (InvalidBuildTriggerException, InvalidRepositoryBuildException,
- db_transaction, user as user_model, config)
+from data.database import (
+ BuildTriggerService,
+ RepositoryBuildTrigger,
+ Repository,
+ Namespace,
+ User,
+ RepositoryBuild,
+ BUILD_PHASE,
+ db_random_func,
+ UseThenDisconnect,
+ TRIGGER_DISABLE_REASON,
+)
+from data.model import (
+ InvalidBuildTriggerException,
+ InvalidRepositoryBuildException,
+ db_transaction,
+ user as user_model,
+ config,
+)
from data.fields import DecryptedValue
PRESUMED_DEAD_BUILD_AGE = timedelta(days=15)
-PHASES_NOT_ALLOWED_TO_CANCEL_FROM = (BUILD_PHASE.PUSHING, BUILD_PHASE.COMPLETE,
- BUILD_PHASE.ERROR, BUILD_PHASE.INTERNAL_ERROR)
+PHASES_NOT_ALLOWED_TO_CANCEL_FROM = (
+ BUILD_PHASE.PUSHING,
+ BUILD_PHASE.COMPLETE,
+ BUILD_PHASE.ERROR,
+ BUILD_PHASE.INTERNAL_ERROR,
+)
-ARCHIVABLE_BUILD_PHASES = [BUILD_PHASE.COMPLETE, BUILD_PHASE.ERROR, BUILD_PHASE.CANCELLED]
+ARCHIVABLE_BUILD_PHASES = [
+ BUILD_PHASE.COMPLETE,
+ BUILD_PHASE.ERROR,
+ BUILD_PHASE.CANCELLED,
+]
def update_build_trigger(trigger, config, auth_token=None, write_token=None):
- trigger.config = json.dumps(config or {})
+ trigger.config = json.dumps(config or {})
- # TODO(remove-unenc): Remove legacy field.
- if auth_token is not None:
+ # TODO(remove-unenc): Remove legacy field.
+ if auth_token is not None:
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ trigger.auth_token = auth_token
+
+ trigger.secure_auth_token = auth_token
+
+ if write_token is not None:
+ trigger.write_token = write_token
+
+ trigger.save()
+
+
+def create_build_trigger(
+ repo, service_name, auth_token, user, pull_robot=None, config=None
+):
+ service = BuildTriggerService.get(name=service_name)
+
+ # TODO(remove-unenc): Remove legacy field.
+ old_auth_token = None
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- trigger.auth_token = auth_token
+ old_auth_token = auth_token
- trigger.secure_auth_token = auth_token
-
- if write_token is not None:
- trigger.write_token = write_token
-
- trigger.save()
-
-
-def create_build_trigger(repo, service_name, auth_token, user, pull_robot=None, config=None):
- service = BuildTriggerService.get(name=service_name)
-
- # TODO(remove-unenc): Remove legacy field.
- old_auth_token = None
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- old_auth_token = auth_token
-
- secure_auth_token = DecryptedValue(auth_token) if auth_token else None
- trigger = RepositoryBuildTrigger.create(repository=repo, service=service,
- auth_token=old_auth_token,
- secure_auth_token=secure_auth_token,
- connected_user=user,
- pull_robot=pull_robot,
- config=json.dumps(config or {}))
- return trigger
+ secure_auth_token = DecryptedValue(auth_token) if auth_token else None
+ trigger = RepositoryBuildTrigger.create(
+ repository=repo,
+ service=service,
+ auth_token=old_auth_token,
+ secure_auth_token=secure_auth_token,
+ connected_user=user,
+ pull_robot=pull_robot,
+ config=json.dumps(config or {}),
+ )
+ return trigger
def get_build_trigger(trigger_uuid):
- try:
- return (RepositoryBuildTrigger
- .select(RepositoryBuildTrigger, BuildTriggerService, Repository, Namespace)
+ try:
+ return (
+ RepositoryBuildTrigger.select(
+ RepositoryBuildTrigger, BuildTriggerService, Repository, Namespace
+ )
.join(BuildTriggerService)
.switch(RepositoryBuildTrigger)
.join(Repository)
@@ -65,259 +94,304 @@ def get_build_trigger(trigger_uuid):
.switch(RepositoryBuildTrigger)
.join(User, on=(RepositoryBuildTrigger.connected_user == User.id))
.where(RepositoryBuildTrigger.uuid == trigger_uuid)
- .get())
- except RepositoryBuildTrigger.DoesNotExist:
- msg = 'No build trigger with uuid: %s' % trigger_uuid
- raise InvalidBuildTriggerException(msg)
+ .get()
+ )
+ except RepositoryBuildTrigger.DoesNotExist:
+ msg = "No build trigger with uuid: %s" % trigger_uuid
+ raise InvalidBuildTriggerException(msg)
def list_build_triggers(namespace_name, repository_name):
- return (RepositoryBuildTrigger
- .select(RepositoryBuildTrigger, BuildTriggerService, Repository)
- .join(BuildTriggerService)
- .switch(RepositoryBuildTrigger)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ return (
+ RepositoryBuildTrigger.select(
+ RepositoryBuildTrigger, BuildTriggerService, Repository
+ )
+ .join(BuildTriggerService)
+ .switch(RepositoryBuildTrigger)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
-def list_trigger_builds(namespace_name, repository_name, trigger_uuid,
- limit):
- return (list_repository_builds(namespace_name, repository_name, limit)
- .where(RepositoryBuildTrigger.uuid == trigger_uuid))
+def list_trigger_builds(namespace_name, repository_name, trigger_uuid, limit):
+ return list_repository_builds(namespace_name, repository_name, limit).where(
+ RepositoryBuildTrigger.uuid == trigger_uuid
+ )
def get_repository_for_resource(resource_key):
- try:
- return (Repository
- .select(Repository, Namespace)
+ try:
+ return (
+ Repository.select(Repository, Namespace)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository)
.join(RepositoryBuild)
.where(RepositoryBuild.resource_key == resource_key)
- .get())
- except Repository.DoesNotExist:
- return None
+ .get()
+ )
+ except Repository.DoesNotExist:
+ return None
def _get_build_base_query():
- return (RepositoryBuild
- .select(RepositoryBuild, RepositoryBuildTrigger, BuildTriggerService, Repository,
- Namespace, User)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(RepositoryBuild)
- .join(User, JOIN.LEFT_OUTER)
- .switch(RepositoryBuild)
- .join(RepositoryBuildTrigger, JOIN.LEFT_OUTER)
- .join(BuildTriggerService, JOIN.LEFT_OUTER)
- .order_by(RepositoryBuild.started.desc()))
+ return (
+ RepositoryBuild.select(
+ RepositoryBuild,
+ RepositoryBuildTrigger,
+ BuildTriggerService,
+ Repository,
+ Namespace,
+ User,
+ )
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(RepositoryBuild)
+ .join(User, JOIN.LEFT_OUTER)
+ .switch(RepositoryBuild)
+ .join(RepositoryBuildTrigger, JOIN.LEFT_OUTER)
+ .join(BuildTriggerService, JOIN.LEFT_OUTER)
+ .order_by(RepositoryBuild.started.desc())
+ )
def get_repository_build(build_uuid):
- try:
- return _get_build_base_query().where(RepositoryBuild.uuid == build_uuid).get()
+ try:
+ return _get_build_base_query().where(RepositoryBuild.uuid == build_uuid).get()
- except RepositoryBuild.DoesNotExist:
- msg = 'Unable to locate a build by id: %s' % build_uuid
- raise InvalidRepositoryBuildException(msg)
+ except RepositoryBuild.DoesNotExist:
+ msg = "Unable to locate a build by id: %s" % build_uuid
+ raise InvalidRepositoryBuildException(msg)
-def list_repository_builds(namespace_name, repository_name, limit,
- include_inactive=True, since=None):
- query = (_get_build_base_query()
- .where(Repository.name == repository_name, Namespace.username == namespace_name)
- .limit(limit))
+def list_repository_builds(
+ namespace_name, repository_name, limit, include_inactive=True, since=None
+):
+ query = (
+ _get_build_base_query()
+ .where(Repository.name == repository_name, Namespace.username == namespace_name)
+ .limit(limit)
+ )
- if since is not None:
- query = query.where(RepositoryBuild.started >= since)
+ if since is not None:
+ query = query.where(RepositoryBuild.started >= since)
- if not include_inactive:
- query = query.where(RepositoryBuild.phase != BUILD_PHASE.ERROR,
- RepositoryBuild.phase != BUILD_PHASE.COMPLETE)
+ if not include_inactive:
+ query = query.where(
+ RepositoryBuild.phase != BUILD_PHASE.ERROR,
+ RepositoryBuild.phase != BUILD_PHASE.COMPLETE,
+ )
- return query
+ return query
def get_recent_repository_build(namespace_name, repository_name):
- query = list_repository_builds(namespace_name, repository_name, 1)
- try:
- return query.get()
- except RepositoryBuild.DoesNotExist:
- return None
+ query = list_repository_builds(namespace_name, repository_name, 1)
+ try:
+ return query.get()
+ except RepositoryBuild.DoesNotExist:
+ return None
-def create_repository_build(repo, access_token, job_config_obj, dockerfile_id,
- display_name, trigger=None, pull_robot_name=None):
- pull_robot = None
- if pull_robot_name:
- pull_robot = user_model.lookup_robot(pull_robot_name)
+def create_repository_build(
+ repo,
+ access_token,
+ job_config_obj,
+ dockerfile_id,
+ display_name,
+ trigger=None,
+ pull_robot_name=None,
+):
+ pull_robot = None
+ if pull_robot_name:
+ pull_robot = user_model.lookup_robot(pull_robot_name)
- return RepositoryBuild.create(repository=repo, access_token=access_token,
- job_config=json.dumps(job_config_obj),
- display_name=display_name, trigger=trigger,
- resource_key=dockerfile_id,
- pull_robot=pull_robot)
+ return RepositoryBuild.create(
+ repository=repo,
+ access_token=access_token,
+ job_config=json.dumps(job_config_obj),
+ display_name=display_name,
+ trigger=trigger,
+ resource_key=dockerfile_id,
+ pull_robot=pull_robot,
+ )
def get_pull_robot_name(trigger):
- if not trigger.pull_robot:
- return None
+ if not trigger.pull_robot:
+ return None
- return trigger.pull_robot.username
+ return trigger.pull_robot.username
def _get_build_row(build_uuid):
- return RepositoryBuild.select().where(RepositoryBuild.uuid == build_uuid).get()
+ return RepositoryBuild.select().where(RepositoryBuild.uuid == build_uuid).get()
def update_phase_then_close(build_uuid, phase):
- """ A function to change the phase of a build """
- with UseThenDisconnect(config.app_config):
- try:
- build = _get_build_row(build_uuid)
- except RepositoryBuild.DoesNotExist:
- return False
+ """ A function to change the phase of a build """
+ with UseThenDisconnect(config.app_config):
+ try:
+ build = _get_build_row(build_uuid)
+ except RepositoryBuild.DoesNotExist:
+ return False
- # Can't update a cancelled build
- if build.phase == BUILD_PHASE.CANCELLED:
- return False
+ # Can't update a cancelled build
+ if build.phase == BUILD_PHASE.CANCELLED:
+ return False
- updated = (RepositoryBuild
- .update(phase=phase)
- .where(RepositoryBuild.id == build.id, RepositoryBuild.phase == build.phase)
- .execute())
+ updated = (
+ RepositoryBuild.update(phase=phase)
+ .where(RepositoryBuild.id == build.id, RepositoryBuild.phase == build.phase)
+ .execute()
+ )
- return updated > 0
+ return updated > 0
def create_cancel_build_in_queue(build_phase, build_queue_id, build_queue):
- """ A function to cancel a build before it leaves the queue """
+ """ A function to cancel a build before it leaves the queue """
- def cancel_build():
- cancelled = False
+ def cancel_build():
+ cancelled = False
- if build_queue_id is not None:
- cancelled = build_queue.cancel(build_queue_id)
+ if build_queue_id is not None:
+ cancelled = build_queue.cancel(build_queue_id)
- if build_phase != BUILD_PHASE.WAITING:
- return False
+ if build_phase != BUILD_PHASE.WAITING:
+ return False
- return cancelled
+ return cancelled
- return cancel_build
+ return cancel_build
def create_cancel_build_in_manager(build_phase, build_uuid, build_canceller):
- """ A function to cancel the build before it starts to push """
+ """ A function to cancel the build before it starts to push """
- def cancel_build():
- if build_phase in PHASES_NOT_ALLOWED_TO_CANCEL_FROM:
- return False
+ def cancel_build():
+ if build_phase in PHASES_NOT_ALLOWED_TO_CANCEL_FROM:
+ return False
- return build_canceller.try_cancel_build(build_uuid)
+ return build_canceller.try_cancel_build(build_uuid)
- return cancel_build
+ return cancel_build
def cancel_repository_build(build, build_queue):
- """ This tries to cancel the build returns true if request is successful false
+ """ This tries to cancel the build returns true if request is successful false
if it can't be cancelled """
- from app import build_canceller
- from buildman.jobutil.buildjob import BuildJobNotifier
+ from app import build_canceller
+ from buildman.jobutil.buildjob import BuildJobNotifier
- cancel_builds = [create_cancel_build_in_queue(build.phase, build.queue_id, build_queue),
- create_cancel_build_in_manager(build.phase, build.uuid, build_canceller), ]
- for cancelled in cancel_builds:
- if cancelled():
- updated = update_phase_then_close(build.uuid, BUILD_PHASE.CANCELLED)
- if updated:
- BuildJobNotifier(build.uuid).send_notification("build_cancelled")
+ cancel_builds = [
+ create_cancel_build_in_queue(build.phase, build.queue_id, build_queue),
+ create_cancel_build_in_manager(build.phase, build.uuid, build_canceller),
+ ]
+ for cancelled in cancel_builds:
+ if cancelled():
+ updated = update_phase_then_close(build.uuid, BUILD_PHASE.CANCELLED)
+ if updated:
+ BuildJobNotifier(build.uuid).send_notification("build_cancelled")
- return updated
+ return updated
- return False
+ return False
def get_archivable_build():
- presumed_dead_date = datetime.utcnow() - PRESUMED_DEAD_BUILD_AGE
+ presumed_dead_date = datetime.utcnow() - PRESUMED_DEAD_BUILD_AGE
- candidates = (RepositoryBuild
- .select(RepositoryBuild.id)
- .where((RepositoryBuild.phase << ARCHIVABLE_BUILD_PHASES) |
- (RepositoryBuild.started < presumed_dead_date),
- RepositoryBuild.logs_archived == False)
- .limit(50)
- .alias('candidates'))
+ candidates = (
+ RepositoryBuild.select(RepositoryBuild.id)
+ .where(
+ (RepositoryBuild.phase << ARCHIVABLE_BUILD_PHASES)
+ | (RepositoryBuild.started < presumed_dead_date),
+ RepositoryBuild.logs_archived == False,
+ )
+ .limit(50)
+ .alias("candidates")
+ )
- try:
- found_id = (RepositoryBuild
- .select(candidates.c.id)
- .from_(candidates)
- .order_by(db_random_func())
- .get())
- return RepositoryBuild.get(id=found_id)
- except RepositoryBuild.DoesNotExist:
- return None
+ try:
+ found_id = (
+ RepositoryBuild.select(candidates.c.id)
+ .from_(candidates)
+ .order_by(db_random_func())
+ .get()
+ )
+ return RepositoryBuild.get(id=found_id)
+ except RepositoryBuild.DoesNotExist:
+ return None
def mark_build_archived(build_uuid):
- """ Mark a build as archived, and return True if we were the ones who actually
+ """ Mark a build as archived, and return True if we were the ones who actually
updated the row. """
- return (RepositoryBuild
- .update(logs_archived=True)
- .where(RepositoryBuild.uuid == build_uuid,
- RepositoryBuild.logs_archived == False)
- .execute()) > 0
+ return (
+ RepositoryBuild.update(logs_archived=True)
+ .where(
+ RepositoryBuild.uuid == build_uuid, RepositoryBuild.logs_archived == False
+ )
+ .execute()
+ ) > 0
def toggle_build_trigger(trigger, enabled, reason=TRIGGER_DISABLE_REASON.USER_TOGGLED):
- """ Toggles the enabled status of a build trigger. """
- trigger.enabled = enabled
+ """ Toggles the enabled status of a build trigger. """
+ trigger.enabled = enabled
- if not enabled:
- trigger.disabled_reason = RepositoryBuildTrigger.disabled_reason.get_id(reason)
- trigger.disabled_datetime = datetime.utcnow()
+ if not enabled:
+ trigger.disabled_reason = RepositoryBuildTrigger.disabled_reason.get_id(reason)
+ trigger.disabled_datetime = datetime.utcnow()
- trigger.save()
+ trigger.save()
def update_trigger_disable_status(trigger, final_phase):
- """ Updates the disable status of the given build trigger. If the build trigger had a
+ """ Updates the disable status of the given build trigger. If the build trigger had a
failure, then the counter is increased and, if we've reached the limit, the trigger is
automatically disabled. Otherwise, if the trigger succeeded, it's counter is reset. This
ensures that triggers that continue to error are eventually automatically disabled.
"""
- with db_transaction():
- try:
- trigger = RepositoryBuildTrigger.get(id=trigger.id)
- except RepositoryBuildTrigger.DoesNotExist:
- # Already deleted.
- return
+ with db_transaction():
+ try:
+ trigger = RepositoryBuildTrigger.get(id=trigger.id)
+ except RepositoryBuildTrigger.DoesNotExist:
+ # Already deleted.
+ return
- # If the build completed successfully, then reset the successive counters.
- if final_phase == BUILD_PHASE.COMPLETE:
- trigger.successive_failure_count = 0
- trigger.successive_internal_error_count = 0
- trigger.save()
- return
+ # If the build completed successfully, then reset the successive counters.
+ if final_phase == BUILD_PHASE.COMPLETE:
+ trigger.successive_failure_count = 0
+ trigger.successive_internal_error_count = 0
+ trigger.save()
+ return
- # Otherwise, increment the counters and check for trigger disable.
- if final_phase == BUILD_PHASE.ERROR:
- trigger.successive_failure_count = trigger.successive_failure_count + 1
- trigger.successive_internal_error_count = 0
- elif final_phase == BUILD_PHASE.INTERNAL_ERROR:
- trigger.successive_internal_error_count = trigger.successive_internal_error_count + 1
+ # Otherwise, increment the counters and check for trigger disable.
+ if final_phase == BUILD_PHASE.ERROR:
+ trigger.successive_failure_count = trigger.successive_failure_count + 1
+ trigger.successive_internal_error_count = 0
+ elif final_phase == BUILD_PHASE.INTERNAL_ERROR:
+ trigger.successive_internal_error_count = (
+ trigger.successive_internal_error_count + 1
+ )
- # Check if we need to disable the trigger.
- failure_threshold = config.app_config.get('SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD')
- error_threshold = config.app_config.get('SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD')
+ # Check if we need to disable the trigger.
+ failure_threshold = config.app_config.get(
+ "SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD"
+ )
+ error_threshold = config.app_config.get(
+ "SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD"
+ )
- if failure_threshold and trigger.successive_failure_count >= failure_threshold:
- toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.BUILD_FALURES)
- elif (error_threshold and
- trigger.successive_internal_error_count >= error_threshold):
- toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.INTERNAL_ERRORS)
- else:
- # Save the trigger changes.
- trigger.save()
+ if failure_threshold and trigger.successive_failure_count >= failure_threshold:
+ toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.BUILD_FALURES)
+ elif (
+ error_threshold
+ and trigger.successive_internal_error_count >= error_threshold
+ ):
+ toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.INTERNAL_ERRORS)
+ else:
+ # Save the trigger changes.
+ trigger.save()
diff --git a/data/model/gc.py b/data/model/gc.py
index 7f898bec8..4f3d72489 100644
--- a/data/model/gc.py
+++ b/data/model/gc.py
@@ -4,551 +4,601 @@ from data.model import config, db_transaction, storage, _basequery, tag as pre_o
from data.model.oci import tag as oci_tag
from data.database import Repository, db_for_update
from data.database import ApprTag
-from data.database import (Tag, Manifest, ManifestBlob, ManifestChild, ManifestLegacyImage,
- ManifestLabel, Label, TagManifestLabel)
+from data.database import (
+ Tag,
+ Manifest,
+ ManifestBlob,
+ ManifestChild,
+ ManifestLegacyImage,
+ ManifestLabel,
+ Label,
+ TagManifestLabel,
+)
from data.database import RepositoryTag, TagManifest, Image, DerivedStorageForImage
from data.database import TagManifestToManifest, TagToRepositoryTag, TagManifestLabelMap
logger = logging.getLogger(__name__)
+
class _GarbageCollectorContext(object):
- def __init__(self, repository):
- self.repository = repository
- self.manifest_ids = set()
- self.label_ids = set()
- self.blob_ids = set()
- self.legacy_image_ids = set()
+ def __init__(self, repository):
+ self.repository = repository
+ self.manifest_ids = set()
+ self.label_ids = set()
+ self.blob_ids = set()
+ self.legacy_image_ids = set()
- def add_manifest_id(self, manifest_id):
- self.manifest_ids.add(manifest_id)
+ def add_manifest_id(self, manifest_id):
+ self.manifest_ids.add(manifest_id)
- def add_label_id(self, label_id):
- self.label_ids.add(label_id)
+ def add_label_id(self, label_id):
+ self.label_ids.add(label_id)
- def add_blob_id(self, blob_id):
- self.blob_ids.add(blob_id)
+ def add_blob_id(self, blob_id):
+ self.blob_ids.add(blob_id)
- def add_legacy_image_id(self, legacy_image_id):
- self.legacy_image_ids.add(legacy_image_id)
+ def add_legacy_image_id(self, legacy_image_id):
+ self.legacy_image_ids.add(legacy_image_id)
- def mark_label_id_removed(self, label_id):
- self.label_ids.remove(label_id)
+ def mark_label_id_removed(self, label_id):
+ self.label_ids.remove(label_id)
- def mark_manifest_removed(self, manifest):
- self.manifest_ids.remove(manifest.id)
+ def mark_manifest_removed(self, manifest):
+ self.manifest_ids.remove(manifest.id)
- def mark_legacy_image_removed(self, legacy_image):
- self.legacy_image_ids.remove(legacy_image.id)
+ def mark_legacy_image_removed(self, legacy_image):
+ self.legacy_image_ids.remove(legacy_image.id)
- def mark_blob_id_removed(self, blob_id):
- self.blob_ids.remove(blob_id)
+ def mark_blob_id_removed(self, blob_id):
+ self.blob_ids.remove(blob_id)
def purge_repository(namespace_name, repository_name):
- """ Completely delete all traces of the repository. Will return True upon
+ """ Completely delete all traces of the repository. Will return True upon
complete success, and False upon partial or total failure. Garbage
collection is incremental and repeatable, so this return value does
not need to be checked or responded to.
"""
- try:
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- except Repository.DoesNotExist:
- return False
+ try:
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ except Repository.DoesNotExist:
+ return False
- assert repo.name == repository_name
+ assert repo.name == repository_name
- # Delete the repository of all Appr-referenced entries.
- # Note that new-model Tag's must be deleted in *two* passes, as they can reference parent tags,
- # and MySQL is... particular... about such relationships when deleting.
- if repo.kind.name == 'application':
- ApprTag.delete().where(ApprTag.repository == repo, ~(ApprTag.linked_tag >> None)).execute()
- ApprTag.delete().where(ApprTag.repository == repo).execute()
- else:
- # GC to remove the images and storage.
- _purge_repository_contents(repo)
+ # Delete the repository of all Appr-referenced entries.
+ # Note that new-model Tag's must be deleted in *two* passes, as they can reference parent tags,
+ # and MySQL is... particular... about such relationships when deleting.
+ if repo.kind.name == "application":
+ ApprTag.delete().where(
+ ApprTag.repository == repo, ~(ApprTag.linked_tag >> None)
+ ).execute()
+ ApprTag.delete().where(ApprTag.repository == repo).execute()
+ else:
+ # GC to remove the images and storage.
+ _purge_repository_contents(repo)
- # Ensure there are no additional tags, manifests, images or blobs in the repository.
- assert ApprTag.select().where(ApprTag.repository == repo).count() == 0
- assert Tag.select().where(Tag.repository == repo).count() == 0
- assert RepositoryTag.select().where(RepositoryTag.repository == repo).count() == 0
- assert Manifest.select().where(Manifest.repository == repo).count() == 0
- assert ManifestBlob.select().where(ManifestBlob.repository == repo).count() == 0
- assert Image.select().where(Image.repository == repo).count() == 0
+ # Ensure there are no additional tags, manifests, images or blobs in the repository.
+ assert ApprTag.select().where(ApprTag.repository == repo).count() == 0
+ assert Tag.select().where(Tag.repository == repo).count() == 0
+ assert RepositoryTag.select().where(RepositoryTag.repository == repo).count() == 0
+ assert Manifest.select().where(Manifest.repository == repo).count() == 0
+ assert ManifestBlob.select().where(ManifestBlob.repository == repo).count() == 0
+ assert Image.select().where(Image.repository == repo).count() == 0
- # Delete the rest of the repository metadata.
- try:
- # Make sure the repository still exists.
- fetched = _basequery.get_existing_repository(namespace_name, repository_name)
- except Repository.DoesNotExist:
- return False
+ # Delete the rest of the repository metadata.
+ try:
+ # Make sure the repository still exists.
+ fetched = _basequery.get_existing_repository(namespace_name, repository_name)
+ except Repository.DoesNotExist:
+ return False
- fetched.delete_instance(recursive=True, delete_nullable=False)
+ fetched.delete_instance(recursive=True, delete_nullable=False)
- # Run callbacks
- for callback in config.repo_cleanup_callbacks:
- callback(namespace_name, repository_name)
+ # Run callbacks
+ for callback in config.repo_cleanup_callbacks:
+ callback(namespace_name, repository_name)
- return True
+ return True
def _chunk_iterate_for_deletion(query, chunk_size=10):
- """ Returns an iterator that loads the rows returned by the given query in chunks. Note that
+ """ Returns an iterator that loads the rows returned by the given query in chunks. Note that
order is not guaranteed here, so this will only work (i.e. not return duplicates) if
the rows returned are being deleted between calls.
"""
- while True:
- results = list(query.limit(chunk_size))
- if not results:
- raise StopIteration
+ while True:
+ results = list(query.limit(chunk_size))
+ if not results:
+ raise StopIteration
- yield results
+ yield results
def _purge_repository_contents(repo):
- """ Purges all the contents of a repository, removing all of its tags,
+ """ Purges all the contents of a repository, removing all of its tags,
manifests and images.
"""
- logger.debug('Purging repository %s', repo)
+ logger.debug("Purging repository %s", repo)
- # Purge via all the tags.
- while True:
- found = False
- for tags in _chunk_iterate_for_deletion(Tag.select().where(Tag.repository == repo)):
- logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
- found = True
- context = _GarbageCollectorContext(repo)
- for tag in tags:
- logger.debug('Deleting tag %s under repository %s', tag, repo)
- assert tag.repository_id == repo.id
- _purge_oci_tag(tag, context, allow_non_expired=True)
+ # Purge via all the tags.
+ while True:
+ found = False
+ for tags in _chunk_iterate_for_deletion(
+ Tag.select().where(Tag.repository == repo)
+ ):
+ logger.debug("Found %s tags to GC under repository %s", len(tags), repo)
+ found = True
+ context = _GarbageCollectorContext(repo)
+ for tag in tags:
+ logger.debug("Deleting tag %s under repository %s", tag, repo)
+ assert tag.repository_id == repo.id
+ _purge_oci_tag(tag, context, allow_non_expired=True)
- _run_garbage_collection(context)
+ _run_garbage_collection(context)
- if not found:
- break
+ if not found:
+ break
- # TODO: remove this once we're fully on the OCI data model.
- while True:
- found = False
- repo_tag_query = RepositoryTag.select().where(RepositoryTag.repository == repo)
- for tags in _chunk_iterate_for_deletion(repo_tag_query):
- logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
- found = True
- context = _GarbageCollectorContext(repo)
+ # TODO: remove this once we're fully on the OCI data model.
+ while True:
+ found = False
+ repo_tag_query = RepositoryTag.select().where(RepositoryTag.repository == repo)
+ for tags in _chunk_iterate_for_deletion(repo_tag_query):
+ logger.debug("Found %s tags to GC under repository %s", len(tags), repo)
+ found = True
+ context = _GarbageCollectorContext(repo)
- for tag in tags:
- logger.debug('Deleting tag %s under repository %s', tag, repo)
- assert tag.repository_id == repo.id
- _purge_pre_oci_tag(tag, context, allow_non_expired=True)
+ for tag in tags:
+ logger.debug("Deleting tag %s under repository %s", tag, repo)
+ assert tag.repository_id == repo.id
+ _purge_pre_oci_tag(tag, context, allow_non_expired=True)
- _run_garbage_collection(context)
+ _run_garbage_collection(context)
- if not found:
- break
+ if not found:
+ break
- # Add all remaining images to a new context. We do this here to minimize the number of images
- # we need to load.
- while True:
- found_image = False
- image_context = _GarbageCollectorContext(repo)
- for image in Image.select().where(Image.repository == repo):
- found_image = True
- logger.debug('Deleting image %s under repository %s', image, repo)
- assert image.repository_id == repo.id
- image_context.add_legacy_image_id(image.id)
+ # Add all remaining images to a new context. We do this here to minimize the number of images
+ # we need to load.
+ while True:
+ found_image = False
+ image_context = _GarbageCollectorContext(repo)
+ for image in Image.select().where(Image.repository == repo):
+ found_image = True
+ logger.debug("Deleting image %s under repository %s", image, repo)
+ assert image.repository_id == repo.id
+ image_context.add_legacy_image_id(image.id)
- _run_garbage_collection(image_context)
+ _run_garbage_collection(image_context)
- if not found_image:
- break
+ if not found_image:
+ break
def garbage_collect_repo(repo):
- """ Performs garbage collection over the contents of a repository. """
- # Purge expired tags.
- had_changes = False
+ """ Performs garbage collection over the contents of a repository. """
+ # Purge expired tags.
+ had_changes = False
- for tags in _chunk_iterate_for_deletion(oci_tag.lookup_unrecoverable_tags(repo)):
- logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
- context = _GarbageCollectorContext(repo)
- for tag in tags:
- logger.debug('Deleting tag %s under repository %s', tag, repo)
- assert tag.repository_id == repo.id
- assert tag.lifetime_end_ms is not None
- _purge_oci_tag(tag, context)
+ for tags in _chunk_iterate_for_deletion(oci_tag.lookup_unrecoverable_tags(repo)):
+ logger.debug("Found %s tags to GC under repository %s", len(tags), repo)
+ context = _GarbageCollectorContext(repo)
+ for tag in tags:
+ logger.debug("Deleting tag %s under repository %s", tag, repo)
+ assert tag.repository_id == repo.id
+ assert tag.lifetime_end_ms is not None
+ _purge_oci_tag(tag, context)
- _run_garbage_collection(context)
- had_changes = True
+ _run_garbage_collection(context)
+ had_changes = True
- for tags in _chunk_iterate_for_deletion(pre_oci_tag.lookup_unrecoverable_tags(repo)):
- logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
- context = _GarbageCollectorContext(repo)
- for tag in tags:
- logger.debug('Deleting tag %s under repository %s', tag, repo)
- assert tag.repository_id == repo.id
- assert tag.lifetime_end_ts is not None
- _purge_pre_oci_tag(tag, context)
+ for tags in _chunk_iterate_for_deletion(
+ pre_oci_tag.lookup_unrecoverable_tags(repo)
+ ):
+ logger.debug("Found %s tags to GC under repository %s", len(tags), repo)
+ context = _GarbageCollectorContext(repo)
+ for tag in tags:
+ logger.debug("Deleting tag %s under repository %s", tag, repo)
+ assert tag.repository_id == repo.id
+ assert tag.lifetime_end_ts is not None
+ _purge_pre_oci_tag(tag, context)
- _run_garbage_collection(context)
- had_changes = True
+ _run_garbage_collection(context)
+ had_changes = True
- return had_changes
+ return had_changes
def _run_garbage_collection(context):
- """ Runs the garbage collection loop, deleting manifests, images, labels and blobs
+ """ Runs the garbage collection loop, deleting manifests, images, labels and blobs
in an iterative fashion.
"""
- has_changes = True
+ has_changes = True
- while has_changes:
- has_changes = False
+ while has_changes:
+ has_changes = False
- # GC all manifests encountered.
- for manifest_id in list(context.manifest_ids):
- if _garbage_collect_manifest(manifest_id, context):
- has_changes = True
+ # GC all manifests encountered.
+ for manifest_id in list(context.manifest_ids):
+ if _garbage_collect_manifest(manifest_id, context):
+ has_changes = True
- # GC all images encountered.
- for image_id in list(context.legacy_image_ids):
- if _garbage_collect_legacy_image(image_id, context):
- has_changes = True
+ # GC all images encountered.
+ for image_id in list(context.legacy_image_ids):
+ if _garbage_collect_legacy_image(image_id, context):
+ has_changes = True
- # GC all labels encountered.
- for label_id in list(context.label_ids):
- if _garbage_collect_label(label_id, context):
- has_changes = True
+ # GC all labels encountered.
+ for label_id in list(context.label_ids):
+ if _garbage_collect_label(label_id, context):
+ has_changes = True
- # GC any blobs encountered.
- if context.blob_ids:
- storage_ids_removed = set(storage.garbage_collect_storage(context.blob_ids))
- for blob_removed_id in storage_ids_removed:
- context.mark_blob_id_removed(blob_removed_id)
- has_changes = True
+ # GC any blobs encountered.
+ if context.blob_ids:
+ storage_ids_removed = set(storage.garbage_collect_storage(context.blob_ids))
+ for blob_removed_id in storage_ids_removed:
+ context.mark_blob_id_removed(blob_removed_id)
+ has_changes = True
def _purge_oci_tag(tag, context, allow_non_expired=False):
- assert tag.repository_id == context.repository.id
+ assert tag.repository_id == context.repository.id
- if not allow_non_expired:
- assert tag.lifetime_end_ms is not None
- assert tag.lifetime_end_ms <= oci_tag.get_epoch_timestamp_ms()
+ if not allow_non_expired:
+ assert tag.lifetime_end_ms is not None
+ assert tag.lifetime_end_ms <= oci_tag.get_epoch_timestamp_ms()
- # Add the manifest to be GCed.
- context.add_manifest_id(tag.manifest_id)
+ # Add the manifest to be GCed.
+ context.add_manifest_id(tag.manifest_id)
- with db_transaction():
- # Reload the tag and verify its lifetime_end_ms has not changed.
- try:
- reloaded_tag = db_for_update(Tag.select().where(Tag.id == tag.id)).get()
- except Tag.DoesNotExist:
- return False
+ with db_transaction():
+ # Reload the tag and verify its lifetime_end_ms has not changed.
+ try:
+ reloaded_tag = db_for_update(Tag.select().where(Tag.id == tag.id)).get()
+ except Tag.DoesNotExist:
+ return False
- assert reloaded_tag.id == tag.id
- assert reloaded_tag.repository_id == context.repository.id
- if reloaded_tag.lifetime_end_ms != tag.lifetime_end_ms:
- return False
+ assert reloaded_tag.id == tag.id
+ assert reloaded_tag.repository_id == context.repository.id
+ if reloaded_tag.lifetime_end_ms != tag.lifetime_end_ms:
+ return False
- # Delete mapping rows.
- TagToRepositoryTag.delete().where(TagToRepositoryTag.tag == tag).execute()
+ # Delete mapping rows.
+ TagToRepositoryTag.delete().where(TagToRepositoryTag.tag == tag).execute()
- # Delete the tag.
- tag.delete_instance()
+ # Delete the tag.
+ tag.delete_instance()
def _purge_pre_oci_tag(tag, context, allow_non_expired=False):
- assert tag.repository_id == context.repository.id
+ assert tag.repository_id == context.repository.id
- if not allow_non_expired:
- assert tag.lifetime_end_ts is not None
- assert tag.lifetime_end_ts <= pre_oci_tag.get_epoch_timestamp()
+ if not allow_non_expired:
+ assert tag.lifetime_end_ts is not None
+ assert tag.lifetime_end_ts <= pre_oci_tag.get_epoch_timestamp()
- # If it exists, GC the tag manifest.
- try:
- tag_manifest = TagManifest.select().where(TagManifest.tag == tag).get()
- _garbage_collect_legacy_manifest(tag_manifest.id, context)
- except TagManifest.DoesNotExist:
- pass
-
- # Add the tag's legacy image to be GCed.
- context.add_legacy_image_id(tag.image_id)
-
- with db_transaction():
- # Reload the tag and verify its lifetime_end_ts has not changed.
+ # If it exists, GC the tag manifest.
try:
- reloaded_tag = db_for_update(RepositoryTag.select().where(RepositoryTag.id == tag.id)).get()
- except RepositoryTag.DoesNotExist:
- return False
+ tag_manifest = TagManifest.select().where(TagManifest.tag == tag).get()
+ _garbage_collect_legacy_manifest(tag_manifest.id, context)
+ except TagManifest.DoesNotExist:
+ pass
- assert reloaded_tag.id == tag.id
- assert reloaded_tag.repository_id == context.repository.id
- if reloaded_tag.lifetime_end_ts != tag.lifetime_end_ts:
- return False
+ # Add the tag's legacy image to be GCed.
+ context.add_legacy_image_id(tag.image_id)
- # Delete mapping rows.
- TagToRepositoryTag.delete().where(TagToRepositoryTag.repository_tag == reloaded_tag).execute()
+ with db_transaction():
+ # Reload the tag and verify its lifetime_end_ts has not changed.
+ try:
+ reloaded_tag = db_for_update(
+ RepositoryTag.select().where(RepositoryTag.id == tag.id)
+ ).get()
+ except RepositoryTag.DoesNotExist:
+ return False
- # Delete the tag.
- reloaded_tag.delete_instance()
+ assert reloaded_tag.id == tag.id
+ assert reloaded_tag.repository_id == context.repository.id
+ if reloaded_tag.lifetime_end_ts != tag.lifetime_end_ts:
+ return False
+
+ # Delete mapping rows.
+ TagToRepositoryTag.delete().where(
+ TagToRepositoryTag.repository_tag == reloaded_tag
+ ).execute()
+
+ # Delete the tag.
+ reloaded_tag.delete_instance()
def _check_manifest_used(manifest_id):
- assert manifest_id is not None
+ assert manifest_id is not None
- with db_transaction():
- # Check if the manifest is referenced by any other tag.
- try:
- Tag.select().where(Tag.manifest == manifest_id).get()
- return True
- except Tag.DoesNotExist:
- pass
+ with db_transaction():
+ # Check if the manifest is referenced by any other tag.
+ try:
+ Tag.select().where(Tag.manifest == manifest_id).get()
+ return True
+ except Tag.DoesNotExist:
+ pass
- # Check if the manifest is referenced as a child of another manifest.
- try:
- ManifestChild.select().where(ManifestChild.child_manifest == manifest_id).get()
- return True
- except ManifestChild.DoesNotExist:
- pass
+ # Check if the manifest is referenced as a child of another manifest.
+ try:
+ ManifestChild.select().where(
+ ManifestChild.child_manifest == manifest_id
+ ).get()
+ return True
+ except ManifestChild.DoesNotExist:
+ pass
- return False
+ return False
def _garbage_collect_manifest(manifest_id, context):
- assert manifest_id is not None
+ assert manifest_id is not None
- # Make sure the manifest isn't referenced.
- if _check_manifest_used(manifest_id):
- return False
-
- # Add the manifest's blobs to the context to be GCed.
- for manifest_blob in ManifestBlob.select().where(ManifestBlob.manifest == manifest_id):
- context.add_blob_id(manifest_blob.blob_id)
-
- # Retrieve the manifest's associated image, if any.
- try:
- legacy_image_id = ManifestLegacyImage.get(manifest=manifest_id).image_id
- context.add_legacy_image_id(legacy_image_id)
- except ManifestLegacyImage.DoesNotExist:
- legacy_image_id = None
-
- # Add child manifests to be GCed.
- for connector in ManifestChild.select().where(ManifestChild.manifest == manifest_id):
- context.add_manifest_id(connector.child_manifest_id)
-
- # Add the labels to be GCed.
- for manifest_label in ManifestLabel.select().where(ManifestLabel.manifest == manifest_id):
- context.add_label_id(manifest_label.label_id)
-
- # Delete the manifest.
- with db_transaction():
- try:
- manifest = Manifest.select().where(Manifest.id == manifest_id).get()
- except Manifest.DoesNotExist:
- return False
-
- assert manifest.id == manifest_id
- assert manifest.repository_id == context.repository.id
+ # Make sure the manifest isn't referenced.
if _check_manifest_used(manifest_id):
- return False
+ return False
- # Delete any label mappings.
- (TagManifestLabelMap
- .delete()
- .where(TagManifestLabelMap.manifest == manifest_id)
- .execute())
+ # Add the manifest's blobs to the context to be GCed.
+ for manifest_blob in ManifestBlob.select().where(
+ ManifestBlob.manifest == manifest_id
+ ):
+ context.add_blob_id(manifest_blob.blob_id)
- # Delete any mapping rows for the manifest.
- TagManifestToManifest.delete().where(TagManifestToManifest.manifest == manifest_id).execute()
+ # Retrieve the manifest's associated image, if any.
+ try:
+ legacy_image_id = ManifestLegacyImage.get(manifest=manifest_id).image_id
+ context.add_legacy_image_id(legacy_image_id)
+ except ManifestLegacyImage.DoesNotExist:
+ legacy_image_id = None
- # Delete any label rows.
- ManifestLabel.delete().where(ManifestLabel.manifest == manifest_id,
- ManifestLabel.repository == context.repository).execute()
+ # Add child manifests to be GCed.
+ for connector in ManifestChild.select().where(
+ ManifestChild.manifest == manifest_id
+ ):
+ context.add_manifest_id(connector.child_manifest_id)
- # Delete any child manifest rows.
- ManifestChild.delete().where(ManifestChild.manifest == manifest_id,
- ManifestChild.repository == context.repository).execute()
-
- # Delete the manifest blobs for the manifest.
- ManifestBlob.delete().where(ManifestBlob.manifest == manifest_id,
- ManifestBlob.repository == context.repository).execute()
-
- # Delete the manifest legacy image row.
- if legacy_image_id:
- (ManifestLegacyImage
- .delete()
- .where(ManifestLegacyImage.manifest == manifest_id,
- ManifestLegacyImage.repository == context.repository)
- .execute())
+ # Add the labels to be GCed.
+ for manifest_label in ManifestLabel.select().where(
+ ManifestLabel.manifest == manifest_id
+ ):
+ context.add_label_id(manifest_label.label_id)
# Delete the manifest.
- manifest.delete_instance()
+ with db_transaction():
+ try:
+ manifest = Manifest.select().where(Manifest.id == manifest_id).get()
+ except Manifest.DoesNotExist:
+ return False
- context.mark_manifest_removed(manifest)
- return True
+ assert manifest.id == manifest_id
+ assert manifest.repository_id == context.repository.id
+ if _check_manifest_used(manifest_id):
+ return False
+
+ # Delete any label mappings.
+ (
+ TagManifestLabelMap.delete()
+ .where(TagManifestLabelMap.manifest == manifest_id)
+ .execute()
+ )
+
+ # Delete any mapping rows for the manifest.
+ TagManifestToManifest.delete().where(
+ TagManifestToManifest.manifest == manifest_id
+ ).execute()
+
+ # Delete any label rows.
+ ManifestLabel.delete().where(
+ ManifestLabel.manifest == manifest_id,
+ ManifestLabel.repository == context.repository,
+ ).execute()
+
+ # Delete any child manifest rows.
+ ManifestChild.delete().where(
+ ManifestChild.manifest == manifest_id,
+ ManifestChild.repository == context.repository,
+ ).execute()
+
+ # Delete the manifest blobs for the manifest.
+ ManifestBlob.delete().where(
+ ManifestBlob.manifest == manifest_id,
+ ManifestBlob.repository == context.repository,
+ ).execute()
+
+ # Delete the manifest legacy image row.
+ if legacy_image_id:
+ (
+ ManifestLegacyImage.delete()
+ .where(
+ ManifestLegacyImage.manifest == manifest_id,
+ ManifestLegacyImage.repository == context.repository,
+ )
+ .execute()
+ )
+
+ # Delete the manifest.
+ manifest.delete_instance()
+
+ context.mark_manifest_removed(manifest)
+ return True
def _garbage_collect_legacy_manifest(legacy_manifest_id, context):
- assert legacy_manifest_id is not None
+ assert legacy_manifest_id is not None
- # Add the labels to be GCed.
- query = TagManifestLabel.select().where(TagManifestLabel.annotated == legacy_manifest_id)
- for manifest_label in query:
- context.add_label_id(manifest_label.label_id)
-
- # Delete the tag manifest.
- with db_transaction():
- try:
- tag_manifest = TagManifest.select().where(TagManifest.id == legacy_manifest_id).get()
- except TagManifest.DoesNotExist:
- return False
-
- assert tag_manifest.id == legacy_manifest_id
- assert tag_manifest.tag.repository_id == context.repository.id
-
- # Delete any label mapping rows.
- (TagManifestLabelMap
- .delete()
- .where(TagManifestLabelMap.tag_manifest == legacy_manifest_id)
- .execute())
-
- # Delete the label rows.
- TagManifestLabel.delete().where(TagManifestLabel.annotated == legacy_manifest_id).execute()
-
- # Delete the mapping row if it exists.
- try:
- tmt = (TagManifestToManifest
- .select()
- .where(TagManifestToManifest.tag_manifest == tag_manifest)
- .get())
- context.add_manifest_id(tmt.manifest_id)
- tmt.delete_instance()
- except TagManifestToManifest.DoesNotExist:
- pass
+ # Add the labels to be GCed.
+ query = TagManifestLabel.select().where(
+ TagManifestLabel.annotated == legacy_manifest_id
+ )
+ for manifest_label in query:
+ context.add_label_id(manifest_label.label_id)
# Delete the tag manifest.
- tag_manifest.delete_instance()
+ with db_transaction():
+ try:
+ tag_manifest = (
+ TagManifest.select().where(TagManifest.id == legacy_manifest_id).get()
+ )
+ except TagManifest.DoesNotExist:
+ return False
- return True
+ assert tag_manifest.id == legacy_manifest_id
+ assert tag_manifest.tag.repository_id == context.repository.id
+
+ # Delete any label mapping rows.
+ (
+ TagManifestLabelMap.delete()
+ .where(TagManifestLabelMap.tag_manifest == legacy_manifest_id)
+ .execute()
+ )
+
+ # Delete the label rows.
+ TagManifestLabel.delete().where(
+ TagManifestLabel.annotated == legacy_manifest_id
+ ).execute()
+
+ # Delete the mapping row if it exists.
+ try:
+ tmt = (
+ TagManifestToManifest.select()
+ .where(TagManifestToManifest.tag_manifest == tag_manifest)
+ .get()
+ )
+ context.add_manifest_id(tmt.manifest_id)
+ tmt.delete_instance()
+ except TagManifestToManifest.DoesNotExist:
+ pass
+
+ # Delete the tag manifest.
+ tag_manifest.delete_instance()
+
+ return True
def _check_image_used(legacy_image_id):
- assert legacy_image_id is not None
+ assert legacy_image_id is not None
- with db_transaction():
- # Check if the image is referenced by a manifest.
- try:
- ManifestLegacyImage.select().where(ManifestLegacyImage.image == legacy_image_id).get()
- return True
- except ManifestLegacyImage.DoesNotExist:
- pass
+ with db_transaction():
+ # Check if the image is referenced by a manifest.
+ try:
+ ManifestLegacyImage.select().where(
+ ManifestLegacyImage.image == legacy_image_id
+ ).get()
+ return True
+ except ManifestLegacyImage.DoesNotExist:
+ pass
- # Check if the image is referenced by a tag.
- try:
- RepositoryTag.select().where(RepositoryTag.image == legacy_image_id).get()
- return True
- except RepositoryTag.DoesNotExist:
- pass
+ # Check if the image is referenced by a tag.
+ try:
+ RepositoryTag.select().where(RepositoryTag.image == legacy_image_id).get()
+ return True
+ except RepositoryTag.DoesNotExist:
+ pass
- # Check if the image is referenced by another image.
- try:
- Image.select().where(Image.parent == legacy_image_id).get()
- return True
- except Image.DoesNotExist:
- pass
+ # Check if the image is referenced by another image.
+ try:
+ Image.select().where(Image.parent == legacy_image_id).get()
+ return True
+ except Image.DoesNotExist:
+ pass
- return False
+ return False
def _garbage_collect_legacy_image(legacy_image_id, context):
- assert legacy_image_id is not None
+ assert legacy_image_id is not None
- # Check if the image is referenced.
- if _check_image_used(legacy_image_id):
- return False
-
- # We have an unreferenced image. We can now delete it.
- # Grab any derived storage for the image.
- for derived in (DerivedStorageForImage
- .select()
- .where(DerivedStorageForImage.source_image == legacy_image_id)):
- context.add_blob_id(derived.derivative_id)
-
- try:
- image = Image.select().where(Image.id == legacy_image_id).get()
- except Image.DoesNotExist:
- return False
-
- assert image.repository_id == context.repository.id
-
- # Add the image's blob to be GCed.
- context.add_blob_id(image.storage_id)
-
- # If the image has a parent ID, add the parent for GC.
- if image.parent_id is not None:
- context.add_legacy_image_id(image.parent_id)
-
- # Delete the image.
- with db_transaction():
+ # Check if the image is referenced.
if _check_image_used(legacy_image_id):
- return False
+ return False
+
+ # We have an unreferenced image. We can now delete it.
+ # Grab any derived storage for the image.
+ for derived in DerivedStorageForImage.select().where(
+ DerivedStorageForImage.source_image == legacy_image_id
+ ):
+ context.add_blob_id(derived.derivative_id)
try:
- image = Image.select().where(Image.id == legacy_image_id).get()
+ image = Image.select().where(Image.id == legacy_image_id).get()
except Image.DoesNotExist:
- return False
+ return False
- assert image.id == legacy_image_id
assert image.repository_id == context.repository.id
- # Delete any derived storage for the image.
- (DerivedStorageForImage
- .delete()
- .where(DerivedStorageForImage.source_image == legacy_image_id)
- .execute())
+ # Add the image's blob to be GCed.
+ context.add_blob_id(image.storage_id)
- # Delete the image itself.
- image.delete_instance()
+ # If the image has a parent ID, add the parent for GC.
+ if image.parent_id is not None:
+ context.add_legacy_image_id(image.parent_id)
- context.mark_legacy_image_removed(image)
+ # Delete the image.
+ with db_transaction():
+ if _check_image_used(legacy_image_id):
+ return False
- if config.image_cleanup_callbacks:
- for callback in config.image_cleanup_callbacks:
- callback([image])
+ try:
+ image = Image.select().where(Image.id == legacy_image_id).get()
+ except Image.DoesNotExist:
+ return False
- return True
+ assert image.id == legacy_image_id
+ assert image.repository_id == context.repository.id
+
+ # Delete any derived storage for the image.
+ (
+ DerivedStorageForImage.delete()
+ .where(DerivedStorageForImage.source_image == legacy_image_id)
+ .execute()
+ )
+
+ # Delete the image itself.
+ image.delete_instance()
+
+ context.mark_legacy_image_removed(image)
+
+ if config.image_cleanup_callbacks:
+ for callback in config.image_cleanup_callbacks:
+ callback([image])
+
+ return True
def _check_label_used(label_id):
- assert label_id is not None
+ assert label_id is not None
- with db_transaction():
- # Check if the label is referenced by another manifest or tag manifest.
- try:
- ManifestLabel.select().where(ManifestLabel.label == label_id).get()
- return True
- except ManifestLabel.DoesNotExist:
- pass
+ with db_transaction():
+ # Check if the label is referenced by another manifest or tag manifest.
+ try:
+ ManifestLabel.select().where(ManifestLabel.label == label_id).get()
+ return True
+ except ManifestLabel.DoesNotExist:
+ pass
- try:
- TagManifestLabel.select().where(TagManifestLabel.label == label_id).get()
- return True
- except TagManifestLabel.DoesNotExist:
- pass
+ try:
+ TagManifestLabel.select().where(TagManifestLabel.label == label_id).get()
+ return True
+ except TagManifestLabel.DoesNotExist:
+ pass
- return False
+ return False
def _garbage_collect_label(label_id, context):
- assert label_id is not None
+ assert label_id is not None
- # We can now delete the label.
- with db_transaction():
- if _check_label_used(label_id):
- return False
+ # We can now delete the label.
+ with db_transaction():
+ if _check_label_used(label_id):
+ return False
- result = Label.delete().where(Label.id == label_id).execute() == 1
+ result = Label.delete().where(Label.id == label_id).execute() == 1
- if result:
- context.mark_label_id_removed(label_id)
+ if result:
+ context.mark_label_id_removed(label_id)
- return result
+ return result
diff --git a/data/model/health.py b/data/model/health.py
index b40cee025..1f0471d19 100644
--- a/data/model/health.py
+++ b/data/model/health.py
@@ -4,19 +4,20 @@ from data.database import TeamRole, validate_database_url
logger = logging.getLogger(__name__)
-def check_health(app_config):
- # Attempt to connect to the database first. If the DB is not responding,
- # using the validate_database_url will timeout quickly, as opposed to
- # making a normal connect which will just hang (thus breaking the health
- # check).
- try:
- validate_database_url(app_config['DB_URI'], {}, connect_timeout=3)
- except Exception as ex:
- return (False, 'Could not connect to the database: %s' % ex.message)
- # We will connect to the db, check that it contains some team role kinds
- try:
- okay = bool(list(TeamRole.select().limit(1)))
- return (okay, 'Could not connect to the database' if not okay else None)
- except Exception as ex:
- return (False, 'Could not connect to the database: %s' % ex.message)
+def check_health(app_config):
+ # Attempt to connect to the database first. If the DB is not responding,
+ # using the validate_database_url will timeout quickly, as opposed to
+ # making a normal connect which will just hang (thus breaking the health
+ # check).
+ try:
+ validate_database_url(app_config["DB_URI"], {}, connect_timeout=3)
+ except Exception as ex:
+ return (False, "Could not connect to the database: %s" % ex.message)
+
+ # We will connect to the db, check that it contains some team role kinds
+ try:
+ okay = bool(list(TeamRole.select().limit(1)))
+ return (okay, "Could not connect to the database" if not okay else None)
+ except Exception as ex:
+ return (False, "Could not connect to the database: %s" % ex.message)
diff --git a/data/model/image.py b/data/model/image.py
index 1c6f1b952..f87f83643 100644
--- a/data/model/image.py
+++ b/data/model/image.py
@@ -8,509 +8,640 @@ import dateutil.parser
from peewee import JOIN, IntegrityError, fn
-from data.model import (DataModelException, db_transaction, _basequery, storage,
- InvalidImageException)
-from data.database import (Image, Repository, ImageStoragePlacement, Namespace, ImageStorage,
- ImageStorageLocation, RepositoryPermission, DerivedStorageForImage,
- ImageStorageTransformation, User)
+from data.model import (
+ DataModelException,
+ db_transaction,
+ _basequery,
+ storage,
+ InvalidImageException,
+)
+from data.database import (
+ Image,
+ Repository,
+ ImageStoragePlacement,
+ Namespace,
+ ImageStorage,
+ ImageStorageLocation,
+ RepositoryPermission,
+ DerivedStorageForImage,
+ ImageStorageTransformation,
+ User,
+)
from util.canonicaljson import canonicalize
logger = logging.getLogger(__name__)
+
def _namespace_id_for_username(username):
- try:
- return User.get(username=username).id
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(username=username).id
+ except User.DoesNotExist:
+ return None
def get_image_with_storage(docker_image_id, storage_uuid):
- """ Returns the image with the given docker image ID and storage uuid or None if none.
+ """ Returns the image with the given docker image ID and storage uuid or None if none.
"""
- try:
- return (Image
- .select(Image, ImageStorage)
+ try:
+ return (
+ Image.select(Image, ImageStorage)
.join(ImageStorage)
- .where(Image.docker_image_id == docker_image_id,
- ImageStorage.uuid == storage_uuid)
- .get())
- except Image.DoesNotExist:
- return None
+ .where(
+ Image.docker_image_id == docker_image_id,
+ ImageStorage.uuid == storage_uuid,
+ )
+ .get()
+ )
+ except Image.DoesNotExist:
+ return None
def get_parent_images(namespace_name, repository_name, image_obj):
- """ Returns a list of parent Image objects starting with the most recent parent
+ """ Returns a list of parent Image objects starting with the most recent parent
and ending with the base layer. The images in this query will include the storage.
"""
- parents = image_obj.ancestors
+ parents = image_obj.ancestors
- # Ancestors are in the format ///...//, with each path section
- # containing the database Id of the image row.
- parent_db_ids = parents.strip('/').split('/')
- if parent_db_ids == ['']:
- return []
+ # Ancestors are in the format ///...//, with each path section
+ # containing the database Id of the image row.
+ parent_db_ids = parents.strip("/").split("/")
+ if parent_db_ids == [""]:
+ return []
- def filter_to_parents(query):
- return query.where(Image.id << parent_db_ids)
+ def filter_to_parents(query):
+ return query.where(Image.id << parent_db_ids)
- parents = _get_repository_images_and_storages(namespace_name, repository_name,
- filter_to_parents)
- id_to_image = {unicode(image.id): image for image in parents}
- try:
- return [id_to_image[parent_id] for parent_id in reversed(parent_db_ids)]
- except KeyError as ke:
- logger.exception('Could not find an expected parent image for image %s', image_obj.id)
- raise DataModelException('Unknown parent image')
+ parents = _get_repository_images_and_storages(
+ namespace_name, repository_name, filter_to_parents
+ )
+ id_to_image = {unicode(image.id): image for image in parents}
+ try:
+ return [id_to_image[parent_id] for parent_id in reversed(parent_db_ids)]
+ except KeyError as ke:
+ logger.exception(
+ "Could not find an expected parent image for image %s", image_obj.id
+ )
+ raise DataModelException("Unknown parent image")
def get_placements_for_images(images):
- """ Returns the placements for the given images, as a map from image storage ID to placements. """
- if not images:
- return {}
+ """ Returns the placements for the given images, as a map from image storage ID to placements. """
+ if not images:
+ return {}
- query = (ImageStoragePlacement
- .select(ImageStoragePlacement, ImageStorageLocation, ImageStorage)
- .join(ImageStorageLocation)
- .switch(ImageStoragePlacement)
- .join(ImageStorage)
- .where(ImageStorage.id << [image.storage_id for image in images]))
+ query = (
+ ImageStoragePlacement.select(
+ ImageStoragePlacement, ImageStorageLocation, ImageStorage
+ )
+ .join(ImageStorageLocation)
+ .switch(ImageStoragePlacement)
+ .join(ImageStorage)
+ .where(ImageStorage.id << [image.storage_id for image in images])
+ )
- placement_map = defaultdict(list)
- for placement in query:
- placement_map[placement.storage.id].append(placement)
+ placement_map = defaultdict(list)
+ for placement in query:
+ placement_map[placement.storage.id].append(placement)
- return dict(placement_map)
+ return dict(placement_map)
def get_image_and_placements(namespace_name, repo_name, docker_image_id):
- """ Returns the repo image (with a storage object) and storage placements for the image
+ """ Returns the repo image (with a storage object) and storage placements for the image
or (None, None) if non found.
"""
- repo_image = get_repo_image_and_storage(namespace_name, repo_name, docker_image_id)
- if repo_image is None:
- return (None, None)
+ repo_image = get_repo_image_and_storage(namespace_name, repo_name, docker_image_id)
+ if repo_image is None:
+ return (None, None)
- query = (ImageStoragePlacement
- .select(ImageStoragePlacement, ImageStorageLocation)
- .join(ImageStorageLocation)
- .switch(ImageStoragePlacement)
- .join(ImageStorage)
- .where(ImageStorage.id == repo_image.storage_id))
+ query = (
+ ImageStoragePlacement.select(ImageStoragePlacement, ImageStorageLocation)
+ .join(ImageStorageLocation)
+ .switch(ImageStoragePlacement)
+ .join(ImageStorage)
+ .where(ImageStorage.id == repo_image.storage_id)
+ )
- return repo_image, list(query)
+ return repo_image, list(query)
def get_repo_image(namespace_name, repository_name, docker_image_id):
- """ Returns the repository image with the given Docker image ID or None if none.
+ """ Returns the repository image with the given Docker image ID or None if none.
Does not include the storage object.
"""
- def limit_to_image_id(query):
- return query.where(Image.docker_image_id == docker_image_id).limit(1)
- query = _get_repository_images(namespace_name, repository_name, limit_to_image_id)
- try:
- return query.get()
- except Image.DoesNotExist:
- return None
+ def limit_to_image_id(query):
+ return query.where(Image.docker_image_id == docker_image_id).limit(1)
+
+ query = _get_repository_images(namespace_name, repository_name, limit_to_image_id)
+ try:
+ return query.get()
+ except Image.DoesNotExist:
+ return None
def get_repo_image_and_storage(namespace_name, repository_name, docker_image_id):
- """ Returns the repository image with the given Docker image ID or None if none.
+ """ Returns the repository image with the given Docker image ID or None if none.
Includes the storage object.
"""
- def limit_to_image_id(query):
- return query.where(Image.docker_image_id == docker_image_id)
- images = _get_repository_images_and_storages(namespace_name, repository_name, limit_to_image_id)
- if not images:
- return None
+ def limit_to_image_id(query):
+ return query.where(Image.docker_image_id == docker_image_id)
- return images[0]
+ images = _get_repository_images_and_storages(
+ namespace_name, repository_name, limit_to_image_id
+ )
+ if not images:
+ return None
+
+ return images[0]
def get_image_by_id(namespace_name, repository_name, docker_image_id):
- """ Returns the repository image with the given Docker image ID or raises if not found.
+ """ Returns the repository image with the given Docker image ID or raises if not found.
Includes the storage object.
"""
- image = get_repo_image_and_storage(namespace_name, repository_name, docker_image_id)
- if not image:
- raise InvalidImageException('Unable to find image \'%s\' for repo \'%s/%s\'' %
- (docker_image_id, namespace_name, repository_name))
- return image
+ image = get_repo_image_and_storage(namespace_name, repository_name, docker_image_id)
+ if not image:
+ raise InvalidImageException(
+ "Unable to find image '%s' for repo '%s/%s'"
+ % (docker_image_id, namespace_name, repository_name)
+ )
+ return image
-def _get_repository_images_and_storages(namespace_name, repository_name, query_modifier):
- query = (Image
- .select(Image, ImageStorage)
- .join(ImageStorage)
- .switch(Image)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.name == repository_name, Namespace.username == namespace_name))
+def _get_repository_images_and_storages(
+ namespace_name, repository_name, query_modifier
+):
+ query = (
+ Image.select(Image, ImageStorage)
+ .join(ImageStorage)
+ .switch(Image)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Repository.name == repository_name, Namespace.username == namespace_name)
+ )
- query = query_modifier(query)
- return query
+ query = query_modifier(query)
+ return query
def _get_repository_images(namespace_name, repository_name, query_modifier):
- query = (Image
- .select()
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.name == repository_name, Namespace.username == namespace_name))
+ query = (
+ Image.select()
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Repository.name == repository_name, Namespace.username == namespace_name)
+ )
- query = query_modifier(query)
- return query
+ query = query_modifier(query)
+ return query
def lookup_repository_images(repo, docker_image_ids):
- return (Image
- .select(Image, ImageStorage)
- .join(ImageStorage)
- .where(Image.repository == repo, Image.docker_image_id << docker_image_ids))
+ return (
+ Image.select(Image, ImageStorage)
+ .join(ImageStorage)
+ .where(Image.repository == repo, Image.docker_image_id << docker_image_ids)
+ )
def get_repository_images_without_placements(repo_obj, with_ancestor=None):
- query = (Image
- .select(Image, ImageStorage)
- .join(ImageStorage)
- .where(Image.repository == repo_obj))
+ query = (
+ Image.select(Image, ImageStorage)
+ .join(ImageStorage)
+ .where(Image.repository == repo_obj)
+ )
- if with_ancestor:
- ancestors_string = '%s%s/' % (with_ancestor.ancestors, with_ancestor.id)
- query = query.where((Image.ancestors ** (ancestors_string + '%')) |
- (Image.id == with_ancestor.id))
+ if with_ancestor:
+ ancestors_string = "%s%s/" % (with_ancestor.ancestors, with_ancestor.id)
+ query = query.where(
+ (Image.ancestors ** (ancestors_string + "%"))
+ | (Image.id == with_ancestor.id)
+ )
- return query
+ return query
def get_repository_images(namespace_name, repository_name):
- """ Returns all the repository images in the repository. Does not include storage objects. """
- return _get_repository_images(namespace_name, repository_name, lambda q: q)
+ """ Returns all the repository images in the repository. Does not include storage objects. """
+ return _get_repository_images(namespace_name, repository_name, lambda q: q)
-def __translate_ancestry(old_ancestry, translations, repo_obj, username, preferred_location):
- if old_ancestry == '/':
- return '/'
+def __translate_ancestry(
+ old_ancestry, translations, repo_obj, username, preferred_location
+):
+ if old_ancestry == "/":
+ return "/"
- def translate_id(old_id, docker_image_id):
- logger.debug('Translating id: %s', old_id)
- if old_id not in translations:
- image_in_repo = find_create_or_link_image(docker_image_id, repo_obj, username, translations,
- preferred_location)
- translations[old_id] = image_in_repo.id
- return translations[old_id]
+ def translate_id(old_id, docker_image_id):
+ logger.debug("Translating id: %s", old_id)
+ if old_id not in translations:
+ image_in_repo = find_create_or_link_image(
+ docker_image_id, repo_obj, username, translations, preferred_location
+ )
+ translations[old_id] = image_in_repo.id
+ return translations[old_id]
- # Select all the ancestor Docker IDs in a single query.
- old_ids = [int(id_str) for id_str in old_ancestry.split('/')[1:-1]]
- query = Image.select(Image.id, Image.docker_image_id).where(Image.id << old_ids)
- old_images = {i.id: i.docker_image_id for i in query}
+ # Select all the ancestor Docker IDs in a single query.
+ old_ids = [int(id_str) for id_str in old_ancestry.split("/")[1:-1]]
+ query = Image.select(Image.id, Image.docker_image_id).where(Image.id << old_ids)
+ old_images = {i.id: i.docker_image_id for i in query}
- # Translate the old images into new ones.
- new_ids = [str(translate_id(old_id, old_images[old_id])) for old_id in old_ids]
- return '/%s/' % '/'.join(new_ids)
+ # Translate the old images into new ones.
+ new_ids = [str(translate_id(old_id, old_images[old_id])) for old_id in old_ids]
+ return "/%s/" % "/".join(new_ids)
-def _find_or_link_image(existing_image, repo_obj, username, translations, preferred_location):
- with db_transaction():
- # Check for an existing image, under the transaction, to make sure it doesn't already exist.
- repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
- existing_image.docker_image_id)
+def _find_or_link_image(
+ existing_image, repo_obj, username, translations, preferred_location
+):
+ with db_transaction():
+ # Check for an existing image, under the transaction, to make sure it doesn't already exist.
+ repo_image = get_repo_image(
+ repo_obj.namespace_user.username,
+ repo_obj.name,
+ existing_image.docker_image_id,
+ )
+ if repo_image:
+ return repo_image
+
+ # Make sure the existing base image still exists.
+ try:
+ to_copy = (
+ Image.select()
+ .join(ImageStorage)
+ .where(Image.id == existing_image.id)
+ .get()
+ )
+
+ msg = "Linking image to existing storage with docker id: %s and uuid: %s"
+ logger.debug(msg, existing_image.docker_image_id, to_copy.storage.uuid)
+
+ new_image_ancestry = __translate_ancestry(
+ to_copy.ancestors, translations, repo_obj, username, preferred_location
+ )
+
+ copied_storage = to_copy.storage
+
+ translated_parent_id = None
+ if new_image_ancestry != "/":
+ translated_parent_id = int(new_image_ancestry.split("/")[-2])
+
+ new_image = Image.create(
+ docker_image_id=existing_image.docker_image_id,
+ repository=repo_obj,
+ storage=copied_storage,
+ ancestors=new_image_ancestry,
+ command=existing_image.command,
+ created=existing_image.created,
+ comment=existing_image.comment,
+ v1_json_metadata=existing_image.v1_json_metadata,
+ aggregate_size=existing_image.aggregate_size,
+ parent=translated_parent_id,
+ v1_checksum=existing_image.v1_checksum,
+ )
+
+ logger.debug(
+ "Storing translation %s -> %s", existing_image.id, new_image.id
+ )
+ translations[existing_image.id] = new_image.id
+ return new_image
+ except Image.DoesNotExist:
+ return None
+
+
+def find_create_or_link_image(
+ docker_image_id, repo_obj, username, translations, preferred_location
+):
+
+ # First check for the image existing in the repository. If found, we simply return it.
+ repo_image = get_repo_image(
+ repo_obj.namespace_user.username, repo_obj.name, docker_image_id
+ )
if repo_image:
- return repo_image
+ return repo_image
- # Make sure the existing base image still exists.
+ # We next check to see if there is an existing storage the new image can link to.
+ existing_image_query = (
+ Image.select(Image, ImageStorage)
+ .distinct()
+ .join(ImageStorage)
+ .switch(Image)
+ .join(Repository)
+ .join(RepositoryPermission, JOIN.LEFT_OUTER)
+ .switch(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ ImageStorage.uploading == False, Image.docker_image_id == docker_image_id
+ )
+ )
+
+ existing_image_query = _basequery.filter_to_repos_for_user(
+ existing_image_query, _namespace_id_for_username(username)
+ )
+
+ # If there is an existing image, we try to translate its ancestry and copy its storage.
+ new_image = None
try:
- to_copy = Image.select().join(ImageStorage).where(Image.id == existing_image.id).get()
+ logger.debug("Looking up existing image for ID: %s", docker_image_id)
+ existing_image = existing_image_query.get()
- msg = 'Linking image to existing storage with docker id: %s and uuid: %s'
- logger.debug(msg, existing_image.docker_image_id, to_copy.storage.uuid)
-
- new_image_ancestry = __translate_ancestry(to_copy.ancestors, translations, repo_obj,
- username, preferred_location)
-
- copied_storage = to_copy.storage
-
- translated_parent_id = None
- if new_image_ancestry != '/':
- translated_parent_id = int(new_image_ancestry.split('/')[-2])
-
- new_image = Image.create(docker_image_id=existing_image.docker_image_id,
- repository=repo_obj,
- storage=copied_storage,
- ancestors=new_image_ancestry,
- command=existing_image.command,
- created=existing_image.created,
- comment=existing_image.comment,
- v1_json_metadata=existing_image.v1_json_metadata,
- aggregate_size=existing_image.aggregate_size,
- parent=translated_parent_id,
- v1_checksum=existing_image.v1_checksum)
-
-
- logger.debug('Storing translation %s -> %s', existing_image.id, new_image.id)
- translations[existing_image.id] = new_image.id
- return new_image
+ logger.debug(
+ "Existing image %s found for ID: %s", existing_image.id, docker_image_id
+ )
+ new_image = _find_or_link_image(
+ existing_image, repo_obj, username, translations, preferred_location
+ )
+ if new_image:
+ return new_image
except Image.DoesNotExist:
- return None
+ logger.debug("No existing image found for ID: %s", docker_image_id)
+
+ # Otherwise, create a new storage directly.
+ with db_transaction():
+ # Final check for an existing image, under the transaction.
+ repo_image = get_repo_image(
+ repo_obj.namespace_user.username, repo_obj.name, docker_image_id
+ )
+ if repo_image:
+ return repo_image
+
+ logger.debug("Creating new storage for docker id: %s", docker_image_id)
+ new_storage = storage.create_v1_storage(preferred_location)
+
+ return Image.create(
+ docker_image_id=docker_image_id,
+ repository=repo_obj,
+ storage=new_storage,
+ ancestors="/",
+ )
-def find_create_or_link_image(docker_image_id, repo_obj, username, translations,
- preferred_location):
-
- # First check for the image existing in the repository. If found, we simply return it.
- repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
- docker_image_id)
- if repo_image:
- return repo_image
-
- # We next check to see if there is an existing storage the new image can link to.
- existing_image_query = (Image
- .select(Image, ImageStorage)
- .distinct()
- .join(ImageStorage)
- .switch(Image)
- .join(Repository)
- .join(RepositoryPermission, JOIN.LEFT_OUTER)
- .switch(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(ImageStorage.uploading == False,
- Image.docker_image_id == docker_image_id))
-
- existing_image_query = _basequery.filter_to_repos_for_user(existing_image_query,
- _namespace_id_for_username(username))
-
- # If there is an existing image, we try to translate its ancestry and copy its storage.
- new_image = None
- try:
- logger.debug('Looking up existing image for ID: %s', docker_image_id)
- existing_image = existing_image_query.get()
-
- logger.debug('Existing image %s found for ID: %s', existing_image.id, docker_image_id)
- new_image = _find_or_link_image(existing_image, repo_obj, username, translations,
- preferred_location)
- if new_image:
- return new_image
- except Image.DoesNotExist:
- logger.debug('No existing image found for ID: %s', docker_image_id)
-
- # Otherwise, create a new storage directly.
- with db_transaction():
- # Final check for an existing image, under the transaction.
- repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
- docker_image_id)
- if repo_image:
- return repo_image
-
- logger.debug('Creating new storage for docker id: %s', docker_image_id)
- new_storage = storage.create_v1_storage(preferred_location)
-
- return Image.create(docker_image_id=docker_image_id,
- repository=repo_obj, storage=new_storage,
- ancestors='/')
-
-
-def set_image_metadata(docker_image_id, namespace_name, repository_name, created_date_str, comment,
- command, v1_json_metadata, parent=None):
- """ Sets metadata that is specific to how a binary piece of storage fits into the layer tree.
+def set_image_metadata(
+ docker_image_id,
+ namespace_name,
+ repository_name,
+ created_date_str,
+ comment,
+ command,
+ v1_json_metadata,
+ parent=None,
+):
+ """ Sets metadata that is specific to how a binary piece of storage fits into the layer tree.
"""
- with db_transaction():
- try:
- fetched = (Image
- .select(Image, ImageStorage)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(Image)
- .join(ImageStorage)
- .where(Repository.name == repository_name, Namespace.username == namespace_name,
- Image.docker_image_id == docker_image_id)
- .get())
- except Image.DoesNotExist:
- raise DataModelException('No image with specified id and repository')
+ with db_transaction():
+ try:
+ fetched = (
+ Image.select(Image, ImageStorage)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(Image)
+ .join(ImageStorage)
+ .where(
+ Repository.name == repository_name,
+ Namespace.username == namespace_name,
+ Image.docker_image_id == docker_image_id,
+ )
+ .get()
+ )
+ except Image.DoesNotExist:
+ raise DataModelException("No image with specified id and repository")
- fetched.created = datetime.now()
- if created_date_str is not None:
- try:
- fetched.created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
- except:
- # parse raises different exceptions, so we cannot use a specific kind of handler here.
- pass
+ fetched.created = datetime.now()
+ if created_date_str is not None:
+ try:
+ fetched.created = dateutil.parser.parse(created_date_str).replace(
+ tzinfo=None
+ )
+ except:
+ # parse raises different exceptions, so we cannot use a specific kind of handler here.
+ pass
- # We cleanup any old checksum in case it's a retry after a fail
- fetched.v1_checksum = None
- fetched.comment = comment
- fetched.command = command
- fetched.v1_json_metadata = v1_json_metadata
+ # We cleanup any old checksum in case it's a retry after a fail
+ fetched.v1_checksum = None
+ fetched.comment = comment
+ fetched.command = command
+ fetched.v1_json_metadata = v1_json_metadata
- if parent:
- fetched.ancestors = '%s%s/' % (parent.ancestors, parent.id)
- fetched.parent = parent
+ if parent:
+ fetched.ancestors = "%s%s/" % (parent.ancestors, parent.id)
+ fetched.parent = parent
- fetched.save()
- return fetched
+ fetched.save()
+ return fetched
def get_image(repo, docker_image_id):
- try:
- return (Image
- .select(Image, ImageStorage)
+ try:
+ return (
+ Image.select(Image, ImageStorage)
.join(ImageStorage)
.where(Image.docker_image_id == docker_image_id, Image.repository == repo)
- .get())
- except Image.DoesNotExist:
- return None
+ .get()
+ )
+ except Image.DoesNotExist:
+ return None
def get_image_by_db_id(id):
- try:
- return Image.get(id=id)
- except Image.DoesNotExist:
- return None
+ try:
+ return Image.get(id=id)
+ except Image.DoesNotExist:
+ return None
-def synthesize_v1_image(repo, image_storage_id, storage_image_size, docker_image_id,
- created_date_str, comment, command, v1_json_metadata, parent_image=None):
- """ Find an existing image with this docker image id, and if none exists, write one with the
+def synthesize_v1_image(
+ repo,
+ image_storage_id,
+ storage_image_size,
+ docker_image_id,
+ created_date_str,
+ comment,
+ command,
+ v1_json_metadata,
+ parent_image=None,
+):
+ """ Find an existing image with this docker image id, and if none exists, write one with the
specified metadata.
"""
- ancestors = '/'
- if parent_image is not None:
- ancestors = '{0}{1}/'.format(parent_image.ancestors, parent_image.id)
+ ancestors = "/"
+ if parent_image is not None:
+ ancestors = "{0}{1}/".format(parent_image.ancestors, parent_image.id)
+
+ created = None
+ if created_date_str is not None:
+ try:
+ created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
+ except:
+ # parse raises different exceptions, so we cannot use a specific kind of handler here.
+ pass
+
+ # Get the aggregate size for the image.
+ aggregate_size = _basequery.calculate_image_aggregate_size(
+ ancestors, storage_image_size, parent_image
+ )
- created = None
- if created_date_str is not None:
try:
- created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
- except:
- # parse raises different exceptions, so we cannot use a specific kind of handler here.
- pass
-
- # Get the aggregate size for the image.
- aggregate_size = _basequery.calculate_image_aggregate_size(ancestors, storage_image_size,
- parent_image)
-
- try:
- return Image.create(docker_image_id=docker_image_id, ancestors=ancestors, comment=comment,
- command=command, v1_json_metadata=v1_json_metadata, created=created,
- storage=image_storage_id, repository=repo, parent=parent_image,
- aggregate_size=aggregate_size)
- except IntegrityError:
- return Image.get(docker_image_id=docker_image_id, repository=repo)
+ return Image.create(
+ docker_image_id=docker_image_id,
+ ancestors=ancestors,
+ comment=comment,
+ command=command,
+ v1_json_metadata=v1_json_metadata,
+ created=created,
+ storage=image_storage_id,
+ repository=repo,
+ parent=parent_image,
+ aggregate_size=aggregate_size,
+ )
+ except IntegrityError:
+ return Image.get(docker_image_id=docker_image_id, repository=repo)
def ensure_image_locations(*names):
- with db_transaction():
- locations = ImageStorageLocation.select().where(ImageStorageLocation.name << names)
+ with db_transaction():
+ locations = ImageStorageLocation.select().where(
+ ImageStorageLocation.name << names
+ )
- insert_names = list(names)
+ insert_names = list(names)
- for location in locations:
- insert_names.remove(location.name)
+ for location in locations:
+ insert_names.remove(location.name)
- if not insert_names:
- return
+ if not insert_names:
+ return
- data = [{'name': name} for name in insert_names]
- ImageStorageLocation.insert_many(data).execute()
+ data = [{"name": name} for name in insert_names]
+ ImageStorageLocation.insert_many(data).execute()
def get_max_id_for_sec_scan():
- """ Gets the maximum id for a clair sec scan """
- return Image.select(fn.Max(Image.id)).scalar()
+ """ Gets the maximum id for a clair sec scan """
+ return Image.select(fn.Max(Image.id)).scalar()
def get_min_id_for_sec_scan(version):
- """ Gets the minimum id for a clair sec scan """
- return (Image
- .select(fn.Min(Image.id))
- .where(Image.security_indexed_engine < version)
- .scalar())
+ """ Gets the minimum id for a clair sec scan """
+ return (
+ Image.select(fn.Min(Image.id))
+ .where(Image.security_indexed_engine < version)
+ .scalar()
+ )
def total_image_count():
- """ Returns the total number of images in DB """
- return Image.select().count()
+ """ Returns the total number of images in DB """
+ return Image.select().count()
def get_image_pk_field():
- """ Returns the primary key for Image DB model """
- return Image.id
+ """ Returns the primary key for Image DB model """
+ return Image.id
def get_images_eligible_for_scan(clair_version):
- """ Returns a query that gives all images eligible for a clair scan """
- return (get_image_with_storage_and_parent_base()
- .where(Image.security_indexed_engine < clair_version)
- .where(ImageStorage.uploading == False))
+ """ Returns a query that gives all images eligible for a clair scan """
+ return (
+ get_image_with_storage_and_parent_base()
+ .where(Image.security_indexed_engine < clair_version)
+ .where(ImageStorage.uploading == False)
+ )
def get_image_with_storage_and_parent_base():
- Parent = Image.alias()
- ParentImageStorage = ImageStorage.alias()
+ Parent = Image.alias()
+ ParentImageStorage = ImageStorage.alias()
- return (Image
- .select(Image, ImageStorage, Parent, ParentImageStorage)
- .join(ImageStorage)
- .switch(Image)
- .join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
- .join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage)))
+ return (
+ Image.select(Image, ImageStorage, Parent, ParentImageStorage)
+ .join(ImageStorage)
+ .switch(Image)
+ .join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
+ .join(
+ ParentImageStorage,
+ JOIN.LEFT_OUTER,
+ on=(ParentImageStorage.id == Parent.storage),
+ )
+ )
def set_secscan_status(image, indexed, version):
- return (Image
- .update(security_indexed=indexed, security_indexed_engine=version)
- .where(Image.id == image.id)
- .where((Image.security_indexed_engine != version) | (Image.security_indexed != indexed))
- .execute()) != 0
+ return (
+ Image.update(security_indexed=indexed, security_indexed_engine=version)
+ .where(Image.id == image.id)
+ .where(
+ (Image.security_indexed_engine != version)
+ | (Image.security_indexed != indexed)
+ )
+ .execute()
+ ) != 0
def _get_uniqueness_hash(varying_metadata):
- if not varying_metadata:
- return None
+ if not varying_metadata:
+ return None
- return hashlib.sha256(json.dumps(canonicalize(varying_metadata))).hexdigest()
+ return hashlib.sha256(json.dumps(canonicalize(varying_metadata))).hexdigest()
-def find_or_create_derived_storage(source_image, transformation_name, preferred_location,
- varying_metadata=None):
- existing = find_derived_storage_for_image(source_image, transformation_name, varying_metadata)
- if existing is not None:
- return existing
+def find_or_create_derived_storage(
+ source_image, transformation_name, preferred_location, varying_metadata=None
+):
+ existing = find_derived_storage_for_image(
+ source_image, transformation_name, varying_metadata
+ )
+ if existing is not None:
+ return existing
- uniqueness_hash = _get_uniqueness_hash(varying_metadata)
- trans = ImageStorageTransformation.get(name=transformation_name)
- new_storage = storage.create_v1_storage(preferred_location)
+ uniqueness_hash = _get_uniqueness_hash(varying_metadata)
+ trans = ImageStorageTransformation.get(name=transformation_name)
+ new_storage = storage.create_v1_storage(preferred_location)
- try:
- derived = DerivedStorageForImage.create(source_image=source_image, derivative=new_storage,
- transformation=trans, uniqueness_hash=uniqueness_hash)
- except IntegrityError:
- # Storage was created while this method executed. Just return the existing.
- ImageStoragePlacement.delete().where(ImageStoragePlacement.storage == new_storage).execute()
- new_storage.delete_instance()
- return find_derived_storage_for_image(source_image, transformation_name, varying_metadata)
+ try:
+ derived = DerivedStorageForImage.create(
+ source_image=source_image,
+ derivative=new_storage,
+ transformation=trans,
+ uniqueness_hash=uniqueness_hash,
+ )
+ except IntegrityError:
+ # Storage was created while this method executed. Just return the existing.
+ ImageStoragePlacement.delete().where(
+ ImageStoragePlacement.storage == new_storage
+ ).execute()
+ new_storage.delete_instance()
+ return find_derived_storage_for_image(
+ source_image, transformation_name, varying_metadata
+ )
- return derived
+ return derived
-def find_derived_storage_for_image(source_image, transformation_name, varying_metadata=None):
- uniqueness_hash = _get_uniqueness_hash(varying_metadata)
+def find_derived_storage_for_image(
+ source_image, transformation_name, varying_metadata=None
+):
+ uniqueness_hash = _get_uniqueness_hash(varying_metadata)
- try:
- found = (DerivedStorageForImage
- .select(ImageStorage, DerivedStorageForImage)
- .join(ImageStorage)
- .switch(DerivedStorageForImage)
- .join(ImageStorageTransformation)
- .where(DerivedStorageForImage.source_image == source_image,
- ImageStorageTransformation.name == transformation_name,
- DerivedStorageForImage.uniqueness_hash == uniqueness_hash)
- .get())
- return found
- except DerivedStorageForImage.DoesNotExist:
- return None
+ try:
+ found = (
+ DerivedStorageForImage.select(ImageStorage, DerivedStorageForImage)
+ .join(ImageStorage)
+ .switch(DerivedStorageForImage)
+ .join(ImageStorageTransformation)
+ .where(
+ DerivedStorageForImage.source_image == source_image,
+ ImageStorageTransformation.name == transformation_name,
+ DerivedStorageForImage.uniqueness_hash == uniqueness_hash,
+ )
+ .get()
+ )
+ return found
+ except DerivedStorageForImage.DoesNotExist:
+ return None
def delete_derived_storage(derived_storage):
- derived_storage.derivative.delete_instance(recursive=True)
+ derived_storage.derivative.delete_instance(recursive=True)
diff --git a/data/model/label.py b/data/model/label.py
index fce7479ba..0268363b2 100644
--- a/data/model/label.py
+++ b/data/model/label.py
@@ -2,9 +2,21 @@ import logging
from cachetools.func import lru_cache
-from data.database import (Label, TagManifestLabel, MediaType, LabelSourceType, db_transaction,
- ManifestLabel, TagManifestLabelMap, TagManifestToManifest)
-from data.model import InvalidLabelKeyException, InvalidMediaTypeException, DataModelException
+from data.database import (
+ Label,
+ TagManifestLabel,
+ MediaType,
+ LabelSourceType,
+ db_transaction,
+ ManifestLabel,
+ TagManifestLabelMap,
+ TagManifestToManifest,
+)
+from data.model import (
+ InvalidLabelKeyException,
+ InvalidMediaTypeException,
+ DataModelException,
+)
from data.text import prefix_search
from util.validation import validate_label_key
from util.validation import is_json
@@ -14,130 +26,147 @@ logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_label_source_types():
- source_type_map = {}
- for kind in LabelSourceType.select():
- source_type_map[kind.id] = kind.name
- source_type_map[kind.name] = kind.id
+ source_type_map = {}
+ for kind in LabelSourceType.select():
+ source_type_map[kind.id] = kind.name
+ source_type_map[kind.name] = kind.id
- return source_type_map
+ return source_type_map
@lru_cache(maxsize=1)
def get_media_types():
- media_type_map = {}
- for kind in MediaType.select():
- media_type_map[kind.id] = kind.name
- media_type_map[kind.name] = kind.id
+ media_type_map = {}
+ for kind in MediaType.select():
+ media_type_map[kind.id] = kind.name
+ media_type_map[kind.name] = kind.id
- return media_type_map
+ return media_type_map
def _get_label_source_type_id(name):
- kinds = get_label_source_types()
- return kinds[name]
+ kinds = get_label_source_types()
+ return kinds[name]
def _get_media_type_id(name):
- kinds = get_media_types()
- return kinds[name]
+ kinds = get_media_types()
+ return kinds[name]
-def create_manifest_label(tag_manifest, key, value, source_type_name, media_type_name=None):
- """ Creates a new manifest label on a specific tag manifest. """
- if not key:
- raise InvalidLabelKeyException()
+def create_manifest_label(
+ tag_manifest, key, value, source_type_name, media_type_name=None
+):
+ """ Creates a new manifest label on a specific tag manifest. """
+ if not key:
+ raise InvalidLabelKeyException()
- # Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
- # does not currently prevent them from being put into said manifests.
- if not validate_label_key(key) and source_type_name != 'manifest':
- raise InvalidLabelKeyException()
+ # Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
+ # does not currently prevent them from being put into said manifests.
+ if not validate_label_key(key) and source_type_name != "manifest":
+ raise InvalidLabelKeyException()
- # Find the matching media type. If none specified, we infer.
- if media_type_name is None:
- media_type_name = 'text/plain'
- if is_json(value):
- media_type_name = 'application/json'
+ # Find the matching media type. If none specified, we infer.
+ if media_type_name is None:
+ media_type_name = "text/plain"
+ if is_json(value):
+ media_type_name = "application/json"
- media_type_id = _get_media_type_id(media_type_name)
- if media_type_id is None:
- raise InvalidMediaTypeException()
+ media_type_id = _get_media_type_id(media_type_name)
+ if media_type_id is None:
+ raise InvalidMediaTypeException()
- source_type_id = _get_label_source_type_id(source_type_name)
+ source_type_id = _get_label_source_type_id(source_type_name)
- with db_transaction():
- label = Label.create(key=key, value=value, source_type=source_type_id, media_type=media_type_id)
- tag_manifest_label = TagManifestLabel.create(annotated=tag_manifest, label=label,
- repository=tag_manifest.tag.repository)
- try:
- mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
- if mapping_row.manifest:
- manifest_label = ManifestLabel.create(manifest=mapping_row.manifest, label=label,
- repository=tag_manifest.tag.repository)
- TagManifestLabelMap.create(manifest_label=manifest_label,
- tag_manifest_label=tag_manifest_label,
- label=label,
- manifest=mapping_row.manifest,
- tag_manifest=tag_manifest)
- except TagManifestToManifest.DoesNotExist:
- pass
+ with db_transaction():
+ label = Label.create(
+ key=key, value=value, source_type=source_type_id, media_type=media_type_id
+ )
+ tag_manifest_label = TagManifestLabel.create(
+ annotated=tag_manifest, label=label, repository=tag_manifest.tag.repository
+ )
+ try:
+ mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
+ if mapping_row.manifest:
+ manifest_label = ManifestLabel.create(
+ manifest=mapping_row.manifest,
+ label=label,
+ repository=tag_manifest.tag.repository,
+ )
+ TagManifestLabelMap.create(
+ manifest_label=manifest_label,
+ tag_manifest_label=tag_manifest_label,
+ label=label,
+ manifest=mapping_row.manifest,
+ tag_manifest=tag_manifest,
+ )
+ except TagManifestToManifest.DoesNotExist:
+ pass
- return label
+ return label
def list_manifest_labels(tag_manifest, prefix_filter=None):
- """ Lists all labels found on the given tag manifest. """
- query = (Label.select(Label, MediaType)
- .join(MediaType)
- .switch(Label)
- .join(LabelSourceType)
- .switch(Label)
- .join(TagManifestLabel)
- .where(TagManifestLabel.annotated == tag_manifest))
+ """ Lists all labels found on the given tag manifest. """
+ query = (
+ Label.select(Label, MediaType)
+ .join(MediaType)
+ .switch(Label)
+ .join(LabelSourceType)
+ .switch(Label)
+ .join(TagManifestLabel)
+ .where(TagManifestLabel.annotated == tag_manifest)
+ )
- if prefix_filter is not None:
- query = query.where(prefix_search(Label.key, prefix_filter))
+ if prefix_filter is not None:
+ query = query.where(prefix_search(Label.key, prefix_filter))
- return query
+ return query
def get_manifest_label(label_uuid, tag_manifest):
- """ Retrieves the manifest label on the tag manifest with the given ID. """
- try:
- return (Label.select(Label, LabelSourceType)
+ """ Retrieves the manifest label on the tag manifest with the given ID. """
+ try:
+ return (
+ Label.select(Label, LabelSourceType)
.join(LabelSourceType)
.where(Label.uuid == label_uuid)
.switch(Label)
.join(TagManifestLabel)
.where(TagManifestLabel.annotated == tag_manifest)
- .get())
- except Label.DoesNotExist:
- return None
+ .get()
+ )
+ except Label.DoesNotExist:
+ return None
def delete_manifest_label(label_uuid, tag_manifest):
- """ Deletes the manifest label on the tag manifest with the given ID. """
+ """ Deletes the manifest label on the tag manifest with the given ID. """
- # Find the label itself.
- label = get_manifest_label(label_uuid, tag_manifest)
- if label is None:
- return None
+ # Find the label itself.
+ label = get_manifest_label(label_uuid, tag_manifest)
+ if label is None:
+ return None
- if not label.source_type.mutable:
- raise DataModelException('Cannot delete immutable label')
+ if not label.source_type.mutable:
+ raise DataModelException("Cannot delete immutable label")
- # Delete the mapping records and label.
- (TagManifestLabelMap
- .delete()
- .where(TagManifestLabelMap.label == label)
- .execute())
+ # Delete the mapping records and label.
+ (TagManifestLabelMap.delete().where(TagManifestLabelMap.label == label).execute())
- deleted_count = TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
- if deleted_count != 1:
- logger.warning('More than a single label deleted for matching label %s', label_uuid)
+ deleted_count = (
+ TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
+ )
+ if deleted_count != 1:
+ logger.warning(
+ "More than a single label deleted for matching label %s", label_uuid
+ )
- deleted_count = ManifestLabel.delete().where(ManifestLabel.label == label).execute()
- if deleted_count != 1:
- logger.warning('More than a single label deleted for matching label %s', label_uuid)
+ deleted_count = ManifestLabel.delete().where(ManifestLabel.label == label).execute()
+ if deleted_count != 1:
+ logger.warning(
+ "More than a single label deleted for matching label %s", label_uuid
+ )
- label.delete_instance(recursive=False)
- return label
+ label.delete_instance(recursive=False)
+ return label
diff --git a/data/model/log.py b/data/model/log.py
index e78ec4b1b..da51a590c 100644
--- a/data/model/log.py
+++ b/data/model/log.py
@@ -12,288 +12,393 @@ from data.model import config, user, DataModelException
logger = logging.getLogger(__name__)
-ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING = ['pull_repo']
+ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING = ["pull_repo"]
-def _logs_query(selections, start_time=None, end_time=None, performer=None, repository=None,
- namespace=None, ignore=None, model=LogEntry3, id_range=None):
- """ Returns a query for selecting logs from the table, with various options and filters. """
- assert (start_time is not None and end_time is not None) or (id_range is not None)
- joined = (model.select(*selections).switch(model))
+def _logs_query(
+ selections,
+ start_time=None,
+ end_time=None,
+ performer=None,
+ repository=None,
+ namespace=None,
+ ignore=None,
+ model=LogEntry3,
+ id_range=None,
+):
+ """ Returns a query for selecting logs from the table, with various options and filters. """
+ assert (start_time is not None and end_time is not None) or (id_range is not None)
+ joined = model.select(*selections).switch(model)
- if id_range is not None:
- joined = joined.where(model.id >= id_range[0], model.id <= id_range[1])
- else:
- joined = joined.where(model.datetime >= start_time, model.datetime < end_time)
+ if id_range is not None:
+ joined = joined.where(model.id >= id_range[0], model.id <= id_range[1])
+ else:
+ joined = joined.where(model.datetime >= start_time, model.datetime < end_time)
- if repository:
- joined = joined.where(model.repository == repository)
+ if repository:
+ joined = joined.where(model.repository == repository)
- if performer:
- joined = joined.where(model.performer == performer)
+ if performer:
+ joined = joined.where(model.performer == performer)
- if namespace and not repository:
- namespace_user = user.get_user_or_org(namespace)
- if namespace_user is None:
- raise DataModelException('Invalid namespace requested')
+ if namespace and not repository:
+ namespace_user = user.get_user_or_org(namespace)
+ if namespace_user is None:
+ raise DataModelException("Invalid namespace requested")
- joined = joined.where(model.account == namespace_user.id)
+ joined = joined.where(model.account == namespace_user.id)
- if ignore:
- kind_map = get_log_entry_kinds()
- ignore_ids = [kind_map[kind_name] for kind_name in ignore]
- joined = joined.where(~(model.kind << ignore_ids))
+ if ignore:
+ kind_map = get_log_entry_kinds()
+ ignore_ids = [kind_map[kind_name] for kind_name in ignore]
+ joined = joined.where(~(model.kind << ignore_ids))
- return joined
+ return joined
-def _latest_logs_query(selections, performer=None, repository=None, namespace=None, ignore=None,
- model=LogEntry3, size=None):
- """ Returns a query for selecting the latest logs from the table, with various options and
+def _latest_logs_query(
+ selections,
+ performer=None,
+ repository=None,
+ namespace=None,
+ ignore=None,
+ model=LogEntry3,
+ size=None,
+):
+ """ Returns a query for selecting the latest logs from the table, with various options and
filters. """
- query = (model.select(*selections).switch(model))
+ query = model.select(*selections).switch(model)
- if repository:
- query = query.where(model.repository == repository)
+ if repository:
+ query = query.where(model.repository == repository)
- if performer:
- query = query.where(model.repository == repository)
+ if performer:
+ query = query.where(model.repository == repository)
- if namespace and not repository:
- namespace_user = user.get_user_or_org(namespace)
- if namespace_user is None:
- raise DataModelException('Invalid namespace requested')
+ if namespace and not repository:
+ namespace_user = user.get_user_or_org(namespace)
+ if namespace_user is None:
+ raise DataModelException("Invalid namespace requested")
- query = query.where(model.account == namespace_user.id)
+ query = query.where(model.account == namespace_user.id)
- if ignore:
- kind_map = get_log_entry_kinds()
- ignore_ids = [kind_map[kind_name] for kind_name in ignore]
- query = query.where(~(model.kind << ignore_ids))
+ if ignore:
+ kind_map = get_log_entry_kinds()
+ ignore_ids = [kind_map[kind_name] for kind_name in ignore]
+ query = query.where(~(model.kind << ignore_ids))
- query = query.order_by(model.datetime.desc(), model.id)
+ query = query.order_by(model.datetime.desc(), model.id)
- if size:
- query = query.limit(size)
+ if size:
+ query = query.limit(size)
- return query
+ return query
@lru_cache(maxsize=1)
def get_log_entry_kinds():
- kind_map = {}
- for kind in LogEntryKind.select():
- kind_map[kind.id] = kind.name
- kind_map[kind.name] = kind.id
+ kind_map = {}
+ for kind in LogEntryKind.select():
+ kind_map[kind.id] = kind.name
+ kind_map[kind.name] = kind.id
- return kind_map
+ return kind_map
def _get_log_entry_kind(name):
- kinds = get_log_entry_kinds()
- return kinds[name]
+ kinds = get_log_entry_kinds()
+ return kinds[name]
-def get_aggregated_logs(start_time, end_time, performer=None, repository=None, namespace=None,
- ignore=None, model=LogEntry3):
- """ Returns the count of logs, by kind and day, for the logs matching the given filters. """
- date = db.extract_date('day', model.datetime)
- selections = [model.kind, date.alias('day'), fn.Count(model.id).alias('count')]
- query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore,
- model=model)
- return query.group_by(date, model.kind)
+def get_aggregated_logs(
+ start_time,
+ end_time,
+ performer=None,
+ repository=None,
+ namespace=None,
+ ignore=None,
+ model=LogEntry3,
+):
+ """ Returns the count of logs, by kind and day, for the logs matching the given filters. """
+ date = db.extract_date("day", model.datetime)
+ selections = [model.kind, date.alias("day"), fn.Count(model.id).alias("count")]
+ query = _logs_query(
+ selections,
+ start_time,
+ end_time,
+ performer,
+ repository,
+ namespace,
+ ignore,
+ model=model,
+ )
+ return query.group_by(date, model.kind)
-def get_logs_query(start_time=None, end_time=None, performer=None, repository=None, namespace=None,
- ignore=None, model=LogEntry3, id_range=None):
- """ Returns the logs matching the given filters. """
- Performer = User.alias()
- Account = User.alias()
- selections = [model, Performer]
+def get_logs_query(
+ start_time=None,
+ end_time=None,
+ performer=None,
+ repository=None,
+ namespace=None,
+ ignore=None,
+ model=LogEntry3,
+ id_range=None,
+):
+ """ Returns the logs matching the given filters. """
+ Performer = User.alias()
+ Account = User.alias()
+ selections = [model, Performer]
- if namespace is None and repository is None:
- selections.append(Account)
+ if namespace is None and repository is None:
+ selections.append(Account)
- query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore,
- model=model, id_range=id_range)
- query = (query.switch(model).join(Performer, JOIN.LEFT_OUTER,
- on=(model.performer == Performer.id).alias('performer')))
+ query = _logs_query(
+ selections,
+ start_time,
+ end_time,
+ performer,
+ repository,
+ namespace,
+ ignore,
+ model=model,
+ id_range=id_range,
+ )
+ query = query.switch(model).join(
+ Performer,
+ JOIN.LEFT_OUTER,
+ on=(model.performer == Performer.id).alias("performer"),
+ )
- if namespace is None and repository is None:
- query = (query.switch(model).join(Account, JOIN.LEFT_OUTER,
- on=(model.account == Account.id).alias('account')))
+ if namespace is None and repository is None:
+ query = query.switch(model).join(
+ Account, JOIN.LEFT_OUTER, on=(model.account == Account.id).alias("account")
+ )
- return query
+ return query
-def get_latest_logs_query(performer=None, repository=None, namespace=None, ignore=None,
- model=LogEntry3, size=None):
- """ Returns the latest logs matching the given filters. """
- Performer = User.alias()
- Account = User.alias()
- selections = [model, Performer]
+def get_latest_logs_query(
+ performer=None,
+ repository=None,
+ namespace=None,
+ ignore=None,
+ model=LogEntry3,
+ size=None,
+):
+ """ Returns the latest logs matching the given filters. """
+ Performer = User.alias()
+ Account = User.alias()
+ selections = [model, Performer]
- if namespace is None and repository is None:
- selections.append(Account)
+ if namespace is None and repository is None:
+ selections.append(Account)
- query = _latest_logs_query(selections, performer, repository, namespace, ignore, model=model,
- size=size)
- query = (query.switch(model).join(Performer, JOIN.LEFT_OUTER,
- on=(model.performer == Performer.id).alias('performer')))
+ query = _latest_logs_query(
+ selections, performer, repository, namespace, ignore, model=model, size=size
+ )
+ query = query.switch(model).join(
+ Performer,
+ JOIN.LEFT_OUTER,
+ on=(model.performer == Performer.id).alias("performer"),
+ )
- if namespace is None and repository is None:
- query = (query.switch(model).join(Account, JOIN.LEFT_OUTER,
- on=(model.account == Account.id).alias('account')))
+ if namespace is None and repository is None:
+ query = query.switch(model).join(
+ Account, JOIN.LEFT_OUTER, on=(model.account == Account.id).alias("account")
+ )
- return query
+ return query
def _json_serialize(obj):
- if isinstance(obj, datetime):
- return timegm(obj.utctimetuple())
+ if isinstance(obj, datetime):
+ return timegm(obj.utctimetuple())
- return obj
+ return obj
-def log_action(kind_name, user_or_organization_name, performer=None, repository=None, ip=None,
- metadata={}, timestamp=None):
- """ Logs an entry in the LogEntry table. """
- if not timestamp:
- timestamp = datetime.today()
+def log_action(
+ kind_name,
+ user_or_organization_name,
+ performer=None,
+ repository=None,
+ ip=None,
+ metadata={},
+ timestamp=None,
+):
+ """ Logs an entry in the LogEntry table. """
+ if not timestamp:
+ timestamp = datetime.today()
- account = None
- if user_or_organization_name is not None:
- account = User.get(User.username == user_or_organization_name).id
- else:
- account = config.app_config.get('SERVICE_LOG_ACCOUNT_ID')
- if account is None:
- account = user.get_minimum_user_id()
-
- if performer is not None:
- performer = performer.id
-
- if repository is not None:
- repository = repository.id
-
- kind = _get_log_entry_kind(kind_name)
- metadata_json = json.dumps(metadata, default=_json_serialize)
- log_data = {
- 'kind': kind,
- 'account': account,
- 'performer': performer,
- 'repository': repository,
- 'ip': ip,
- 'metadata_json': metadata_json,
- 'datetime': timestamp
- }
-
- try:
- LogEntry3.create(**log_data)
- except PeeweeException as ex:
- strict_logging_disabled = config.app_config.get('ALLOW_PULLS_WITHOUT_STRICT_LOGGING')
- if strict_logging_disabled and kind_name in ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING:
- logger.exception('log_action failed', extra=({'exception': ex}).update(log_data))
+ account = None
+ if user_or_organization_name is not None:
+ account = User.get(User.username == user_or_organization_name).id
else:
- raise
+ account = config.app_config.get("SERVICE_LOG_ACCOUNT_ID")
+ if account is None:
+ account = user.get_minimum_user_id()
+
+ if performer is not None:
+ performer = performer.id
+
+ if repository is not None:
+ repository = repository.id
+
+ kind = _get_log_entry_kind(kind_name)
+ metadata_json = json.dumps(metadata, default=_json_serialize)
+ log_data = {
+ "kind": kind,
+ "account": account,
+ "performer": performer,
+ "repository": repository,
+ "ip": ip,
+ "metadata_json": metadata_json,
+ "datetime": timestamp,
+ }
+
+ try:
+ LogEntry3.create(**log_data)
+ except PeeweeException as ex:
+ strict_logging_disabled = config.app_config.get(
+ "ALLOW_PULLS_WITHOUT_STRICT_LOGGING"
+ )
+ if (
+ strict_logging_disabled
+ and kind_name in ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING
+ ):
+ logger.exception(
+ "log_action failed", extra=({"exception": ex}).update(log_data)
+ )
+ else:
+ raise
def get_stale_logs_start_id(model):
- """ Gets the oldest log entry. """
- try:
- return (model.select(fn.Min(model.id)).tuples())[0][0]
- except IndexError:
- return None
+ """ Gets the oldest log entry. """
+ try:
+ return (model.select(fn.Min(model.id)).tuples())[0][0]
+ except IndexError:
+ return None
def get_stale_logs(start_id, end_id, model, cutoff_date):
- """ Returns all the logs with IDs between start_id and end_id inclusively. """
- return model.select().where((model.id >= start_id),
- (model.id <= end_id),
- model.datetime <= cutoff_date)
+ """ Returns all the logs with IDs between start_id and end_id inclusively. """
+ return model.select().where(
+ (model.id >= start_id), (model.id <= end_id), model.datetime <= cutoff_date
+ )
def delete_stale_logs(start_id, end_id, model):
- """ Deletes all the logs with IDs between start_id and end_id. """
- model.delete().where((model.id >= start_id), (model.id <= end_id)).execute()
+ """ Deletes all the logs with IDs between start_id and end_id. """
+ model.delete().where((model.id >= start_id), (model.id <= end_id)).execute()
def get_repository_action_counts(repo, start_date):
- """ Returns the daily aggregated action counts for the given repository, starting at the given
+ """ Returns the daily aggregated action counts for the given repository, starting at the given
start date.
"""
- return RepositoryActionCount.select().where(RepositoryActionCount.repository == repo,
- RepositoryActionCount.date >= start_date)
+ return RepositoryActionCount.select().where(
+ RepositoryActionCount.repository == repo,
+ RepositoryActionCount.date >= start_date,
+ )
def get_repositories_action_sums(repository_ids):
- """ Returns a map from repository ID to total actions within that repository in the last week. """
- if not repository_ids:
- return {}
+ """ Returns a map from repository ID to total actions within that repository in the last week. """
+ if not repository_ids:
+ return {}
- # Filter the join to recent entries only.
- last_week = datetime.now() - timedelta(weeks=1)
- tuples = (RepositoryActionCount.select(RepositoryActionCount.repository,
- fn.Sum(RepositoryActionCount.count))
- .where(RepositoryActionCount.repository << repository_ids)
- .where(RepositoryActionCount.date >= last_week)
- .group_by(RepositoryActionCount.repository).tuples())
+ # Filter the join to recent entries only.
+ last_week = datetime.now() - timedelta(weeks=1)
+ tuples = (
+ RepositoryActionCount.select(
+ RepositoryActionCount.repository, fn.Sum(RepositoryActionCount.count)
+ )
+ .where(RepositoryActionCount.repository << repository_ids)
+ .where(RepositoryActionCount.date >= last_week)
+ .group_by(RepositoryActionCount.repository)
+ .tuples()
+ )
- action_count_map = {}
- for record in tuples:
- action_count_map[record[0]] = record[1]
+ action_count_map = {}
+ for record in tuples:
+ action_count_map[record[0]] = record[1]
- return action_count_map
+ return action_count_map
-def get_minimum_id_for_logs(start_time, repository_id=None, namespace_id=None, model=LogEntry3):
- """ Returns the minimum ID for logs matching the given repository or namespace in
+def get_minimum_id_for_logs(
+ start_time, repository_id=None, namespace_id=None, model=LogEntry3
+):
+ """ Returns the minimum ID for logs matching the given repository or namespace in
the logs table, starting at the given start time.
"""
- # First try bounded by a day. Most repositories will meet this criteria, and therefore
- # can make a much faster query.
- day_after = start_time + timedelta(days=1)
- result = _get_bounded_id(fn.Min, model.datetime >= start_time,
- repository_id, namespace_id, model.datetime < day_after, model=model)
- if result is not None:
- return result
+ # First try bounded by a day. Most repositories will meet this criteria, and therefore
+ # can make a much faster query.
+ day_after = start_time + timedelta(days=1)
+ result = _get_bounded_id(
+ fn.Min,
+ model.datetime >= start_time,
+ repository_id,
+ namespace_id,
+ model.datetime < day_after,
+ model=model,
+ )
+ if result is not None:
+ return result
- return _get_bounded_id(fn.Min, model.datetime >= start_time, repository_id, namespace_id,
- model=model)
+ return _get_bounded_id(
+ fn.Min, model.datetime >= start_time, repository_id, namespace_id, model=model
+ )
-def get_maximum_id_for_logs(end_time, repository_id=None, namespace_id=None, model=LogEntry3):
- """ Returns the maximum ID for logs matching the given repository or namespace in
+def get_maximum_id_for_logs(
+ end_time, repository_id=None, namespace_id=None, model=LogEntry3
+):
+ """ Returns the maximum ID for logs matching the given repository or namespace in
the logs table, ending at the given end time.
"""
- # First try bounded by a day. Most repositories will meet this criteria, and therefore
- # can make a much faster query.
- day_before = end_time - timedelta(days=1)
- result = _get_bounded_id(fn.Max, model.datetime <= end_time,
- repository_id, namespace_id, model.datetime > day_before, model=model)
- if result is not None:
- return result
+ # First try bounded by a day. Most repositories will meet this criteria, and therefore
+ # can make a much faster query.
+ day_before = end_time - timedelta(days=1)
+ result = _get_bounded_id(
+ fn.Max,
+ model.datetime <= end_time,
+ repository_id,
+ namespace_id,
+ model.datetime > day_before,
+ model=model,
+ )
+ if result is not None:
+ return result
- return _get_bounded_id(fn.Max, model.datetime <= end_time, repository_id, namespace_id,
- model=model)
+ return _get_bounded_id(
+ fn.Max, model.datetime <= end_time, repository_id, namespace_id, model=model
+ )
-def _get_bounded_id(fn, filter_clause, repository_id, namespace_id, reduction_clause=None,
- model=LogEntry3):
- assert (namespace_id is not None) or (repository_id is not None)
- query = (model
- .select(fn(model.id))
- .where(filter_clause))
+def _get_bounded_id(
+ fn,
+ filter_clause,
+ repository_id,
+ namespace_id,
+ reduction_clause=None,
+ model=LogEntry3,
+):
+ assert (namespace_id is not None) or (repository_id is not None)
+ query = model.select(fn(model.id)).where(filter_clause)
- if reduction_clause is not None:
- query = query.where(reduction_clause)
+ if reduction_clause is not None:
+ query = query.where(reduction_clause)
- if repository_id is not None:
- query = query.where(model.repository == repository_id)
- else:
- query = query.where(model.account == namespace_id)
+ if repository_id is not None:
+ query = query.where(model.repository == repository_id)
+ else:
+ query = query.where(model.account == namespace_id)
- row = query.tuples()[0]
- if not row:
- return None
+ row = query.tuples()[0]
+ if not row:
+ return None
- return row[0]
+ return row[0]
diff --git a/data/model/message.py b/data/model/message.py
index 24df4d0ba..911fb974e 100644
--- a/data/model/message.py
+++ b/data/model/message.py
@@ -2,23 +2,28 @@ from data.database import Messages, MediaType
def get_messages():
- """Query the data base for messages and returns a container of database message objects"""
- return Messages.select(Messages, MediaType).join(MediaType)
+ """Query the data base for messages and returns a container of database message objects"""
+ return Messages.select(Messages, MediaType).join(MediaType)
+
def create(messages):
- """Insert messages into the database."""
- inserted = []
- for message in messages:
- severity = message['severity']
- media_type_name = message['media_type']
- media_type = MediaType.get(name=media_type_name)
+ """Insert messages into the database."""
+ inserted = []
+ for message in messages:
+ severity = message["severity"]
+ media_type_name = message["media_type"]
+ media_type = MediaType.get(name=media_type_name)
+
+ inserted.append(
+ Messages.create(
+ content=message["content"], media_type=media_type, severity=severity
+ )
+ )
+ return inserted
- inserted.append(Messages.create(content=message['content'], media_type=media_type,
- severity=severity))
- return inserted
def delete_message(uuids):
- """Delete message from the database"""
- if not uuids:
- return
- Messages.delete().where(Messages.uuid << uuids).execute()
+ """Delete message from the database"""
+ if not uuids:
+ return
+ Messages.delete().where(Messages.uuid << uuids).execute()
diff --git a/data/model/modelutil.py b/data/model/modelutil.py
index 4048e4eff..52830fc39 100644
--- a/data/model/modelutil.py
+++ b/data/model/modelutil.py
@@ -5,73 +5,82 @@ from datetime import datetime
from peewee import SQL
-def paginate(query, model, descending=False, page_token=None, limit=50, sort_field_alias=None,
- max_page=None, sort_field_name=None):
- """ Paginates the given query using an field range, starting at the optional page_token.
+def paginate(
+ query,
+ model,
+ descending=False,
+ page_token=None,
+ limit=50,
+ sort_field_alias=None,
+ max_page=None,
+ sort_field_name=None,
+):
+ """ Paginates the given query using an field range, starting at the optional page_token.
Returns a *list* of matching results along with an unencrypted page_token for the
next page, if any. If descending is set to True, orders by the field descending rather
than ascending.
"""
- # Note: We use the sort_field_alias for the order_by, but not the where below. The alias is
- # necessary for certain queries that use unions in MySQL, as it gets confused on which field
- # to order by. The where clause, on the other hand, cannot use the alias because Postgres does
- # not allow aliases in where clauses.
- sort_field_name = sort_field_name or 'id'
- sort_field = getattr(model, sort_field_name)
+ # Note: We use the sort_field_alias for the order_by, but not the where below. The alias is
+ # necessary for certain queries that use unions in MySQL, as it gets confused on which field
+ # to order by. The where clause, on the other hand, cannot use the alias because Postgres does
+ # not allow aliases in where clauses.
+ sort_field_name = sort_field_name or "id"
+ sort_field = getattr(model, sort_field_name)
- if sort_field_alias is not None:
- sort_field_name = sort_field_alias
- sort_field = SQL(sort_field_alias)
+ if sort_field_alias is not None:
+ sort_field_name = sort_field_alias
+ sort_field = SQL(sort_field_alias)
- if descending:
- query = query.order_by(sort_field.desc())
- else:
- query = query.order_by(sort_field)
-
- start_index = pagination_start(page_token)
- if start_index is not None:
if descending:
- query = query.where(sort_field <= start_index)
+ query = query.order_by(sort_field.desc())
else:
- query = query.where(sort_field >= start_index)
+ query = query.order_by(sort_field)
- query = query.limit(limit + 1)
+ start_index = pagination_start(page_token)
+ if start_index is not None:
+ if descending:
+ query = query.where(sort_field <= start_index)
+ else:
+ query = query.where(sort_field >= start_index)
- page_number = (page_token.get('page_number') or None) if page_token else None
- if page_number is not None and max_page is not None and page_number > max_page:
- return [], None
+ query = query.limit(limit + 1)
- return paginate_query(query, limit=limit, sort_field_name=sort_field_name,
- page_number=page_number)
+ page_number = (page_token.get("page_number") or None) if page_token else None
+ if page_number is not None and max_page is not None and page_number > max_page:
+ return [], None
+
+ return paginate_query(
+ query, limit=limit, sort_field_name=sort_field_name, page_number=page_number
+ )
def pagination_start(page_token=None):
- """ Returns the start index for pagination for the given page token. Will return None if None. """
- if page_token is not None:
- start_index = page_token.get('start_index')
- if page_token.get('is_datetime'):
- start_index = dateutil.parser.parse(start_index)
- return start_index
- return None
+ """ Returns the start index for pagination for the given page token. Will return None if None. """
+ if page_token is not None:
+ start_index = page_token.get("start_index")
+ if page_token.get("is_datetime"):
+ start_index = dateutil.parser.parse(start_index)
+ return start_index
+ return None
def paginate_query(query, limit=50, sort_field_name=None, page_number=None):
- """ Executes the given query and returns a page's worth of results, as well as the page token
+ """ Executes the given query and returns a page's worth of results, as well as the page token
for the next page (if any).
"""
- results = list(query)
- page_token = None
- if len(results) > limit:
- start_index = getattr(results[limit], sort_field_name or 'id')
- is_datetime = False
- if isinstance(start_index, datetime):
- start_index = start_index.isoformat() + "Z"
- is_datetime = True
+ results = list(query)
+ page_token = None
+ if len(results) > limit:
+ start_index = getattr(results[limit], sort_field_name or "id")
+ is_datetime = False
+ if isinstance(start_index, datetime):
+ start_index = start_index.isoformat() + "Z"
+ is_datetime = True
- page_token = {
- 'start_index': start_index,
- 'page_number': page_number + 1 if page_number else 1,
- 'is_datetime': is_datetime,
- }
+ page_token = {
+ "start_index": start_index,
+ "page_number": page_number + 1 if page_number else 1,
+ "is_datetime": is_datetime,
+ }
- return results[0:limit], page_token
+ return results[0:limit], page_token
diff --git a/data/model/notification.py b/data/model/notification.py
index 11a84fea7..a8ed15c3e 100644
--- a/data/model/notification.py
+++ b/data/model/notification.py
@@ -2,219 +2,268 @@ import json
from peewee import SQL
-from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole,
- RepositoryNotification, ExternalNotificationEvent, Repository,
- ExternalNotificationMethod, Namespace, db_for_update)
+from data.database import (
+ Notification,
+ NotificationKind,
+ User,
+ Team,
+ TeamMember,
+ TeamRole,
+ RepositoryNotification,
+ ExternalNotificationEvent,
+ Repository,
+ ExternalNotificationMethod,
+ Namespace,
+ db_for_update,
+)
from data.model import InvalidNotificationException, db_transaction
def create_notification(kind_name, target, metadata={}, lookup_path=None):
- kind_ref = NotificationKind.get(name=kind_name)
- notification = Notification.create(kind=kind_ref, target=target,
- metadata_json=json.dumps(metadata),
- lookup_path=lookup_path)
- return notification
+ kind_ref = NotificationKind.get(name=kind_name)
+ notification = Notification.create(
+ kind=kind_ref,
+ target=target,
+ metadata_json=json.dumps(metadata),
+ lookup_path=lookup_path,
+ )
+ return notification
def create_unique_notification(kind_name, target, metadata={}):
- with db_transaction():
- if list_notifications(target, kind_name).count() == 0:
- create_notification(kind_name, target, metadata)
+ with db_transaction():
+ if list_notifications(target, kind_name).count() == 0:
+ create_notification(kind_name, target, metadata)
def lookup_notification(user, uuid):
- results = list(list_notifications(user, id_filter=uuid, include_dismissed=True, limit=1))
- if not results:
- return None
+ results = list(
+ list_notifications(user, id_filter=uuid, include_dismissed=True, limit=1)
+ )
+ if not results:
+ return None
- return results[0]
+ return results[0]
def lookup_notifications_by_path_prefix(prefix):
- return list((Notification
- .select()
- .where(Notification.lookup_path % prefix)))
+ return list((Notification.select().where(Notification.lookup_path % prefix)))
-def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=False,
- page=None, limit=None):
+def list_notifications(
+ user, kind_name=None, id_filter=None, include_dismissed=False, page=None, limit=None
+):
- base_query = (Notification
- .select(Notification.id,
- Notification.uuid,
- Notification.kind,
- Notification.metadata_json,
- Notification.dismissed,
- Notification.lookup_path,
- Notification.created,
- Notification.created.alias('cd'),
- Notification.target)
- .join(NotificationKind))
+ base_query = Notification.select(
+ Notification.id,
+ Notification.uuid,
+ Notification.kind,
+ Notification.metadata_json,
+ Notification.dismissed,
+ Notification.lookup_path,
+ Notification.created,
+ Notification.created.alias("cd"),
+ Notification.target,
+ ).join(NotificationKind)
- if kind_name is not None:
- base_query = base_query.where(NotificationKind.name == kind_name)
+ if kind_name is not None:
+ base_query = base_query.where(NotificationKind.name == kind_name)
- if id_filter is not None:
- base_query = base_query.where(Notification.uuid == id_filter)
+ if id_filter is not None:
+ base_query = base_query.where(Notification.uuid == id_filter)
- if not include_dismissed:
- base_query = base_query.where(Notification.dismissed == False)
+ if not include_dismissed:
+ base_query = base_query.where(Notification.dismissed == False)
- # Lookup directly for the user.
- user_direct = base_query.clone().where(Notification.target == user)
+ # Lookup directly for the user.
+ user_direct = base_query.clone().where(Notification.target == user)
- # Lookup via organizations admined by the user.
- Org = User.alias()
- AdminTeam = Team.alias()
- AdminTeamMember = TeamMember.alias()
- AdminUser = User.alias()
+ # Lookup via organizations admined by the user.
+ Org = User.alias()
+ AdminTeam = Team.alias()
+ AdminTeamMember = TeamMember.alias()
+ AdminUser = User.alias()
- via_orgs = (base_query.clone()
- .join(Org, on=(Org.id == Notification.target))
- .join(AdminTeam, on=(Org.id == AdminTeam.organization))
- .join(TeamRole, on=(AdminTeam.role == TeamRole.id))
- .switch(AdminTeam)
- .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
- .join(AdminUser, on=(AdminTeamMember.user == AdminUser.id))
- .where((AdminUser.id == user) & (TeamRole.name == 'admin')))
+ via_orgs = (
+ base_query.clone()
+ .join(Org, on=(Org.id == Notification.target))
+ .join(AdminTeam, on=(Org.id == AdminTeam.organization))
+ .join(TeamRole, on=(AdminTeam.role == TeamRole.id))
+ .switch(AdminTeam)
+ .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
+ .join(AdminUser, on=(AdminTeamMember.user == AdminUser.id))
+ .where((AdminUser.id == user) & (TeamRole.name == "admin"))
+ )
- query = user_direct | via_orgs
+ query = user_direct | via_orgs
- if page:
- query = query.paginate(page, limit)
- elif limit:
- query = query.limit(limit)
+ if page:
+ query = query.paginate(page, limit)
+ elif limit:
+ query = query.limit(limit)
- return query.order_by(SQL('cd desc'))
+ return query.order_by(SQL("cd desc"))
def delete_all_notifications_by_path_prefix(prefix):
- (Notification
- .delete()
- .where(Notification.lookup_path ** (prefix + '%'))
- .execute())
+ (Notification.delete().where(Notification.lookup_path ** (prefix + "%")).execute())
def delete_all_notifications_by_kind(kind_name):
- kind_ref = NotificationKind.get(name=kind_name)
- (Notification
- .delete()
- .where(Notification.kind == kind_ref)
- .execute())
+ kind_ref = NotificationKind.get(name=kind_name)
+ (Notification.delete().where(Notification.kind == kind_ref).execute())
def delete_notifications_by_kind(target, kind_name):
- kind_ref = NotificationKind.get(name=kind_name)
- Notification.delete().where(Notification.target == target,
- Notification.kind == kind_ref).execute()
+ kind_ref = NotificationKind.get(name=kind_name)
+ Notification.delete().where(
+ Notification.target == target, Notification.kind == kind_ref
+ ).execute()
def delete_matching_notifications(target, kind_name, **kwargs):
- kind_ref = NotificationKind.get(name=kind_name)
+ kind_ref = NotificationKind.get(name=kind_name)
- # Load all notifications for the user with the given kind.
- notifications = (Notification
- .select()
- .where(Notification.target == target,
- Notification.kind == kind_ref))
+ # Load all notifications for the user with the given kind.
+ notifications = Notification.select().where(
+ Notification.target == target, Notification.kind == kind_ref
+ )
- # For each, match the metadata to the specified values.
- for notification in notifications:
- matches = True
- try:
- metadata = json.loads(notification.metadata_json)
- except:
- continue
+ # For each, match the metadata to the specified values.
+ for notification in notifications:
+ matches = True
+ try:
+ metadata = json.loads(notification.metadata_json)
+ except:
+ continue
- for (key, value) in kwargs.iteritems():
- if not key in metadata or metadata[key] != value:
- matches = False
- break
+ for (key, value) in kwargs.iteritems():
+ if not key in metadata or metadata[key] != value:
+ matches = False
+ break
- if not matches:
- continue
+ if not matches:
+ continue
- notification.delete_instance()
+ notification.delete_instance()
def increment_notification_failure_count(uuid):
- """ This increments the number of failures by one """
- (RepositoryNotification
- .update(number_of_failures=RepositoryNotification.number_of_failures + 1)
- .where(RepositoryNotification.uuid == uuid)
- .execute())
+ """ This increments the number of failures by one """
+ (
+ RepositoryNotification.update(
+ number_of_failures=RepositoryNotification.number_of_failures + 1
+ )
+ .where(RepositoryNotification.uuid == uuid)
+ .execute()
+ )
def reset_notification_number_of_failures(namespace_name, repository_name, uuid):
- """ This resets the number of failures for a repo notification to 0 """
- try:
- notification = RepositoryNotification.select().where(RepositoryNotification.uuid == uuid).get()
- if (notification.repository.namespace_user.username != namespace_name or
- notification.repository.name != repository_name):
- raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
- reset_number_of_failures_to_zero(notification.id)
- return notification
- except RepositoryNotification.DoesNotExist:
- return None
+ """ This resets the number of failures for a repo notification to 0 """
+ try:
+ notification = (
+ RepositoryNotification.select()
+ .where(RepositoryNotification.uuid == uuid)
+ .get()
+ )
+ if (
+ notification.repository.namespace_user.username != namespace_name
+ or notification.repository.name != repository_name
+ ):
+ raise InvalidNotificationException(
+ "No repository notification found with uuid: %s" % uuid
+ )
+ reset_number_of_failures_to_zero(notification.id)
+ return notification
+ except RepositoryNotification.DoesNotExist:
+ return None
def reset_number_of_failures_to_zero(notification_id):
- """ This resets the number of failures for a repo notification to 0 """
- RepositoryNotification.update(number_of_failures=0).where(RepositoryNotification.id == notification_id).execute()
+ """ This resets the number of failures for a repo notification to 0 """
+ RepositoryNotification.update(number_of_failures=0).where(
+ RepositoryNotification.id == notification_id
+ ).execute()
-def create_repo_notification(repo, event_name, method_name, method_config, event_config, title=None):
- event = ExternalNotificationEvent.get(ExternalNotificationEvent.name == event_name)
- method = ExternalNotificationMethod.get(ExternalNotificationMethod.name == method_name)
+def create_repo_notification(
+ repo, event_name, method_name, method_config, event_config, title=None
+):
+ event = ExternalNotificationEvent.get(ExternalNotificationEvent.name == event_name)
+ method = ExternalNotificationMethod.get(
+ ExternalNotificationMethod.name == method_name
+ )
- return RepositoryNotification.create(repository=repo, event=event, method=method,
- config_json=json.dumps(method_config), title=title,
- event_config_json=json.dumps(event_config))
+ return RepositoryNotification.create(
+ repository=repo,
+ event=event,
+ method=method,
+ config_json=json.dumps(method_config),
+ title=title,
+ event_config_json=json.dumps(event_config),
+ )
def _base_get_notification(uuid):
- """ This is a base query for get statements """
- return (RepositoryNotification
- .select(RepositoryNotification, Repository, Namespace)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(RepositoryNotification.uuid == uuid))
+ """ This is a base query for get statements """
+ return (
+ RepositoryNotification.select(RepositoryNotification, Repository, Namespace)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(RepositoryNotification.uuid == uuid)
+ )
def get_enabled_notification(uuid):
- """ This returns a notification with less than 3 failures """
- try:
- return _base_get_notification(uuid).where(RepositoryNotification.number_of_failures < 3).get()
- except RepositoryNotification.DoesNotExist:
- raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
+ """ This returns a notification with less than 3 failures """
+ try:
+ return (
+ _base_get_notification(uuid)
+ .where(RepositoryNotification.number_of_failures < 3)
+ .get()
+ )
+ except RepositoryNotification.DoesNotExist:
+ raise InvalidNotificationException(
+ "No repository notification found with uuid: %s" % uuid
+ )
def get_repo_notification(uuid):
- try:
- return _base_get_notification(uuid).get()
- except RepositoryNotification.DoesNotExist:
- raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
+ try:
+ return _base_get_notification(uuid).get()
+ except RepositoryNotification.DoesNotExist:
+ raise InvalidNotificationException(
+ "No repository notification found with uuid: %s" % uuid
+ )
def delete_repo_notification(namespace_name, repository_name, uuid):
- found = get_repo_notification(uuid)
- if found.repository.namespace_user.username != namespace_name or found.repository.name != repository_name:
- raise InvalidNotificationException('No repository notifiation found with uuid: %s' % uuid)
- found.delete_instance()
- return found
+ found = get_repo_notification(uuid)
+ if (
+ found.repository.namespace_user.username != namespace_name
+ or found.repository.name != repository_name
+ ):
+ raise InvalidNotificationException(
+ "No repository notifiation found with uuid: %s" % uuid
+ )
+ found.delete_instance()
+ return found
def list_repo_notifications(namespace_name, repository_name, event_name=None):
- query = (RepositoryNotification
- .select(RepositoryNotification, Repository, Namespace)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ query = (
+ RepositoryNotification.select(RepositoryNotification, Repository, Namespace)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
- if event_name:
- query = (query
- .switch(RepositoryNotification)
- .join(ExternalNotificationEvent)
- .where(ExternalNotificationEvent.name == event_name))
+ if event_name:
+ query = (
+ query.switch(RepositoryNotification)
+ .join(ExternalNotificationEvent)
+ .where(ExternalNotificationEvent.name == event_name)
+ )
- return query
+ return query
diff --git a/data/model/oauth.py b/data/model/oauth.py
index 182c08f32..f0febe832 100644
--- a/data/model/oauth.py
+++ b/data/model/oauth.py
@@ -7,8 +7,13 @@ from oauth2lib.provider import AuthorizationProvider
from oauth2lib import utils
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
- random_string_generator)
+from data.database import (
+ OAuthApplication,
+ OAuthAuthorizationCode,
+ OAuthAccessToken,
+ User,
+ random_string_generator,
+)
from data.fields import DecryptedValue, Credential
from data.model import user, config
from auth import scopes
@@ -23,412 +28,465 @@ AUTHORIZATION_CODE_PREFIX_LENGTH = 20
class DatabaseAuthorizationProvider(AuthorizationProvider):
- def get_authorized_user(self):
- raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
+ def get_authorized_user(self):
+ raise NotImplementedError(
+ "Subclasses must fill in the ability to get the authorized_user."
+ )
- def _generate_data_string(self):
- return json.dumps({'username': self.get_authorized_user().username})
+ def _generate_data_string(self):
+ return json.dumps({"username": self.get_authorized_user().username})
- @property
- def token_expires_in(self):
- """Property method to get the token expiration time in seconds.
+ @property
+ def token_expires_in(self):
+ """Property method to get the token expiration time in seconds.
"""
- return int(60*60*24*365.25*10) # 10 Years
+ return int(60 * 60 * 24 * 365.25 * 10) # 10 Years
- def validate_client_id(self, client_id):
- return self.get_application_for_client_id(client_id) is not None
+ def validate_client_id(self, client_id):
+ return self.get_application_for_client_id(client_id) is not None
- def get_application_for_client_id(self, client_id):
- try:
- return OAuthApplication.get(client_id=client_id)
- except OAuthApplication.DoesNotExist:
- return None
-
- def validate_client_secret(self, client_id, client_secret):
- try:
- application = OAuthApplication.get(client_id=client_id)
-
- # TODO(remove-unenc): Remove legacy check.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- if application.secure_client_secret is None:
- return application.client_secret == client_secret
-
- assert application.secure_client_secret is not None
- return application.secure_client_secret.matches(client_secret)
- except OAuthApplication.DoesNotExist:
- return False
-
- def validate_redirect_uri(self, client_id, redirect_uri):
- internal_redirect_url = '%s%s' % (get_app_url(config.app_config),
- url_for('web.oauth_local_handler'))
-
- if redirect_uri == internal_redirect_url:
- return True
-
- try:
- oauth_app = OAuthApplication.get(client_id=client_id)
- if (oauth_app.redirect_uri and redirect_uri and
- redirect_uri.startswith(oauth_app.redirect_uri)):
- return True
- return False
- except OAuthApplication.DoesNotExist:
- return False
-
- def validate_scope(self, client_id, scopes_string):
- return scopes.validate_scope_string(scopes_string)
-
- def validate_access(self):
- return self.get_authorized_user() is not None
-
- def load_authorized_scope_string(self, client_id, username):
- found = (OAuthAccessToken
- .select()
- .join(OAuthApplication)
- .switch(OAuthAccessToken)
- .join(User)
- .where(OAuthApplication.client_id == client_id, User.username == username,
- OAuthAccessToken.expires_at > datetime.utcnow()))
- found = list(found)
- logger.debug('Found %s matching tokens.', len(found))
- long_scope_string = ','.join([token.scope for token in found])
- logger.debug('Computed long scope string: %s', long_scope_string)
- return long_scope_string
-
- def validate_has_scopes(self, client_id, username, scope):
- long_scope_string = self.load_authorized_scope_string(client_id, username)
-
- # Make sure the token contains the given scopes (at least).
- return scopes.is_subset_string(long_scope_string, scope)
-
- def from_authorization_code(self, client_id, full_code, scope):
- code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
- code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
-
- try:
- found = (OAuthAuthorizationCode
- .select()
- .join(OAuthApplication)
- .where(OAuthApplication.client_id == client_id,
- OAuthAuthorizationCode.code_name == code_name,
- OAuthAuthorizationCode.scope == scope)
- .get())
- if not found.code_credential.matches(code_credential):
- return None
-
- logger.debug('Returning data: %s', found.data)
- return found.data
- except OAuthAuthorizationCode.DoesNotExist:
- # Fallback to the legacy lookup of the full code.
- # TODO(remove-unenc): Remove legacy fallback.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ def get_application_for_client_id(self, client_id):
try:
- found = (OAuthAuthorizationCode
- .select()
- .join(OAuthApplication)
- .where(OAuthApplication.client_id == client_id,
- OAuthAuthorizationCode.code == full_code,
- OAuthAuthorizationCode.scope == scope)
- .get())
- logger.debug('Returning data: %s', found.data)
- return found.data
+ return OAuthApplication.get(client_id=client_id)
+ except OAuthApplication.DoesNotExist:
+ return None
+
+ def validate_client_secret(self, client_id, client_secret):
+ try:
+ application = OAuthApplication.get(client_id=client_id)
+
+ # TODO(remove-unenc): Remove legacy check.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ if application.secure_client_secret is None:
+ return application.client_secret == client_secret
+
+ assert application.secure_client_secret is not None
+ return application.secure_client_secret.matches(client_secret)
+ except OAuthApplication.DoesNotExist:
+ return False
+
+ def validate_redirect_uri(self, client_id, redirect_uri):
+ internal_redirect_url = "%s%s" % (
+ get_app_url(config.app_config),
+ url_for("web.oauth_local_handler"),
+ )
+
+ if redirect_uri == internal_redirect_url:
+ return True
+
+ try:
+ oauth_app = OAuthApplication.get(client_id=client_id)
+ if (
+ oauth_app.redirect_uri
+ and redirect_uri
+ and redirect_uri.startswith(oauth_app.redirect_uri)
+ ):
+ return True
+ return False
+ except OAuthApplication.DoesNotExist:
+ return False
+
+ def validate_scope(self, client_id, scopes_string):
+ return scopes.validate_scope_string(scopes_string)
+
+ def validate_access(self):
+ return self.get_authorized_user() is not None
+
+ def load_authorized_scope_string(self, client_id, username):
+ found = (
+ OAuthAccessToken.select()
+ .join(OAuthApplication)
+ .switch(OAuthAccessToken)
+ .join(User)
+ .where(
+ OAuthApplication.client_id == client_id,
+ User.username == username,
+ OAuthAccessToken.expires_at > datetime.utcnow(),
+ )
+ )
+ found = list(found)
+ logger.debug("Found %s matching tokens.", len(found))
+ long_scope_string = ",".join([token.scope for token in found])
+ logger.debug("Computed long scope string: %s", long_scope_string)
+ return long_scope_string
+
+ def validate_has_scopes(self, client_id, username, scope):
+ long_scope_string = self.load_authorized_scope_string(client_id, username)
+
+ # Make sure the token contains the given scopes (at least).
+ return scopes.is_subset_string(long_scope_string, scope)
+
+ def from_authorization_code(self, client_id, full_code, scope):
+ code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
+ code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
+
+ try:
+ found = (
+ OAuthAuthorizationCode.select()
+ .join(OAuthApplication)
+ .where(
+ OAuthApplication.client_id == client_id,
+ OAuthAuthorizationCode.code_name == code_name,
+ OAuthAuthorizationCode.scope == scope,
+ )
+ .get()
+ )
+ if not found.code_credential.matches(code_credential):
+ return None
+
+ logger.debug("Returning data: %s", found.data)
+ return found.data
except OAuthAuthorizationCode.DoesNotExist:
- return None
- else:
- return None
+ # Fallback to the legacy lookup of the full code.
+ # TODO(remove-unenc): Remove legacy fallback.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ try:
+ found = (
+ OAuthAuthorizationCode.select()
+ .join(OAuthApplication)
+ .where(
+ OAuthApplication.client_id == client_id,
+ OAuthAuthorizationCode.code == full_code,
+ OAuthAuthorizationCode.scope == scope,
+ )
+ .get()
+ )
+ logger.debug("Returning data: %s", found.data)
+ return found.data
+ except OAuthAuthorizationCode.DoesNotExist:
+ return None
+ else:
+ return None
- def persist_authorization_code(self, client_id, full_code, scope):
- oauth_app = OAuthApplication.get(client_id=client_id)
- data = self._generate_data_string()
+ def persist_authorization_code(self, client_id, full_code, scope):
+ oauth_app = OAuthApplication.get(client_id=client_id)
+ data = self._generate_data_string()
- assert len(full_code) >= (AUTHORIZATION_CODE_PREFIX_LENGTH * 2)
- code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
- code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
+ assert len(full_code) >= (AUTHORIZATION_CODE_PREFIX_LENGTH * 2)
+ code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
+ code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
- # TODO(remove-unenc): Remove legacy fallback.
- full_code = None
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- full_code = code_name + code_credential
+ # TODO(remove-unenc): Remove legacy fallback.
+ full_code = None
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ full_code = code_name + code_credential
- OAuthAuthorizationCode.create(application=oauth_app,
- code=full_code,
- scope=scope,
- code_name=code_name,
- code_credential=Credential.from_string(code_credential),
- data=data)
+ OAuthAuthorizationCode.create(
+ application=oauth_app,
+ code=full_code,
+ scope=scope,
+ code_name=code_name,
+ code_credential=Credential.from_string(code_credential),
+ data=data,
+ )
- def persist_token_information(self, client_id, scope, access_token, token_type,
- expires_in, refresh_token, data):
- assert not refresh_token
- found = user.get_user(json.loads(data)['username'])
- if not found:
- raise RuntimeError('Username must be in the data field')
+ def persist_token_information(
+ self,
+ client_id,
+ scope,
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ data,
+ ):
+ assert not refresh_token
+ found = user.get_user(json.loads(data)["username"])
+ if not found:
+ raise RuntimeError("Username must be in the data field")
- token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
- token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
+ token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
+ token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
- assert token_name
- assert token_code
- assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
- assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
+ assert token_name
+ assert token_code
+ assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
+ assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
- oauth_app = OAuthApplication.get(client_id=client_id)
- expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
- OAuthAccessToken.create(application=oauth_app,
- authorized_user=found,
- scope=scope,
- token_name=token_name,
- token_code=Credential.from_string(token_code),
- access_token='',
- token_type=token_type,
- expires_at=expires_at,
- data=data)
+ oauth_app = OAuthApplication.get(client_id=client_id)
+ expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
+ OAuthAccessToken.create(
+ application=oauth_app,
+ authorized_user=found,
+ scope=scope,
+ token_name=token_name,
+ token_code=Credential.from_string(token_code),
+ access_token="",
+ token_type=token_type,
+ expires_at=expires_at,
+ data=data,
+ )
- def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params):
- # Ensure proper response_type
- if response_type != 'token':
- err = 'unsupported_response_type'
- return self._make_redirect_error_response(redirect_uri, err)
+ def get_auth_denied_response(
+ self, response_type, client_id, redirect_uri, **params
+ ):
+ # Ensure proper response_type
+ if response_type != "token":
+ err = "unsupported_response_type"
+ return self._make_redirect_error_response(redirect_uri, err)
- # Check redirect URI
- is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
- if not is_valid_redirect_uri:
- return self._invalid_redirect_uri_response()
+ # Check redirect URI
+ is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
+ if not is_valid_redirect_uri:
+ return self._invalid_redirect_uri_response()
- return self._make_redirect_error_response(redirect_uri, 'authorization_denied')
+ return self._make_redirect_error_response(redirect_uri, "authorization_denied")
- def get_token_response(self, response_type, client_id, redirect_uri, **params):
- # Ensure proper response_type
- if response_type != 'token':
- err = 'unsupported_response_type'
- return self._make_redirect_error_response(redirect_uri, err)
+ def get_token_response(self, response_type, client_id, redirect_uri, **params):
+ # Ensure proper response_type
+ if response_type != "token":
+ err = "unsupported_response_type"
+ return self._make_redirect_error_response(redirect_uri, err)
- # Check for a valid client ID.
- is_valid_client_id = self.validate_client_id(client_id)
- if not is_valid_client_id:
- err = 'unauthorized_client'
- return self._make_redirect_error_response(redirect_uri, err)
+ # Check for a valid client ID.
+ is_valid_client_id = self.validate_client_id(client_id)
+ if not is_valid_client_id:
+ err = "unauthorized_client"
+ return self._make_redirect_error_response(redirect_uri, err)
- # Check for a valid redirect URI.
- is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
- if not is_valid_redirect_uri:
- return self._invalid_redirect_uri_response()
+ # Check for a valid redirect URI.
+ is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
+ if not is_valid_redirect_uri:
+ return self._invalid_redirect_uri_response()
- # Check conditions
- is_valid_access = self.validate_access()
- scope = params.get('scope', '')
- are_valid_scopes = self.validate_scope(client_id, scope)
+ # Check conditions
+ is_valid_access = self.validate_access()
+ scope = params.get("scope", "")
+ are_valid_scopes = self.validate_scope(client_id, scope)
- # Return proper error responses on invalid conditions
- if not is_valid_access:
- err = 'access_denied'
- return self._make_redirect_error_response(redirect_uri, err)
+ # Return proper error responses on invalid conditions
+ if not is_valid_access:
+ err = "access_denied"
+ return self._make_redirect_error_response(redirect_uri, err)
- if not are_valid_scopes:
- err = 'invalid_scope'
- return self._make_redirect_error_response(redirect_uri, err)
+ if not are_valid_scopes:
+ err = "invalid_scope"
+ return self._make_redirect_error_response(redirect_uri, err)
- # Make sure we have enough random data in the token to have a public
- # prefix and a private encrypted suffix.
- access_token = str(self.generate_access_token())
- assert len(access_token) - ACCESS_TOKEN_PREFIX_LENGTH >= 20
+ # Make sure we have enough random data in the token to have a public
+ # prefix and a private encrypted suffix.
+ access_token = str(self.generate_access_token())
+ assert len(access_token) - ACCESS_TOKEN_PREFIX_LENGTH >= 20
- token_type = self.token_type
- expires_in = self.token_expires_in
+ token_type = self.token_type
+ expires_in = self.token_expires_in
- data = self._generate_data_string()
- self.persist_token_information(client_id=client_id,
- scope=scope,
- access_token=access_token,
- token_type=token_type,
- expires_in=expires_in,
- refresh_token=None,
- data=data)
+ data = self._generate_data_string()
+ self.persist_token_information(
+ client_id=client_id,
+ scope=scope,
+ access_token=access_token,
+ token_type=token_type,
+ expires_in=expires_in,
+ refresh_token=None,
+ data=data,
+ )
- url = utils.build_url(redirect_uri, params)
- url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
+ url = utils.build_url(redirect_uri, params)
+ url += "#access_token=%s&token_type=%s&expires_in=%s" % (
+ access_token,
+ token_type,
+ expires_in,
+ )
- return self._make_response(headers={'Location': url}, status_code=302)
+ return self._make_response(headers={"Location": url}, status_code=302)
- def from_refresh_token(self, client_id, refresh_token, scope):
- raise NotImplementedError()
+ def from_refresh_token(self, client_id, refresh_token, scope):
+ raise NotImplementedError()
- def discard_authorization_code(self, client_id, full_code):
- code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
- try:
- found = (OAuthAuthorizationCode
- .select()
- .join(OAuthApplication)
- .where(OAuthApplication.client_id == client_id,
- OAuthAuthorizationCode.code_name == code_name)
- .get())
- found.delete_instance()
- return
- except OAuthAuthorizationCode.DoesNotExist:
- pass
+ def discard_authorization_code(self, client_id, full_code):
+ code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
+ try:
+ found = (
+ OAuthAuthorizationCode.select()
+ .join(OAuthApplication)
+ .where(
+ OAuthApplication.client_id == client_id,
+ OAuthAuthorizationCode.code_name == code_name,
+ )
+ .get()
+ )
+ found.delete_instance()
+ return
+ except OAuthAuthorizationCode.DoesNotExist:
+ pass
- # Legacy: full code.
- # TODO(remove-unenc): Remove legacy fallback.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- try:
- found = (OAuthAuthorizationCode
- .select()
- .join(OAuthApplication)
- .where(OAuthApplication.client_id == client_id,
- OAuthAuthorizationCode.code == full_code)
- .get())
- found.delete_instance()
- except OAuthAuthorizationCode.DoesNotExist:
- pass
+ # Legacy: full code.
+ # TODO(remove-unenc): Remove legacy fallback.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ try:
+ found = (
+ OAuthAuthorizationCode.select()
+ .join(OAuthApplication)
+ .where(
+ OAuthApplication.client_id == client_id,
+ OAuthAuthorizationCode.code == full_code,
+ )
+ .get()
+ )
+ found.delete_instance()
+ except OAuthAuthorizationCode.DoesNotExist:
+ pass
- def discard_refresh_token(self, client_id, refresh_token):
- raise NotImplementedError()
+ def discard_refresh_token(self, client_id, refresh_token):
+ raise NotImplementedError()
def create_application(org, name, application_uri, redirect_uri, **kwargs):
- client_secret = kwargs.pop('client_secret', random_string_generator(length=40)())
+ client_secret = kwargs.pop("client_secret", random_string_generator(length=40)())
- # TODO(remove-unenc): Remove legacy field.
- old_client_secret = None
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- old_client_secret = client_secret
+ # TODO(remove-unenc): Remove legacy field.
+ old_client_secret = None
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ old_client_secret = client_secret
- return OAuthApplication.create(organization=org,
- name=name,
- application_uri=application_uri,
- redirect_uri=redirect_uri,
- client_secret=old_client_secret,
- secure_client_secret=DecryptedValue(client_secret),
- **kwargs)
+ return OAuthApplication.create(
+ organization=org,
+ name=name,
+ application_uri=application_uri,
+ redirect_uri=redirect_uri,
+ client_secret=old_client_secret,
+ secure_client_secret=DecryptedValue(client_secret),
+ **kwargs
+ )
def validate_access_token(access_token):
- assert isinstance(access_token, basestring)
- token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
- if not token_name:
- return None
+ assert isinstance(access_token, basestring)
+ token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
+ if not token_name:
+ return None
- token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
- if not token_code:
- return None
+ token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
+ if not token_code:
+ return None
- try:
- found = (OAuthAccessToken
- .select(OAuthAccessToken, User)
- .join(User)
- .where(OAuthAccessToken.token_name == token_name)
- .get())
-
- if found.token_code is None or not found.token_code.matches(token_code):
- return None
-
- return found
- except OAuthAccessToken.DoesNotExist:
- pass
-
- # Legacy lookup.
- # TODO(remove-unenc): Remove this once migrated.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
- assert access_token
- found = (OAuthAccessToken
- .select(OAuthAccessToken, User)
- .join(User)
- .where(OAuthAccessToken.access_token == access_token)
- .get())
- return found
- except OAuthAccessToken.DoesNotExist:
- return None
+ found = (
+ OAuthAccessToken.select(OAuthAccessToken, User)
+ .join(User)
+ .where(OAuthAccessToken.token_name == token_name)
+ .get()
+ )
- return None
+ if found.token_code is None or not found.token_code.matches(token_code):
+ return None
+
+ return found
+ except OAuthAccessToken.DoesNotExist:
+ pass
+
+ # Legacy lookup.
+ # TODO(remove-unenc): Remove this once migrated.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ try:
+ assert access_token
+ found = (
+ OAuthAccessToken.select(OAuthAccessToken, User)
+ .join(User)
+ .where(OAuthAccessToken.access_token == access_token)
+ .get()
+ )
+ return found
+ except OAuthAccessToken.DoesNotExist:
+ return None
+
+ return None
def get_application_for_client_id(client_id):
- try:
- return OAuthApplication.get(client_id=client_id)
- except OAuthApplication.DoesNotExist:
- return None
+ try:
+ return OAuthApplication.get(client_id=client_id)
+ except OAuthApplication.DoesNotExist:
+ return None
def reset_client_secret(application):
- client_secret = random_string_generator(length=40)()
+ client_secret = random_string_generator(length=40)()
- # TODO(remove-unenc): Remove legacy field.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- application.client_secret = client_secret
+ # TODO(remove-unenc): Remove legacy field.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ application.client_secret = client_secret
- application.secure_client_secret = DecryptedValue(client_secret)
- application.save()
- return application
+ application.secure_client_secret = DecryptedValue(client_secret)
+ application.save()
+ return application
def lookup_application(org, client_id):
- try:
- return OAuthApplication.get(organization=org, client_id=client_id)
- except OAuthApplication.DoesNotExist:
- return None
+ try:
+ return OAuthApplication.get(organization=org, client_id=client_id)
+ except OAuthApplication.DoesNotExist:
+ return None
def delete_application(org, client_id):
- application = lookup_application(org, client_id)
- if not application:
- return
+ application = lookup_application(org, client_id)
+ if not application:
+ return
- application.delete_instance(recursive=True, delete_nullable=True)
- return application
+ application.delete_instance(recursive=True, delete_nullable=True)
+ return application
def lookup_access_token_by_uuid(token_uuid):
- try:
- return OAuthAccessToken.get(OAuthAccessToken.uuid == token_uuid)
- except OAuthAccessToken.DoesNotExist:
- return None
+ try:
+ return OAuthAccessToken.get(OAuthAccessToken.uuid == token_uuid)
+ except OAuthAccessToken.DoesNotExist:
+ return None
def lookup_access_token_for_user(user_obj, token_uuid):
- try:
- return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user_obj,
- OAuthAccessToken.uuid == token_uuid)
- except OAuthAccessToken.DoesNotExist:
- return None
+ try:
+ return OAuthAccessToken.get(
+ OAuthAccessToken.authorized_user == user_obj,
+ OAuthAccessToken.uuid == token_uuid,
+ )
+ except OAuthAccessToken.DoesNotExist:
+ return None
def list_access_tokens_for_user(user_obj):
- query = (OAuthAccessToken
- .select()
- .join(OAuthApplication)
- .switch(OAuthAccessToken)
- .join(User)
- .where(OAuthAccessToken.authorized_user == user_obj))
+ query = (
+ OAuthAccessToken.select()
+ .join(OAuthApplication)
+ .switch(OAuthAccessToken)
+ .join(User)
+ .where(OAuthAccessToken.authorized_user == user_obj)
+ )
- return query
+ return query
def list_applications_for_org(org):
- query = (OAuthApplication
- .select()
- .join(User)
- .where(OAuthApplication.organization == org))
+ query = (
+ OAuthApplication.select().join(User).where(OAuthApplication.organization == org)
+ )
- return query
+ return query
-def create_access_token_for_testing(user_obj, client_id, scope, access_token=None, expires_in=9000):
- access_token = access_token or random_string_generator(length=40)()
- token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
- token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
+def create_access_token_for_testing(
+ user_obj, client_id, scope, access_token=None, expires_in=9000
+):
+ access_token = access_token or random_string_generator(length=40)()
+ token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
+ token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
- assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
- assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
+ assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
+ assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
- expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
- application = get_application_for_client_id(client_id)
- created = OAuthAccessToken.create(application=application,
- authorized_user=user_obj,
- scope=scope,
- token_type='token',
- access_token='',
- token_code=Credential.from_string(token_code),
- token_name=token_name,
- expires_at=expires_at,
- data='')
- return created, access_token
+ expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
+ application = get_application_for_client_id(client_id)
+ created = OAuthAccessToken.create(
+ application=application,
+ authorized_user=user_obj,
+ scope=scope,
+ token_type="token",
+ access_token="",
+ token_code=Credential.from_string(token_code),
+ token_name=token_name,
+ expires_at=expires_at,
+ data="",
+ )
+ return created, access_token
diff --git a/data/model/oci/__init__.py b/data/model/oci/__init__.py
index 39bcef2eb..9bd769121 100644
--- a/data/model/oci/__init__.py
+++ b/data/model/oci/__init__.py
@@ -1,9 +1,3 @@
# There MUST NOT be any circular dependencies between these subsections. If there are fix it by
# moving the minimal number of things to shared
-from data.model.oci import (
- blob,
- label,
- manifest,
- shared,
- tag,
-)
+from data.model.oci import blob, label, manifest, shared, tag
diff --git a/data/model/oci/blob.py b/data/model/oci/blob.py
index f7739c21b..6d04ff561 100644
--- a/data/model/oci/blob.py
+++ b/data/model/oci/blob.py
@@ -3,24 +3,28 @@ from data.model import BlobDoesNotExist
from data.model.storage import get_storage_by_uuid, InvalidImageException
from data.model.blob import get_repository_blob_by_digest as legacy_get
+
def get_repository_blob_by_digest(repository, blob_digest):
- """ Find the content-addressable blob linked to the specified repository and
+ """ Find the content-addressable blob linked to the specified repository and
returns it or None if none.
"""
- try:
- storage = (ImageStorage
- .select(ImageStorage.uuid)
- .join(ManifestBlob)
- .where(ManifestBlob.repository == repository,
- ImageStorage.content_checksum == blob_digest,
- ImageStorage.uploading == False)
- .get())
-
- return get_storage_by_uuid(storage.uuid)
- except (ImageStorage.DoesNotExist, InvalidImageException):
- # TODO: Remove once we are no longer using the legacy tables.
- # Try the legacy call.
try:
- return legacy_get(repository, blob_digest)
- except BlobDoesNotExist:
- return None
+ storage = (
+ ImageStorage.select(ImageStorage.uuid)
+ .join(ManifestBlob)
+ .where(
+ ManifestBlob.repository == repository,
+ ImageStorage.content_checksum == blob_digest,
+ ImageStorage.uploading == False,
+ )
+ .get()
+ )
+
+ return get_storage_by_uuid(storage.uuid)
+ except (ImageStorage.DoesNotExist, InvalidImageException):
+ # TODO: Remove once we are no longer using the legacy tables.
+ # Try the legacy call.
+ try:
+ return legacy_get(repository, blob_digest)
+ except BlobDoesNotExist:
+ return None
diff --git a/data/model/oci/label.py b/data/model/oci/label.py
index d019e6d2d..850db9a31 100644
--- a/data/model/oci/label.py
+++ b/data/model/oci/label.py
@@ -1,142 +1,183 @@
import logging
-from data.model import InvalidLabelKeyException, InvalidMediaTypeException, DataModelException
-from data.database import (Label, Manifest, TagManifestLabel, MediaType, LabelSourceType,
- db_transaction, ManifestLabel, TagManifestLabelMap,
- TagManifestToManifest, Repository, TagManifest)
+from data.model import (
+ InvalidLabelKeyException,
+ InvalidMediaTypeException,
+ DataModelException,
+)
+from data.database import (
+ Label,
+ Manifest,
+ TagManifestLabel,
+ MediaType,
+ LabelSourceType,
+ db_transaction,
+ ManifestLabel,
+ TagManifestLabelMap,
+ TagManifestToManifest,
+ Repository,
+ TagManifest,
+)
from data.text import prefix_search
from util.validation import validate_label_key
from util.validation import is_json
logger = logging.getLogger(__name__)
+
def list_manifest_labels(manifest_id, prefix_filter=None):
- """ Lists all labels found on the given manifest, with an optional filter by key prefix. """
- query = (Label
- .select(Label, MediaType)
- .join(MediaType)
- .switch(Label)
- .join(LabelSourceType)
- .switch(Label)
- .join(ManifestLabel)
- .where(ManifestLabel.manifest == manifest_id))
+ """ Lists all labels found on the given manifest, with an optional filter by key prefix. """
+ query = (
+ Label.select(Label, MediaType)
+ .join(MediaType)
+ .switch(Label)
+ .join(LabelSourceType)
+ .switch(Label)
+ .join(ManifestLabel)
+ .where(ManifestLabel.manifest == manifest_id)
+ )
- if prefix_filter is not None:
- query = query.where(prefix_search(Label.key, prefix_filter))
+ if prefix_filter is not None:
+ query = query.where(prefix_search(Label.key, prefix_filter))
- return query
+ return query
def get_manifest_label(label_uuid, manifest):
- """ Retrieves the manifest label on the manifest with the given UUID or None if none. """
- try:
- return (Label
- .select(Label, LabelSourceType)
+ """ Retrieves the manifest label on the manifest with the given UUID or None if none. """
+ try:
+ return (
+ Label.select(Label, LabelSourceType)
.join(LabelSourceType)
.where(Label.uuid == label_uuid)
.switch(Label)
.join(ManifestLabel)
.where(ManifestLabel.manifest == manifest)
- .get())
- except Label.DoesNotExist:
- return None
+ .get()
+ )
+ except Label.DoesNotExist:
+ return None
-def create_manifest_label(manifest_id, key, value, source_type_name, media_type_name=None,
- adjust_old_model=True):
- """ Creates a new manifest label on a specific tag manifest. """
- if not key:
- raise InvalidLabelKeyException()
+def create_manifest_label(
+ manifest_id,
+ key,
+ value,
+ source_type_name,
+ media_type_name=None,
+ adjust_old_model=True,
+):
+ """ Creates a new manifest label on a specific tag manifest. """
+ if not key:
+ raise InvalidLabelKeyException()
- # Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
- # does not currently prevent them from being put into said manifests.
- if not validate_label_key(key) and source_type_name != 'manifest':
- raise InvalidLabelKeyException('Key `%s` is invalid' % key)
+ # Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
+ # does not currently prevent them from being put into said manifests.
+ if not validate_label_key(key) and source_type_name != "manifest":
+ raise InvalidLabelKeyException("Key `%s` is invalid" % key)
- # Find the matching media type. If none specified, we infer.
- if media_type_name is None:
- media_type_name = 'text/plain'
- if is_json(value):
- media_type_name = 'application/json'
+ # Find the matching media type. If none specified, we infer.
+ if media_type_name is None:
+ media_type_name = "text/plain"
+ if is_json(value):
+ media_type_name = "application/json"
- try:
- media_type_id = Label.media_type.get_id(media_type_name)
- except MediaType.DoesNotExist:
- raise InvalidMediaTypeException()
-
- source_type_id = Label.source_type.get_id(source_type_name)
-
- # Ensure the manifest exists.
- try:
- manifest = (Manifest
- .select(Manifest, Repository)
- .join(Repository)
- .where(Manifest.id == manifest_id)
- .get())
- except Manifest.DoesNotExist:
- return None
-
- repository = manifest.repository
-
- # TODO: Remove this code once the TagManifest table is gone.
- tag_manifest = None
- if adjust_old_model:
try:
- mapping_row = (TagManifestToManifest
- .select(TagManifestToManifest, TagManifest)
- .join(TagManifest)
- .where(TagManifestToManifest.manifest == manifest)
- .get())
- tag_manifest = mapping_row.tag_manifest
- except TagManifestToManifest.DoesNotExist:
- tag_manifest = None
+ media_type_id = Label.media_type.get_id(media_type_name)
+ except MediaType.DoesNotExist:
+ raise InvalidMediaTypeException()
- with db_transaction():
- label = Label.create(key=key, value=value, source_type=source_type_id, media_type=media_type_id)
- manifest_label = ManifestLabel.create(manifest=manifest_id, label=label, repository=repository)
+ source_type_id = Label.source_type.get_id(source_type_name)
+
+ # Ensure the manifest exists.
+ try:
+ manifest = (
+ Manifest.select(Manifest, Repository)
+ .join(Repository)
+ .where(Manifest.id == manifest_id)
+ .get()
+ )
+ except Manifest.DoesNotExist:
+ return None
+
+ repository = manifest.repository
- # If there exists a mapping to a TagManifest, add the old-style label.
# TODO: Remove this code once the TagManifest table is gone.
- if tag_manifest:
- tag_manifest_label = TagManifestLabel.create(annotated=tag_manifest, label=label,
- repository=repository)
- TagManifestLabelMap.create(manifest_label=manifest_label,
- tag_manifest_label=tag_manifest_label,
- label=label,
- manifest=manifest,
- tag_manifest=tag_manifest)
+ tag_manifest = None
+ if adjust_old_model:
+ try:
+ mapping_row = (
+ TagManifestToManifest.select(TagManifestToManifest, TagManifest)
+ .join(TagManifest)
+ .where(TagManifestToManifest.manifest == manifest)
+ .get()
+ )
+ tag_manifest = mapping_row.tag_manifest
+ except TagManifestToManifest.DoesNotExist:
+ tag_manifest = None
- return label
+ with db_transaction():
+ label = Label.create(
+ key=key, value=value, source_type=source_type_id, media_type=media_type_id
+ )
+ manifest_label = ManifestLabel.create(
+ manifest=manifest_id, label=label, repository=repository
+ )
+
+ # If there exists a mapping to a TagManifest, add the old-style label.
+ # TODO: Remove this code once the TagManifest table is gone.
+ if tag_manifest:
+ tag_manifest_label = TagManifestLabel.create(
+ annotated=tag_manifest, label=label, repository=repository
+ )
+ TagManifestLabelMap.create(
+ manifest_label=manifest_label,
+ tag_manifest_label=tag_manifest_label,
+ label=label,
+ manifest=manifest,
+ tag_manifest=tag_manifest,
+ )
+
+ return label
def delete_manifest_label(label_uuid, manifest):
- """ Deletes the manifest label on the tag manifest with the given ID. Returns the label deleted
+ """ Deletes the manifest label on the tag manifest with the given ID. Returns the label deleted
or None if none.
"""
- # Find the label itself.
- label = get_manifest_label(label_uuid, manifest)
- if label is None:
- return None
+ # Find the label itself.
+ label = get_manifest_label(label_uuid, manifest)
+ if label is None:
+ return None
- if not label.source_type.mutable:
- raise DataModelException('Cannot delete immutable label')
+ if not label.source_type.mutable:
+ raise DataModelException("Cannot delete immutable label")
- # Delete the mapping records and label.
- # TODO: Remove this code once the TagManifest table is gone.
- with db_transaction():
- (TagManifestLabelMap
- .delete()
- .where(TagManifestLabelMap.label == label)
- .execute())
+ # Delete the mapping records and label.
+ # TODO: Remove this code once the TagManifest table is gone.
+ with db_transaction():
+ (
+ TagManifestLabelMap.delete()
+ .where(TagManifestLabelMap.label == label)
+ .execute()
+ )
- deleted_count = TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
- if deleted_count != 1:
- logger.warning('More than a single label deleted for matching label %s', label_uuid)
+ deleted_count = (
+ TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
+ )
+ if deleted_count != 1:
+ logger.warning(
+ "More than a single label deleted for matching label %s", label_uuid
+ )
- deleted_count = ManifestLabel.delete().where(ManifestLabel.label == label).execute()
- if deleted_count != 1:
- logger.warning('More than a single label deleted for matching label %s', label_uuid)
+ deleted_count = (
+ ManifestLabel.delete().where(ManifestLabel.label == label).execute()
+ )
+ if deleted_count != 1:
+ logger.warning(
+ "More than a single label deleted for matching label %s", label_uuid
+ )
- label.delete_instance(recursive=False)
- return label
+ label.delete_instance(recursive=False)
+ return label
diff --git a/data/model/oci/manifest.py b/data/model/oci/manifest.py
index 85b66efc5..806ebba24 100644
--- a/data/model/oci/manifest.py
+++ b/data/model/oci/manifest.py
@@ -4,8 +4,14 @@ from collections import namedtuple
from peewee import IntegrityError
-from data.database import (Tag, Manifest, ManifestBlob, ManifestLegacyImage, ManifestChild,
- db_transaction)
+from data.database import (
+ Tag,
+ Manifest,
+ ManifestBlob,
+ ManifestLegacyImage,
+ ManifestChild,
+ db_transaction,
+)
from data.model import BlobDoesNotExist
from data.model.blob import get_or_create_shared_blob, get_shared_blob
from data.model.oci.tag import filter_to_alive_tags, create_temporary_tag_if_necessary
@@ -19,73 +25,86 @@ from image.docker.schema2.list import MalformedSchema2ManifestList
from util.validation import is_json
-TEMP_TAG_EXPIRATION_SEC = 300 # 5 minutes
+TEMP_TAG_EXPIRATION_SEC = 300 # 5 minutes
logger = logging.getLogger(__name__)
-CreatedManifest = namedtuple('CreatedManifest', ['manifest', 'newly_created', 'labels_to_apply'])
+CreatedManifest = namedtuple(
+ "CreatedManifest", ["manifest", "newly_created", "labels_to_apply"]
+)
class CreateManifestException(Exception):
- """ Exception raised when creating a manifest fails and explicit exception
+ """ Exception raised when creating a manifest fails and explicit exception
raising is requested. """
-def lookup_manifest(repository_id, manifest_digest, allow_dead=False, require_available=False,
- temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC):
- """ Returns the manifest with the specified digest under the specified repository
+def lookup_manifest(
+ repository_id,
+ manifest_digest,
+ allow_dead=False,
+ require_available=False,
+ temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
+):
+ """ Returns the manifest with the specified digest under the specified repository
or None if none. If allow_dead is True, then manifests referenced by only
dead tags will also be returned. If require_available is True, the manifest
will be marked with a temporary tag to ensure it remains available.
"""
- if not require_available:
- return _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
+ if not require_available:
+ return _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
- with db_transaction():
- found = _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
- if found is None:
- return None
+ with db_transaction():
+ found = _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
+ if found is None:
+ return None
- create_temporary_tag_if_necessary(found, temp_tag_expiration_sec)
- return found
+ create_temporary_tag_if_necessary(found, temp_tag_expiration_sec)
+ return found
def _lookup_manifest(repository_id, manifest_digest, allow_dead=False):
- query = (Manifest
- .select()
- .where(Manifest.repository == repository_id)
- .where(Manifest.digest == manifest_digest))
+ query = (
+ Manifest.select()
+ .where(Manifest.repository == repository_id)
+ .where(Manifest.digest == manifest_digest)
+ )
- if allow_dead:
+ if allow_dead:
+ try:
+ return query.get()
+ except Manifest.DoesNotExist:
+ return None
+
+ # Try first to filter to those manifests referenced by an alive tag,
try:
- return query.get()
+ return filter_to_alive_tags(query.join(Tag)).get()
except Manifest.DoesNotExist:
- return None
+ pass
- # Try first to filter to those manifests referenced by an alive tag,
- try:
- return filter_to_alive_tags(query.join(Tag)).get()
- except Manifest.DoesNotExist:
- pass
+ # Try referenced as the child of a manifest that has an alive tag.
+ query = query.join(
+ ManifestChild, on=(ManifestChild.child_manifest == Manifest.id)
+ ).join(Tag, on=(Tag.manifest == ManifestChild.manifest))
- # Try referenced as the child of a manifest that has an alive tag.
- query = (query
- .join(ManifestChild, on=(ManifestChild.child_manifest == Manifest.id))
- .join(Tag, on=(Tag.manifest == ManifestChild.manifest)))
+ query = filter_to_alive_tags(query)
- query = filter_to_alive_tags(query)
-
- try:
- return query.get()
- except Manifest.DoesNotExist:
- return None
+ try:
+ return query.get()
+ except Manifest.DoesNotExist:
+ return None
-def get_or_create_manifest(repository_id, manifest_interface_instance, storage,
- temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
- for_tagging=False, raise_on_error=False):
- """ Returns a CreatedManifest for the manifest in the specified repository with the matching
+def get_or_create_manifest(
+ repository_id,
+ manifest_interface_instance,
+ storage,
+ temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
+ for_tagging=False,
+ raise_on_error=False,
+):
+ """ Returns a CreatedManifest for the manifest in the specified repository with the matching
digest (if it already exists) or, if not yet created, creates and returns the manifest.
Returns None if there was an error creating the manifest, unless raise_on_error is specified,
@@ -95,227 +114,293 @@ def get_or_create_manifest(repository_id, manifest_interface_instance, storage,
Note that *all* blobs referenced by the manifest must exist already in the repository or this
method will fail with a None.
"""
- existing = lookup_manifest(repository_id, manifest_interface_instance.digest, allow_dead=True,
- require_available=True,
- temp_tag_expiration_sec=temp_tag_expiration_sec)
- if existing is not None:
- return CreatedManifest(manifest=existing, newly_created=False, labels_to_apply=None)
-
- return _create_manifest(repository_id, manifest_interface_instance, storage,
- temp_tag_expiration_sec, for_tagging=for_tagging,
- raise_on_error=raise_on_error)
-
-
-def _create_manifest(repository_id, manifest_interface_instance, storage,
- temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
- for_tagging=False, raise_on_error=False):
- # Validate the manifest.
- retriever = RepositoryContentRetriever.for_repository(repository_id, storage)
- try:
- manifest_interface_instance.validate(retriever)
- except (ManifestException, MalformedSchema2ManifestList, BlobDoesNotExist, IOError) as ex:
- logger.exception('Could not validate manifest `%s`', manifest_interface_instance.digest)
- if raise_on_error:
- raise CreateManifestException(ex)
-
- return None
-
- # Load, parse and get/create the child manifests, if any.
- child_manifest_refs = manifest_interface_instance.child_manifests(retriever)
- child_manifest_rows = {}
- child_manifest_label_dicts = []
-
- if child_manifest_refs is not None:
- for child_manifest_ref in child_manifest_refs:
- # Load and parse the child manifest.
- try:
- child_manifest = child_manifest_ref.manifest_obj
- except (ManifestException, MalformedSchema2ManifestList, BlobDoesNotExist, IOError) as ex:
- logger.exception('Could not load manifest list for manifest `%s`',
- manifest_interface_instance.digest)
- if raise_on_error:
- raise CreateManifestException(ex)
-
- return None
-
- # Retrieve its labels.
- labels = child_manifest.get_manifest_labels(retriever)
- if labels is None:
- logger.exception('Could not load manifest labels for child manifest')
- return None
-
- # Get/create the child manifest in the database.
- child_manifest_info = get_or_create_manifest(repository_id, child_manifest, storage,
- raise_on_error=raise_on_error)
- if child_manifest_info is None:
- logger.error('Could not get/create child manifest')
- return None
-
- child_manifest_rows[child_manifest_info.manifest.digest] = child_manifest_info.manifest
- child_manifest_label_dicts.append(labels)
-
- # Ensure all the blobs in the manifest exist.
- digests = set(manifest_interface_instance.local_blob_digests)
- blob_map = {}
-
- # If the special empty layer is required, simply load it directly. This is much faster
- # than trying to load it on a per repository basis, and that is unnecessary anyway since
- # this layer is predefined.
- if EMPTY_LAYER_BLOB_DIGEST in digests:
- digests.remove(EMPTY_LAYER_BLOB_DIGEST)
- blob_map[EMPTY_LAYER_BLOB_DIGEST] = get_shared_blob(EMPTY_LAYER_BLOB_DIGEST)
- if not blob_map[EMPTY_LAYER_BLOB_DIGEST]:
- logger.warning('Could not find the special empty blob in storage')
- return None
-
- if digests:
- query = lookup_repo_storages_by_content_checksum(repository_id, digests)
- blob_map.update({s.content_checksum: s for s in query})
- for digest_str in digests:
- if digest_str not in blob_map:
- logger.warning('Unknown blob `%s` under manifest `%s` for repository `%s`', digest_str,
- manifest_interface_instance.digest, repository_id)
-
- if raise_on_error:
- raise CreateManifestException('Unknown blob `%s`' % digest_str)
-
- return None
-
- # Special check: If the empty layer blob is needed for this manifest, add it to the
- # blob map. This is necessary because Docker decided to elide sending of this special
- # empty layer in schema version 2, but we need to have it referenced for GC and schema version 1.
- if EMPTY_LAYER_BLOB_DIGEST not in blob_map:
- if manifest_interface_instance.get_requires_empty_layer_blob(retriever):
- shared_blob = get_or_create_shared_blob(EMPTY_LAYER_BLOB_DIGEST, EMPTY_LAYER_BYTES, storage)
- assert not shared_blob.uploading
- assert shared_blob.content_checksum == EMPTY_LAYER_BLOB_DIGEST
- blob_map[EMPTY_LAYER_BLOB_DIGEST] = shared_blob
-
- # Determine and populate the legacy image if necessary. Manifest lists will not have a legacy
- # image.
- legacy_image = None
- if manifest_interface_instance.has_legacy_image:
- legacy_image_id = _populate_legacy_image(repository_id, manifest_interface_instance, blob_map,
- retriever)
- if legacy_image_id is None:
- return None
-
- legacy_image = get_image(repository_id, legacy_image_id)
- if legacy_image is None:
- return None
-
- # Create the manifest and its blobs.
- media_type = Manifest.media_type.get_id(manifest_interface_instance.media_type)
- storage_ids = {storage.id for storage in blob_map.values()}
-
- with db_transaction():
- # Check for the manifest. This is necessary because Postgres doesn't handle IntegrityErrors
- # well under transactions.
- try:
- manifest = Manifest.get(repository=repository_id, digest=manifest_interface_instance.digest)
- return CreatedManifest(manifest=manifest, newly_created=False, labels_to_apply=None)
- except Manifest.DoesNotExist:
- pass
-
- # Create the manifest.
- try:
- manifest = Manifest.create(repository=repository_id,
- digest=manifest_interface_instance.digest,
- media_type=media_type,
- manifest_bytes=manifest_interface_instance.bytes.as_encoded_str())
- except IntegrityError:
- manifest = Manifest.get(repository=repository_id, digest=manifest_interface_instance.digest)
- return CreatedManifest(manifest=manifest, newly_created=False, labels_to_apply=None)
-
- # Insert the blobs.
- blobs_to_insert = [dict(manifest=manifest, repository=repository_id,
- blob=storage_id) for storage_id in storage_ids]
- if blobs_to_insert:
- ManifestBlob.insert_many(blobs_to_insert).execute()
-
- # Set the legacy image (if applicable).
- if legacy_image is not None:
- ManifestLegacyImage.create(repository=repository_id, image=legacy_image, manifest=manifest)
-
- # Insert the manifest child rows (if applicable).
- if child_manifest_rows:
- children_to_insert = [dict(manifest=manifest, child_manifest=child_manifest,
- repository=repository_id)
- for child_manifest in child_manifest_rows.values()]
- ManifestChild.insert_many(children_to_insert).execute()
-
- # If this manifest is being created not for immediate tagging, add a temporary tag to the
- # manifest to ensure it isn't being GCed. If the manifest *is* for tagging, then since we're
- # creating a new one here, it cannot be GCed (since it isn't referenced by anything yet), so
- # its safe to elide the temp tag operation. If we ever change GC code to collect *all* manifests
- # in a repository for GC, then we will have to reevaluate this optimization at that time.
- if not for_tagging:
- create_temporary_tag_if_necessary(manifest, temp_tag_expiration_sec)
-
- # Define the labels for the manifest (if any).
- labels = manifest_interface_instance.get_manifest_labels(retriever)
- if labels:
- for key, value in labels.iteritems():
- media_type = 'application/json' if is_json(value) else 'text/plain'
- create_manifest_label(manifest, key, value, 'manifest', media_type)
-
- # Return the dictionary of labels to apply (i.e. those labels that cause an action to be taken
- # on the manifest or its resulting tags). We only return those labels either defined on
- # the manifest or shared amongst all the child manifests. We intersect amongst all child manifests
- # to ensure that any action performed is defined in all manifests.
- labels_to_apply = labels or {}
- if child_manifest_label_dicts:
- labels_to_apply = child_manifest_label_dicts[0].viewitems()
- for child_manifest_label_dict in child_manifest_label_dicts[1:]:
- # Intersect the key+values of the labels to ensure we get the exact same result
- # for all the child manifests.
- labels_to_apply = labels_to_apply & child_manifest_label_dict.viewitems()
-
- labels_to_apply = dict(labels_to_apply)
-
- return CreatedManifest(manifest=manifest, newly_created=True, labels_to_apply=labels_to_apply)
-
-
-def _populate_legacy_image(repository_id, manifest_interface_instance, blob_map, retriever):
- # Lookup all the images and their parent images (if any) inside the manifest.
- # This will let us know which v1 images we need to synthesize and which ones are invalid.
- docker_image_ids = list(manifest_interface_instance.get_legacy_image_ids(retriever))
- images_query = lookup_repository_images(repository_id, docker_image_ids)
- image_storage_map = {i.docker_image_id: i.storage for i in images_query}
-
- # Rewrite any v1 image IDs that do not match the checksum in the database.
- try:
- rewritten_images = manifest_interface_instance.generate_legacy_layers(image_storage_map,
- retriever)
- rewritten_images = list(rewritten_images)
- parent_image_map = {}
-
- for rewritten_image in rewritten_images:
- if not rewritten_image.image_id in image_storage_map:
- parent_image = None
- if rewritten_image.parent_image_id:
- parent_image = parent_image_map.get(rewritten_image.parent_image_id)
- if parent_image is None:
- parent_image = get_image(repository_id, rewritten_image.parent_image_id)
- if parent_image is None:
- return None
-
- storage_reference = blob_map[rewritten_image.content_checksum]
- synthesized = synthesize_v1_image(
- repository_id,
- storage_reference.id,
- storage_reference.image_size,
- rewritten_image.image_id,
- rewritten_image.created,
- rewritten_image.comment,
- rewritten_image.command,
- rewritten_image.compat_json,
- parent_image,
+ existing = lookup_manifest(
+ repository_id,
+ manifest_interface_instance.digest,
+ allow_dead=True,
+ require_available=True,
+ temp_tag_expiration_sec=temp_tag_expiration_sec,
+ )
+ if existing is not None:
+ return CreatedManifest(
+ manifest=existing, newly_created=False, labels_to_apply=None
)
- parent_image_map[rewritten_image.image_id] = synthesized
- except ManifestException:
- logger.exception("exception when rewriting v1 metadata")
- return None
+ return _create_manifest(
+ repository_id,
+ manifest_interface_instance,
+ storage,
+ temp_tag_expiration_sec,
+ for_tagging=for_tagging,
+ raise_on_error=raise_on_error,
+ )
- return rewritten_images[-1].image_id
+
+def _create_manifest(
+ repository_id,
+ manifest_interface_instance,
+ storage,
+ temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
+ for_tagging=False,
+ raise_on_error=False,
+):
+ # Validate the manifest.
+ retriever = RepositoryContentRetriever.for_repository(repository_id, storage)
+ try:
+ manifest_interface_instance.validate(retriever)
+ except (
+ ManifestException,
+ MalformedSchema2ManifestList,
+ BlobDoesNotExist,
+ IOError,
+ ) as ex:
+ logger.exception(
+ "Could not validate manifest `%s`", manifest_interface_instance.digest
+ )
+ if raise_on_error:
+ raise CreateManifestException(ex)
+
+ return None
+
+ # Load, parse and get/create the child manifests, if any.
+ child_manifest_refs = manifest_interface_instance.child_manifests(retriever)
+ child_manifest_rows = {}
+ child_manifest_label_dicts = []
+
+ if child_manifest_refs is not None:
+ for child_manifest_ref in child_manifest_refs:
+ # Load and parse the child manifest.
+ try:
+ child_manifest = child_manifest_ref.manifest_obj
+ except (
+ ManifestException,
+ MalformedSchema2ManifestList,
+ BlobDoesNotExist,
+ IOError,
+ ) as ex:
+ logger.exception(
+ "Could not load manifest list for manifest `%s`",
+ manifest_interface_instance.digest,
+ )
+ if raise_on_error:
+ raise CreateManifestException(ex)
+
+ return None
+
+ # Retrieve its labels.
+ labels = child_manifest.get_manifest_labels(retriever)
+ if labels is None:
+ logger.exception("Could not load manifest labels for child manifest")
+ return None
+
+ # Get/create the child manifest in the database.
+ child_manifest_info = get_or_create_manifest(
+ repository_id, child_manifest, storage, raise_on_error=raise_on_error
+ )
+ if child_manifest_info is None:
+ logger.error("Could not get/create child manifest")
+ return None
+
+ child_manifest_rows[
+ child_manifest_info.manifest.digest
+ ] = child_manifest_info.manifest
+ child_manifest_label_dicts.append(labels)
+
+ # Ensure all the blobs in the manifest exist.
+ digests = set(manifest_interface_instance.local_blob_digests)
+ blob_map = {}
+
+ # If the special empty layer is required, simply load it directly. This is much faster
+ # than trying to load it on a per repository basis, and that is unnecessary anyway since
+ # this layer is predefined.
+ if EMPTY_LAYER_BLOB_DIGEST in digests:
+ digests.remove(EMPTY_LAYER_BLOB_DIGEST)
+ blob_map[EMPTY_LAYER_BLOB_DIGEST] = get_shared_blob(EMPTY_LAYER_BLOB_DIGEST)
+ if not blob_map[EMPTY_LAYER_BLOB_DIGEST]:
+ logger.warning("Could not find the special empty blob in storage")
+ return None
+
+ if digests:
+ query = lookup_repo_storages_by_content_checksum(repository_id, digests)
+ blob_map.update({s.content_checksum: s for s in query})
+ for digest_str in digests:
+ if digest_str not in blob_map:
+ logger.warning(
+ "Unknown blob `%s` under manifest `%s` for repository `%s`",
+ digest_str,
+ manifest_interface_instance.digest,
+ repository_id,
+ )
+
+ if raise_on_error:
+ raise CreateManifestException("Unknown blob `%s`" % digest_str)
+
+ return None
+
+ # Special check: If the empty layer blob is needed for this manifest, add it to the
+ # blob map. This is necessary because Docker decided to elide sending of this special
+ # empty layer in schema version 2, but we need to have it referenced for GC and schema version 1.
+ if EMPTY_LAYER_BLOB_DIGEST not in blob_map:
+ if manifest_interface_instance.get_requires_empty_layer_blob(retriever):
+ shared_blob = get_or_create_shared_blob(
+ EMPTY_LAYER_BLOB_DIGEST, EMPTY_LAYER_BYTES, storage
+ )
+ assert not shared_blob.uploading
+ assert shared_blob.content_checksum == EMPTY_LAYER_BLOB_DIGEST
+ blob_map[EMPTY_LAYER_BLOB_DIGEST] = shared_blob
+
+ # Determine and populate the legacy image if necessary. Manifest lists will not have a legacy
+ # image.
+ legacy_image = None
+ if manifest_interface_instance.has_legacy_image:
+ legacy_image_id = _populate_legacy_image(
+ repository_id, manifest_interface_instance, blob_map, retriever
+ )
+ if legacy_image_id is None:
+ return None
+
+ legacy_image = get_image(repository_id, legacy_image_id)
+ if legacy_image is None:
+ return None
+
+ # Create the manifest and its blobs.
+ media_type = Manifest.media_type.get_id(manifest_interface_instance.media_type)
+ storage_ids = {storage.id for storage in blob_map.values()}
+
+ with db_transaction():
+ # Check for the manifest. This is necessary because Postgres doesn't handle IntegrityErrors
+ # well under transactions.
+ try:
+ manifest = Manifest.get(
+ repository=repository_id, digest=manifest_interface_instance.digest
+ )
+ return CreatedManifest(
+ manifest=manifest, newly_created=False, labels_to_apply=None
+ )
+ except Manifest.DoesNotExist:
+ pass
+
+ # Create the manifest.
+ try:
+ manifest = Manifest.create(
+ repository=repository_id,
+ digest=manifest_interface_instance.digest,
+ media_type=media_type,
+ manifest_bytes=manifest_interface_instance.bytes.as_encoded_str(),
+ )
+ except IntegrityError:
+ manifest = Manifest.get(
+ repository=repository_id, digest=manifest_interface_instance.digest
+ )
+ return CreatedManifest(
+ manifest=manifest, newly_created=False, labels_to_apply=None
+ )
+
+ # Insert the blobs.
+ blobs_to_insert = [
+ dict(manifest=manifest, repository=repository_id, blob=storage_id)
+ for storage_id in storage_ids
+ ]
+ if blobs_to_insert:
+ ManifestBlob.insert_many(blobs_to_insert).execute()
+
+ # Set the legacy image (if applicable).
+ if legacy_image is not None:
+ ManifestLegacyImage.create(
+ repository=repository_id, image=legacy_image, manifest=manifest
+ )
+
+ # Insert the manifest child rows (if applicable).
+ if child_manifest_rows:
+ children_to_insert = [
+ dict(
+ manifest=manifest,
+ child_manifest=child_manifest,
+ repository=repository_id,
+ )
+ for child_manifest in child_manifest_rows.values()
+ ]
+ ManifestChild.insert_many(children_to_insert).execute()
+
+ # If this manifest is being created not for immediate tagging, add a temporary tag to the
+ # manifest to ensure it isn't being GCed. If the manifest *is* for tagging, then since we're
+ # creating a new one here, it cannot be GCed (since it isn't referenced by anything yet), so
+ # its safe to elide the temp tag operation. If we ever change GC code to collect *all* manifests
+ # in a repository for GC, then we will have to reevaluate this optimization at that time.
+ if not for_tagging:
+ create_temporary_tag_if_necessary(manifest, temp_tag_expiration_sec)
+
+ # Define the labels for the manifest (if any).
+ labels = manifest_interface_instance.get_manifest_labels(retriever)
+ if labels:
+ for key, value in labels.iteritems():
+ media_type = "application/json" if is_json(value) else "text/plain"
+ create_manifest_label(manifest, key, value, "manifest", media_type)
+
+ # Return the dictionary of labels to apply (i.e. those labels that cause an action to be taken
+ # on the manifest or its resulting tags). We only return those labels either defined on
+ # the manifest or shared amongst all the child manifests. We intersect amongst all child manifests
+ # to ensure that any action performed is defined in all manifests.
+ labels_to_apply = labels or {}
+ if child_manifest_label_dicts:
+ labels_to_apply = child_manifest_label_dicts[0].viewitems()
+ for child_manifest_label_dict in child_manifest_label_dicts[1:]:
+ # Intersect the key+values of the labels to ensure we get the exact same result
+ # for all the child manifests.
+ labels_to_apply = labels_to_apply & child_manifest_label_dict.viewitems()
+
+ labels_to_apply = dict(labels_to_apply)
+
+ return CreatedManifest(
+ manifest=manifest, newly_created=True, labels_to_apply=labels_to_apply
+ )
+
+
+def _populate_legacy_image(
+ repository_id, manifest_interface_instance, blob_map, retriever
+):
+ # Lookup all the images and their parent images (if any) inside the manifest.
+ # This will let us know which v1 images we need to synthesize and which ones are invalid.
+ docker_image_ids = list(manifest_interface_instance.get_legacy_image_ids(retriever))
+ images_query = lookup_repository_images(repository_id, docker_image_ids)
+ image_storage_map = {i.docker_image_id: i.storage for i in images_query}
+
+ # Rewrite any v1 image IDs that do not match the checksum in the database.
+ try:
+ rewritten_images = manifest_interface_instance.generate_legacy_layers(
+ image_storage_map, retriever
+ )
+ rewritten_images = list(rewritten_images)
+ parent_image_map = {}
+
+ for rewritten_image in rewritten_images:
+ if not rewritten_image.image_id in image_storage_map:
+ parent_image = None
+ if rewritten_image.parent_image_id:
+ parent_image = parent_image_map.get(rewritten_image.parent_image_id)
+ if parent_image is None:
+ parent_image = get_image(
+ repository_id, rewritten_image.parent_image_id
+ )
+ if parent_image is None:
+ return None
+
+ storage_reference = blob_map[rewritten_image.content_checksum]
+ synthesized = synthesize_v1_image(
+ repository_id,
+ storage_reference.id,
+ storage_reference.image_size,
+ rewritten_image.image_id,
+ rewritten_image.created,
+ rewritten_image.comment,
+ rewritten_image.command,
+ rewritten_image.compat_json,
+ parent_image,
+ )
+
+ parent_image_map[rewritten_image.image_id] = synthesized
+ except ManifestException:
+ logger.exception("exception when rewriting v1 metadata")
+ return None
+
+ return rewritten_images[-1].image_id
diff --git a/data/model/oci/retriever.py b/data/model/oci/retriever.py
index b6e9633e0..b4f563058 100644
--- a/data/model/oci/retriever.py
+++ b/data/model/oci/retriever.py
@@ -3,35 +3,38 @@ from data.database import Manifest
from data.model.oci.blob import get_repository_blob_by_digest
from data.model.storage import get_layer_path
+
class RepositoryContentRetriever(ContentRetriever):
- """ Implementation of the ContentRetriever interface for manifests that retrieves
+ """ Implementation of the ContentRetriever interface for manifests that retrieves
config blobs and child manifests for the specified repository.
"""
- def __init__(self, repository_id, storage):
- self.repository_id = repository_id
- self.storage = storage
- @classmethod
- def for_repository(cls, repository_id, storage):
- return RepositoryContentRetriever(repository_id, storage)
+ def __init__(self, repository_id, storage):
+ self.repository_id = repository_id
+ self.storage = storage
- def get_manifest_bytes_with_digest(self, digest):
- """ Returns the bytes of the manifest with the given digest or None if none found. """
- query = (Manifest
- .select()
- .where(Manifest.repository == self.repository_id)
- .where(Manifest.digest == digest))
+ @classmethod
+ def for_repository(cls, repository_id, storage):
+ return RepositoryContentRetriever(repository_id, storage)
- try:
- return query.get().manifest_bytes
- except Manifest.DoesNotExist:
- return None
+ def get_manifest_bytes_with_digest(self, digest):
+ """ Returns the bytes of the manifest with the given digest or None if none found. """
+ query = (
+ Manifest.select()
+ .where(Manifest.repository == self.repository_id)
+ .where(Manifest.digest == digest)
+ )
- def get_blob_bytes_with_digest(self, digest):
- """ Returns the bytes of the blob with the given digest or None if none found. """
- blob = get_repository_blob_by_digest(self.repository_id, digest)
- if blob is None:
- return None
+ try:
+ return query.get().manifest_bytes
+ except Manifest.DoesNotExist:
+ return None
- assert blob.locations is not None
- return self.storage.get_content(blob.locations, get_layer_path(blob))
+ def get_blob_bytes_with_digest(self, digest):
+ """ Returns the bytes of the blob with the given digest or None if none found. """
+ blob = get_repository_blob_by_digest(self.repository_id, digest)
+ if blob is None:
+ return None
+
+ assert blob.locations is not None
+ return self.storage.get_content(blob.locations, get_layer_path(blob))
diff --git a/data/model/oci/shared.py b/data/model/oci/shared.py
index 887eda383..6bf59c18b 100644
--- a/data/model/oci/shared.py
+++ b/data/model/oci/shared.py
@@ -1,24 +1,27 @@
from data.database import Manifest, ManifestLegacyImage, Image
+
def get_legacy_image_for_manifest(manifest_id):
- """ Returns the legacy image associated with the given manifest, if any, or None if none. """
- try:
- query = (ManifestLegacyImage
- .select(ManifestLegacyImage, Image)
- .join(Image)
- .where(ManifestLegacyImage.manifest == manifest_id))
- return query.get().image
- except ManifestLegacyImage.DoesNotExist:
- return None
+ """ Returns the legacy image associated with the given manifest, if any, or None if none. """
+ try:
+ query = (
+ ManifestLegacyImage.select(ManifestLegacyImage, Image)
+ .join(Image)
+ .where(ManifestLegacyImage.manifest == manifest_id)
+ )
+ return query.get().image
+ except ManifestLegacyImage.DoesNotExist:
+ return None
def get_manifest_for_legacy_image(image_id):
- """ Returns a manifest that is associated with the given image, if any, or None if none. """
- try:
- query = (ManifestLegacyImage
- .select(ManifestLegacyImage, Manifest)
- .join(Manifest)
- .where(ManifestLegacyImage.image == image_id))
- return query.get().manifest
- except ManifestLegacyImage.DoesNotExist:
- return None
+ """ Returns a manifest that is associated with the given image, if any, or None if none. """
+ try:
+ query = (
+ ManifestLegacyImage.select(ManifestLegacyImage, Manifest)
+ .join(Manifest)
+ .where(ManifestLegacyImage.image == image_id)
+ )
+ return query.get().manifest
+ except ManifestLegacyImage.DoesNotExist:
+ return None
diff --git a/data/model/oci/tag.py b/data/model/oci/tag.py
index 4ad1b8c18..45e2c8004 100644
--- a/data/model/oci/tag.py
+++ b/data/model/oci/tag.py
@@ -4,15 +4,31 @@ import logging
from calendar import timegm
from peewee import fn
-from data.database import (Tag, Manifest, ManifestLegacyImage, Image, ImageStorage,
- MediaType, RepositoryTag, TagManifest, TagManifestToManifest,
- get_epoch_timestamp_ms, db_transaction, Repository,
- TagToRepositoryTag, Namespace, RepositoryNotification,
- ExternalNotificationEvent)
+from data.database import (
+ Tag,
+ Manifest,
+ ManifestLegacyImage,
+ Image,
+ ImageStorage,
+ MediaType,
+ RepositoryTag,
+ TagManifest,
+ TagManifestToManifest,
+ get_epoch_timestamp_ms,
+ db_transaction,
+ Repository,
+ TagToRepositoryTag,
+ Namespace,
+ RepositoryNotification,
+ ExternalNotificationEvent,
+)
from data.model.oci.shared import get_legacy_image_for_manifest
from data.model import config
-from image.docker.schema1 import (DOCKER_SCHEMA1_CONTENT_TYPES, DockerSchema1Manifest,
- MalformedSchema1Manifest)
+from image.docker.schema1 import (
+ DOCKER_SCHEMA1_CONTENT_TYPES,
+ DockerSchema1Manifest,
+ MalformedSchema1Manifest,
+)
from util.bytes import Bytes
from util.timedeltastring import convert_to_timedelta
@@ -20,486 +36,558 @@ logger = logging.getLogger(__name__)
def get_tag_by_id(tag_id):
- """ Returns the tag with the given ID, joined with its manifest or None if none. """
- try:
- return Tag.select(Tag, Manifest).join(Manifest).where(Tag.id == tag_id).get()
- except Tag.DoesNotExist:
- return None
+ """ Returns the tag with the given ID, joined with its manifest or None if none. """
+ try:
+ return Tag.select(Tag, Manifest).join(Manifest).where(Tag.id == tag_id).get()
+ except Tag.DoesNotExist:
+ return None
def get_tag(repository_id, tag_name):
- """ Returns the alive, non-hidden tag with the given name under the specified repository or
+ """ Returns the alive, non-hidden tag with the given name under the specified repository or
None if none. The tag is returned joined with its manifest.
"""
- query = (Tag
- .select(Tag, Manifest)
- .join(Manifest)
- .where(Tag.repository == repository_id)
- .where(Tag.name == tag_name))
+ query = (
+ Tag.select(Tag, Manifest)
+ .join(Manifest)
+ .where(Tag.repository == repository_id)
+ .where(Tag.name == tag_name)
+ )
- query = filter_to_alive_tags(query)
+ query = filter_to_alive_tags(query)
- try:
- found = query.get()
- assert not found.hidden
- return found
- except Tag.DoesNotExist:
- return None
+ try:
+ found = query.get()
+ assert not found.hidden
+ return found
+ except Tag.DoesNotExist:
+ return None
def lookup_alive_tags_shallow(repository_id, start_pagination_id=None, limit=None):
- """ Returns a list of the tags alive in the specified repository. Note that the tags returned
+ """ Returns a list of the tags alive in the specified repository. Note that the tags returned
*only* contain their ID and name. Also note that the Tags are returned ordered by ID.
"""
- query = (Tag
- .select(Tag.id, Tag.name)
- .where(Tag.repository == repository_id)
- .order_by(Tag.id))
+ query = (
+ Tag.select(Tag.id, Tag.name)
+ .where(Tag.repository == repository_id)
+ .order_by(Tag.id)
+ )
- if start_pagination_id is not None:
- query = query.where(Tag.id >= start_pagination_id)
+ if start_pagination_id is not None:
+ query = query.where(Tag.id >= start_pagination_id)
- if limit is not None:
- query = query.limit(limit)
+ if limit is not None:
+ query = query.limit(limit)
- return filter_to_alive_tags(query)
+ return filter_to_alive_tags(query)
def list_alive_tags(repository_id):
- """ Returns a list of all the tags alive in the specified repository.
+ """ Returns a list of all the tags alive in the specified repository.
Tag's returned are joined with their manifest.
"""
- query = (Tag
- .select(Tag, Manifest)
- .join(Manifest)
- .where(Tag.repository == repository_id))
+ query = (
+ Tag.select(Tag, Manifest).join(Manifest).where(Tag.repository == repository_id)
+ )
- return filter_to_alive_tags(query)
+ return filter_to_alive_tags(query)
-def list_repository_tag_history(repository_id, page, page_size, specific_tag_name=None,
- active_tags_only=False, since_time_ms=None):
- """ Returns a tuple of the full set of tags found in the specified repository, including those
+def list_repository_tag_history(
+ repository_id,
+ page,
+ page_size,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+):
+ """ Returns a tuple of the full set of tags found in the specified repository, including those
that are no longer alive (unless active_tags_only is True), and whether additional tags exist.
If specific_tag_name is given, the tags are further filtered by name. If since is given, tags
are further filtered to newer than that date.
Note that the returned Manifest will not contain the manifest contents.
"""
- query = (Tag
- .select(Tag, Manifest.id, Manifest.digest, Manifest.media_type)
- .join(Manifest)
- .where(Tag.repository == repository_id)
- .order_by(Tag.lifetime_start_ms.desc(), Tag.name)
- .limit(page_size + 1)
- .offset(page_size * (page - 1)))
+ query = (
+ Tag.select(Tag, Manifest.id, Manifest.digest, Manifest.media_type)
+ .join(Manifest)
+ .where(Tag.repository == repository_id)
+ .order_by(Tag.lifetime_start_ms.desc(), Tag.name)
+ .limit(page_size + 1)
+ .offset(page_size * (page - 1))
+ )
- if specific_tag_name is not None:
- query = query.where(Tag.name == specific_tag_name)
+ if specific_tag_name is not None:
+ query = query.where(Tag.name == specific_tag_name)
- if since_time_ms is not None:
- query = query.where((Tag.lifetime_start_ms > since_time_ms) | (Tag.lifetime_end_ms > since_time_ms))
+ if since_time_ms is not None:
+ query = query.where(
+ (Tag.lifetime_start_ms > since_time_ms)
+ | (Tag.lifetime_end_ms > since_time_ms)
+ )
- if active_tags_only:
- query = filter_to_alive_tags(query)
+ if active_tags_only:
+ query = filter_to_alive_tags(query)
- query = filter_to_visible_tags(query)
- results = list(query)
+ query = filter_to_visible_tags(query)
+ results = list(query)
- return results[0:page_size], len(results) > page_size
+ return results[0:page_size], len(results) > page_size
def get_legacy_images_for_tags(tags):
- """ Returns a map from tag ID to the legacy image for the tag. """
- if not tags:
- return {}
+ """ Returns a map from tag ID to the legacy image for the tag. """
+ if not tags:
+ return {}
- query = (ManifestLegacyImage
- .select(ManifestLegacyImage, Image, ImageStorage)
- .join(Image)
- .join(ImageStorage)
- .where(ManifestLegacyImage.manifest << [tag.manifest_id for tag in tags]))
+ query = (
+ ManifestLegacyImage.select(ManifestLegacyImage, Image, ImageStorage)
+ .join(Image)
+ .join(ImageStorage)
+ .where(ManifestLegacyImage.manifest << [tag.manifest_id for tag in tags])
+ )
- by_manifest = {mli.manifest_id: mli.image for mli in query}
- return {tag.id: by_manifest[tag.manifest_id] for tag in tags if tag.manifest_id in by_manifest}
+ by_manifest = {mli.manifest_id: mli.image for mli in query}
+ return {
+ tag.id: by_manifest[tag.manifest_id]
+ for tag in tags
+ if tag.manifest_id in by_manifest
+ }
def find_matching_tag(repository_id, tag_names, tag_kinds=None):
- """ Finds an alive tag in the specified repository with one of the specified tag names and
+ """ Finds an alive tag in the specified repository with one of the specified tag names and
returns it or None if none. Tag's returned are joined with their manifest.
"""
- assert repository_id
- assert tag_names
+ assert repository_id
+ assert tag_names
- query = (Tag
- .select(Tag, Manifest)
- .join(Manifest)
- .where(Tag.repository == repository_id)
- .where(Tag.name << tag_names))
+ query = (
+ Tag.select(Tag, Manifest)
+ .join(Manifest)
+ .where(Tag.repository == repository_id)
+ .where(Tag.name << tag_names)
+ )
- if tag_kinds:
- query = query.where(Tag.tag_kind << tag_kinds)
+ if tag_kinds:
+ query = query.where(Tag.tag_kind << tag_kinds)
- try:
- found = filter_to_alive_tags(query).get()
- assert not found.hidden
- return found
- except Tag.DoesNotExist:
- return None
+ try:
+ found = filter_to_alive_tags(query).get()
+ assert not found.hidden
+ return found
+ except Tag.DoesNotExist:
+ return None
def get_most_recent_tag_lifetime_start(repository_ids):
- """ Returns a map from repo ID to the timestamp of the most recently pushed alive tag
+ """ Returns a map from repo ID to the timestamp of the most recently pushed alive tag
for each specified repository or None if none.
"""
- assert len(repository_ids) > 0 and None not in repository_ids
+ assert len(repository_ids) > 0 and None not in repository_ids
- query = (Tag.select(Tag.repository, fn.Max(Tag.lifetime_start_ms))
- .where(Tag.repository << [repo_id for repo_id in repository_ids])
- .group_by(Tag.repository))
- tuples = filter_to_alive_tags(query).tuples()
+ query = (
+ Tag.select(Tag.repository, fn.Max(Tag.lifetime_start_ms))
+ .where(Tag.repository << [repo_id for repo_id in repository_ids])
+ .group_by(Tag.repository)
+ )
+ tuples = filter_to_alive_tags(query).tuples()
- return {repo_id: timestamp for repo_id, timestamp in tuples}
+ return {repo_id: timestamp for repo_id, timestamp in tuples}
def get_most_recent_tag(repository_id):
- """ Returns the most recently pushed alive tag in the specified repository or None if none.
+ """ Returns the most recently pushed alive tag in the specified repository or None if none.
The Tag returned is joined with its manifest.
"""
- assert repository_id
+ assert repository_id
- query = (Tag
- .select(Tag, Manifest)
- .join(Manifest)
- .where(Tag.repository == repository_id)
- .order_by(Tag.lifetime_start_ms.desc()))
+ query = (
+ Tag.select(Tag, Manifest)
+ .join(Manifest)
+ .where(Tag.repository == repository_id)
+ .order_by(Tag.lifetime_start_ms.desc())
+ )
- try:
- found = filter_to_alive_tags(query).get()
- assert not found.hidden
- return found
- except Tag.DoesNotExist:
- return None
+ try:
+ found = filter_to_alive_tags(query).get()
+ assert not found.hidden
+ return found
+ except Tag.DoesNotExist:
+ return None
def get_expired_tag(repository_id, tag_name):
- """ Returns a tag with the given name that is expired in the repository or None if none.
+ """ Returns a tag with the given name that is expired in the repository or None if none.
"""
- try:
- return (Tag
- .select()
+ try:
+ return (
+ Tag.select()
.where(Tag.name == tag_name, Tag.repository == repository_id)
.where(~(Tag.lifetime_end_ms >> None))
.where(Tag.lifetime_end_ms <= get_epoch_timestamp_ms())
- .get())
- except Tag.DoesNotExist:
- return None
+ .get()
+ )
+ except Tag.DoesNotExist:
+ return None
def create_temporary_tag_if_necessary(manifest, expiration_sec):
- """ Creates a temporary tag pointing to the given manifest, with the given expiration in seconds,
+ """ Creates a temporary tag pointing to the given manifest, with the given expiration in seconds,
unless there is an existing tag that will keep the manifest around.
"""
- tag_name = '$temp-%s' % str(uuid.uuid4())
- now_ms = get_epoch_timestamp_ms()
- end_ms = now_ms + (expiration_sec * 1000)
+ tag_name = "$temp-%s" % str(uuid.uuid4())
+ now_ms = get_epoch_timestamp_ms()
+ end_ms = now_ms + (expiration_sec * 1000)
- # Check if there is an existing tag on the manifest that won't expire within the
- # timeframe. If so, no need for a temporary tag.
- with db_transaction():
- try:
- (Tag
- .select()
- .where(Tag.manifest == manifest,
- (Tag.lifetime_end_ms >> None) | (Tag.lifetime_end_ms >= end_ms))
- .get())
- return None
- except Tag.DoesNotExist:
- pass
+ # Check if there is an existing tag on the manifest that won't expire within the
+ # timeframe. If so, no need for a temporary tag.
+ with db_transaction():
+ try:
+ (
+ Tag.select()
+ .where(
+ Tag.manifest == manifest,
+ (Tag.lifetime_end_ms >> None) | (Tag.lifetime_end_ms >= end_ms),
+ )
+ .get()
+ )
+ return None
+ except Tag.DoesNotExist:
+ pass
- return Tag.create(name=tag_name,
- repository=manifest.repository_id,
- lifetime_start_ms=now_ms,
- lifetime_end_ms=end_ms,
- reversion=False,
- hidden=True,
- manifest=manifest,
- tag_kind=Tag.tag_kind.get_id('tag'))
+ return Tag.create(
+ name=tag_name,
+ repository=manifest.repository_id,
+ lifetime_start_ms=now_ms,
+ lifetime_end_ms=end_ms,
+ reversion=False,
+ hidden=True,
+ manifest=manifest,
+ tag_kind=Tag.tag_kind.get_id("tag"),
+ )
-def retarget_tag(tag_name, manifest_id, is_reversion=False, now_ms=None, adjust_old_model=True):
- """ Creates or updates a tag with the specified name to point to the given manifest under
+def retarget_tag(
+ tag_name, manifest_id, is_reversion=False, now_ms=None, adjust_old_model=True
+):
+ """ Creates or updates a tag with the specified name to point to the given manifest under
its repository. If this action is a reversion to a previous manifest, is_reversion
should be set to True. Returns the newly created tag row or None on error.
"""
- try:
- manifest = (Manifest
- .select(Manifest, MediaType)
- .join(MediaType)
- .where(Manifest.id == manifest_id)
- .get())
- except Manifest.DoesNotExist:
- return None
-
- # CHECK: Make sure that we are not mistargeting a schema 1 manifest to a tag with a different
- # name.
- if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES:
try:
- parsed = DockerSchema1Manifest(Bytes.for_string_or_unicode(manifest.manifest_bytes),
- validate=False)
- if parsed.tag != tag_name:
- logger.error('Tried to re-target schema1 manifest with tag `%s` to tag `%s', parsed.tag,
- tag_name)
- return None
- except MalformedSchema1Manifest:
- logger.exception('Could not parse schema1 manifest')
- return None
-
- legacy_image = get_legacy_image_for_manifest(manifest)
- now_ms = now_ms or get_epoch_timestamp_ms()
- now_ts = int(now_ms / 1000)
-
- with db_transaction():
- # Lookup an existing tag in the repository with the same name and, if present, mark it
- # as expired.
- existing_tag = get_tag(manifest.repository_id, tag_name)
- if existing_tag is not None:
- _, okay = set_tag_end_ms(existing_tag, now_ms)
-
- # TODO: should we retry here and/or use a for-update?
- if not okay:
+ manifest = (
+ Manifest.select(Manifest, MediaType)
+ .join(MediaType)
+ .where(Manifest.id == manifest_id)
+ .get()
+ )
+ except Manifest.DoesNotExist:
return None
- # Create a new tag pointing to the manifest with a lifetime start of now.
- created = Tag.create(name=tag_name, repository=manifest.repository_id, lifetime_start_ms=now_ms,
- reversion=is_reversion, manifest=manifest,
- tag_kind=Tag.tag_kind.get_id('tag'))
+ # CHECK: Make sure that we are not mistargeting a schema 1 manifest to a tag with a different
+ # name.
+ if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES:
+ try:
+ parsed = DockerSchema1Manifest(
+ Bytes.for_string_or_unicode(manifest.manifest_bytes), validate=False
+ )
+ if parsed.tag != tag_name:
+ logger.error(
+ "Tried to re-target schema1 manifest with tag `%s` to tag `%s",
+ parsed.tag,
+ tag_name,
+ )
+ return None
+ except MalformedSchema1Manifest:
+ logger.exception("Could not parse schema1 manifest")
+ return None
- # TODO: Remove the linkage code once RepositoryTag is gone.
- # If this is a schema 1 manifest, then add a TagManifest linkage to it. Otherwise, it will only
- # be pullable via the new OCI model.
- if adjust_old_model:
- if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES and legacy_image is not None:
- old_style_tag = RepositoryTag.create(repository=manifest.repository_id, image=legacy_image,
- name=tag_name, lifetime_start_ts=now_ts,
- reversion=is_reversion)
- TagToRepositoryTag.create(tag=created, repository_tag=old_style_tag,
- repository=manifest.repository_id)
+ legacy_image = get_legacy_image_for_manifest(manifest)
+ now_ms = now_ms or get_epoch_timestamp_ms()
+ now_ts = int(now_ms / 1000)
- tag_manifest = TagManifest.create(tag=old_style_tag, digest=manifest.digest,
- json_data=manifest.manifest_bytes)
- TagManifestToManifest.create(tag_manifest=tag_manifest, manifest=manifest,
- repository=manifest.repository_id)
+ with db_transaction():
+ # Lookup an existing tag in the repository with the same name and, if present, mark it
+ # as expired.
+ existing_tag = get_tag(manifest.repository_id, tag_name)
+ if existing_tag is not None:
+ _, okay = set_tag_end_ms(existing_tag, now_ms)
- return created
+ # TODO: should we retry here and/or use a for-update?
+ if not okay:
+ return None
+
+ # Create a new tag pointing to the manifest with a lifetime start of now.
+ created = Tag.create(
+ name=tag_name,
+ repository=manifest.repository_id,
+ lifetime_start_ms=now_ms,
+ reversion=is_reversion,
+ manifest=manifest,
+ tag_kind=Tag.tag_kind.get_id("tag"),
+ )
+
+ # TODO: Remove the linkage code once RepositoryTag is gone.
+ # If this is a schema 1 manifest, then add a TagManifest linkage to it. Otherwise, it will only
+ # be pullable via the new OCI model.
+ if adjust_old_model:
+ if (
+ manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES
+ and legacy_image is not None
+ ):
+ old_style_tag = RepositoryTag.create(
+ repository=manifest.repository_id,
+ image=legacy_image,
+ name=tag_name,
+ lifetime_start_ts=now_ts,
+ reversion=is_reversion,
+ )
+ TagToRepositoryTag.create(
+ tag=created,
+ repository_tag=old_style_tag,
+ repository=manifest.repository_id,
+ )
+
+ tag_manifest = TagManifest.create(
+ tag=old_style_tag,
+ digest=manifest.digest,
+ json_data=manifest.manifest_bytes,
+ )
+ TagManifestToManifest.create(
+ tag_manifest=tag_manifest,
+ manifest=manifest,
+ repository=manifest.repository_id,
+ )
+
+ return created
def delete_tag(repository_id, tag_name):
- """ Deletes the alive tag with the given name in the specified repository and returns the deleted
+ """ Deletes the alive tag with the given name in the specified repository and returns the deleted
tag. If the tag did not exist, returns None.
"""
- tag = get_tag(repository_id, tag_name)
- if tag is None:
- return None
+ tag = get_tag(repository_id, tag_name)
+ if tag is None:
+ return None
- return _delete_tag(tag, get_epoch_timestamp_ms())
+ return _delete_tag(tag, get_epoch_timestamp_ms())
def _delete_tag(tag, now_ms):
- """ Deletes the given tag by marking it as expired. """
- now_ts = int(now_ms / 1000)
+ """ Deletes the given tag by marking it as expired. """
+ now_ts = int(now_ms / 1000)
- with db_transaction():
- updated = (Tag
- .update(lifetime_end_ms=now_ms)
- .where(Tag.id == tag.id, Tag.lifetime_end_ms == tag.lifetime_end_ms)
- .execute())
- if updated != 1:
- return None
+ with db_transaction():
+ updated = (
+ Tag.update(lifetime_end_ms=now_ms)
+ .where(Tag.id == tag.id, Tag.lifetime_end_ms == tag.lifetime_end_ms)
+ .execute()
+ )
+ if updated != 1:
+ return None
- # TODO: Remove the linkage code once RepositoryTag is gone.
- try:
- old_style_tag = (TagToRepositoryTag
- .select(TagToRepositoryTag, RepositoryTag)
- .join(RepositoryTag)
- .where(TagToRepositoryTag.tag == tag)
- .get()).repository_tag
+ # TODO: Remove the linkage code once RepositoryTag is gone.
+ try:
+ old_style_tag = (
+ TagToRepositoryTag.select(TagToRepositoryTag, RepositoryTag)
+ .join(RepositoryTag)
+ .where(TagToRepositoryTag.tag == tag)
+ .get()
+ ).repository_tag
- old_style_tag.lifetime_end_ts = now_ts
- old_style_tag.save()
- except TagToRepositoryTag.DoesNotExist:
- pass
+ old_style_tag.lifetime_end_ts = now_ts
+ old_style_tag.save()
+ except TagToRepositoryTag.DoesNotExist:
+ pass
- return tag
+ return tag
def delete_tags_for_manifest(manifest):
- """ Deletes all tags pointing to the given manifest. Returns the list of tags
+ """ Deletes all tags pointing to the given manifest. Returns the list of tags
deleted.
"""
- query = Tag.select().where(Tag.manifest == manifest)
- query = filter_to_alive_tags(query)
- query = filter_to_visible_tags(query)
+ query = Tag.select().where(Tag.manifest == manifest)
+ query = filter_to_alive_tags(query)
+ query = filter_to_visible_tags(query)
- tags = list(query)
- now_ms = get_epoch_timestamp_ms()
+ tags = list(query)
+ now_ms = get_epoch_timestamp_ms()
- with db_transaction():
- for tag in tags:
- _delete_tag(tag, now_ms)
+ with db_transaction():
+ for tag in tags:
+ _delete_tag(tag, now_ms)
- return tags
+ return tags
def filter_to_visible_tags(query):
- """ Adjusts the specified Tag query to only return those tags that are visible.
+ """ Adjusts the specified Tag query to only return those tags that are visible.
"""
- return query.where(Tag.hidden == False)
+ return query.where(Tag.hidden == False)
def filter_to_alive_tags(query, now_ms=None, model=Tag):
- """ Adjusts the specified Tag query to only return those tags alive. If now_ms is specified,
+ """ Adjusts the specified Tag query to only return those tags alive. If now_ms is specified,
the given timestamp (in MS) is used in place of the current timestamp for determining wherther
a tag is alive.
"""
- if now_ms is None:
- now_ms = get_epoch_timestamp_ms()
+ if now_ms is None:
+ now_ms = get_epoch_timestamp_ms()
- return (query.where((model.lifetime_end_ms >> None) | (model.lifetime_end_ms > now_ms))
- .where(model.hidden == False))
+ return query.where(
+ (model.lifetime_end_ms >> None) | (model.lifetime_end_ms > now_ms)
+ ).where(model.hidden == False)
def set_tag_expiration_sec_for_manifest(manifest_id, expiration_seconds):
- """ Sets the tag expiration for any tags that point to the given manifest ID. """
- query = Tag.select().where(Tag.manifest == manifest_id)
- query = filter_to_alive_tags(query)
- tags = list(query)
- for tag in tags:
- assert not tag.hidden
- set_tag_end_ms(tag, tag.lifetime_start_ms + (expiration_seconds * 1000))
+ """ Sets the tag expiration for any tags that point to the given manifest ID. """
+ query = Tag.select().where(Tag.manifest == manifest_id)
+ query = filter_to_alive_tags(query)
+ tags = list(query)
+ for tag in tags:
+ assert not tag.hidden
+ set_tag_end_ms(tag, tag.lifetime_start_ms + (expiration_seconds * 1000))
- return tags
+ return tags
def set_tag_expiration_for_manifest(manifest_id, expiration_datetime):
- """ Sets the tag expiration for any tags that point to the given manifest ID. """
- query = Tag.select().where(Tag.manifest == manifest_id)
- query = filter_to_alive_tags(query)
- tags = list(query)
- for tag in tags:
- assert not tag.hidden
- change_tag_expiration(tag, expiration_datetime)
+ """ Sets the tag expiration for any tags that point to the given manifest ID. """
+ query = Tag.select().where(Tag.manifest == manifest_id)
+ query = filter_to_alive_tags(query)
+ tags = list(query)
+ for tag in tags:
+ assert not tag.hidden
+ change_tag_expiration(tag, expiration_datetime)
- return tags
+ return tags
def change_tag_expiration(tag_id, expiration_datetime):
- """ Changes the expiration of the specified tag to the given expiration datetime. If
+ """ Changes the expiration of the specified tag to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring. Returns
a tuple of the previous expiration timestamp in seconds (if any), and whether the
operation succeeded.
"""
- try:
- tag = Tag.get(id=tag_id)
- except Tag.DoesNotExist:
- return (None, False)
+ try:
+ tag = Tag.get(id=tag_id)
+ except Tag.DoesNotExist:
+ return (None, False)
- new_end_ms = None
- min_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MINIMUM', '1h'))
- max_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MAXIMUM', '104w'))
+ new_end_ms = None
+ min_expire_sec = convert_to_timedelta(
+ config.app_config.get("LABELED_EXPIRATION_MINIMUM", "1h")
+ )
+ max_expire_sec = convert_to_timedelta(
+ config.app_config.get("LABELED_EXPIRATION_MAXIMUM", "104w")
+ )
- if expiration_datetime is not None:
- lifetime_start_ts = int(tag.lifetime_start_ms / 1000)
+ if expiration_datetime is not None:
+ lifetime_start_ts = int(tag.lifetime_start_ms / 1000)
- offset = timegm(expiration_datetime.utctimetuple()) - lifetime_start_ts
- offset = min(max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds())
- new_end_ms = tag.lifetime_start_ms + (offset * 1000)
+ offset = timegm(expiration_datetime.utctimetuple()) - lifetime_start_ts
+ offset = min(
+ max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds()
+ )
+ new_end_ms = tag.lifetime_start_ms + (offset * 1000)
- if new_end_ms == tag.lifetime_end_ms:
- return (None, True)
+ if new_end_ms == tag.lifetime_end_ms:
+ return (None, True)
- return set_tag_end_ms(tag, new_end_ms)
+ return set_tag_end_ms(tag, new_end_ms)
def lookup_unrecoverable_tags(repo):
- """ Returns the tags in a repository that are expired and past their time machine recovery
+ """ Returns the tags in a repository that are expired and past their time machine recovery
period. """
- expired_clause = get_epoch_timestamp_ms() - (Namespace.removed_tag_expiration_s * 1000)
- return (Tag
- .select()
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Tag.repository == repo)
- .where(~(Tag.lifetime_end_ms >> None), Tag.lifetime_end_ms <= expired_clause))
+ expired_clause = get_epoch_timestamp_ms() - (
+ Namespace.removed_tag_expiration_s * 1000
+ )
+ return (
+ Tag.select()
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Tag.repository == repo)
+ .where(~(Tag.lifetime_end_ms >> None), Tag.lifetime_end_ms <= expired_clause)
+ )
def set_tag_end_ms(tag, end_ms):
- """ Sets the end timestamp for a tag. Should only be called by change_tag_expiration
+ """ Sets the end timestamp for a tag. Should only be called by change_tag_expiration
or tests.
"""
- with db_transaction():
- updated = (Tag
- .update(lifetime_end_ms=end_ms)
- .where(Tag.id == tag)
- .where(Tag.lifetime_end_ms == tag.lifetime_end_ms)
- .execute())
- if updated != 1:
- return (None, False)
+ with db_transaction():
+ updated = (
+ Tag.update(lifetime_end_ms=end_ms)
+ .where(Tag.id == tag)
+ .where(Tag.lifetime_end_ms == tag.lifetime_end_ms)
+ .execute()
+ )
+ if updated != 1:
+ return (None, False)
- # TODO: Remove the linkage code once RepositoryTag is gone.
- try:
- old_style_tag = (TagToRepositoryTag
- .select(TagToRepositoryTag, RepositoryTag)
- .join(RepositoryTag)
- .where(TagToRepositoryTag.tag == tag)
- .get()).repository_tag
+ # TODO: Remove the linkage code once RepositoryTag is gone.
+ try:
+ old_style_tag = (
+ TagToRepositoryTag.select(TagToRepositoryTag, RepositoryTag)
+ .join(RepositoryTag)
+ .where(TagToRepositoryTag.tag == tag)
+ .get()
+ ).repository_tag
- old_style_tag.lifetime_end_ts = end_ms / 1000 if end_ms is not None else None
- old_style_tag.save()
- except TagToRepositoryTag.DoesNotExist:
- pass
+ old_style_tag.lifetime_end_ts = (
+ end_ms / 1000 if end_ms is not None else None
+ )
+ old_style_tag.save()
+ except TagToRepositoryTag.DoesNotExist:
+ pass
- return (tag.lifetime_end_ms, True)
+ return (tag.lifetime_end_ms, True)
def tags_containing_legacy_image(image):
- """ Yields all alive Tags containing the given image as a legacy image, somewhere in its
+ """ Yields all alive Tags containing the given image as a legacy image, somewhere in its
legacy image hierarchy.
"""
- ancestors_str = '%s%s/%%' % (image.ancestors, image.id)
- tags = (Tag
- .select()
- .join(Repository)
- .switch(Tag)
- .join(Manifest)
- .join(ManifestLegacyImage)
- .join(Image)
- .where(Tag.repository == image.repository_id)
- .where(Image.repository == image.repository_id)
- .where((Image.id == image.id) |
- (Image.ancestors ** ancestors_str)))
- return filter_to_alive_tags(tags)
+ ancestors_str = "%s%s/%%" % (image.ancestors, image.id)
+ tags = (
+ Tag.select()
+ .join(Repository)
+ .switch(Tag)
+ .join(Manifest)
+ .join(ManifestLegacyImage)
+ .join(Image)
+ .where(Tag.repository == image.repository_id)
+ .where(Image.repository == image.repository_id)
+ .where((Image.id == image.id) | (Image.ancestors ** ancestors_str))
+ )
+ return filter_to_alive_tags(tags)
def lookup_notifiable_tags_for_legacy_image(docker_image_id, storage_uuid, event_name):
- """ Yields any alive Tags found in repositories with an event with the given name registered
+ """ Yields any alive Tags found in repositories with an event with the given name registered
and whose legacy Image has the given docker image ID and storage UUID.
"""
- event = ExternalNotificationEvent.get(name=event_name)
- images = (Image
- .select()
- .join(ImageStorage)
- .where(Image.docker_image_id == docker_image_id,
- ImageStorage.uuid == storage_uuid))
+ event = ExternalNotificationEvent.get(name=event_name)
+ images = (
+ Image.select()
+ .join(ImageStorage)
+ .where(
+ Image.docker_image_id == docker_image_id, ImageStorage.uuid == storage_uuid
+ )
+ )
- for image in list(images):
- # Ensure the image is under a repository that supports the event.
- try:
- RepositoryNotification.get(repository=image.repository_id, event=event)
- except RepositoryNotification.DoesNotExist:
- continue
+ for image in list(images):
+ # Ensure the image is under a repository that supports the event.
+ try:
+ RepositoryNotification.get(repository=image.repository_id, event=event)
+ except RepositoryNotification.DoesNotExist:
+ continue
- # If found in a repository with the valid event, yield the tag(s) that contains the image.
- for tag in tags_containing_legacy_image(image):
- yield tag
+ # If found in a repository with the valid event, yield the tag(s) that contains the image.
+ for tag in tags_containing_legacy_image(image):
+ yield tag
diff --git a/data/model/oci/test/test_oci_label.py b/data/model/oci/test/test_oci_label.py
index 2ba04521b..b849303bc 100644
--- a/data/model/oci/test/test_oci_label.py
+++ b/data/model/oci/test/test_oci_label.py
@@ -3,85 +3,108 @@ import pytest
from playhouse.test_utils import assert_query_count
from data.database import Manifest, ManifestLabel
-from data.model.oci.label import (create_manifest_label, list_manifest_labels, get_manifest_label,
- delete_manifest_label, DataModelException)
+from data.model.oci.label import (
+ create_manifest_label,
+ list_manifest_labels,
+ get_manifest_label,
+ delete_manifest_label,
+ DataModelException,
+)
from test.fixtures import *
-@pytest.mark.parametrize('key, value, source_type, expected_error', [
- ('foo', 'bar', 'manifest', None),
-
- pytest.param('..foo', 'bar', 'manifest', None, id='invalid key on manifest'),
- pytest.param('..foo', 'bar', 'api', 'is invalid', id='invalid key on api'),
-])
+@pytest.mark.parametrize(
+ "key, value, source_type, expected_error",
+ [
+ ("foo", "bar", "manifest", None),
+ pytest.param("..foo", "bar", "manifest", None, id="invalid key on manifest"),
+ pytest.param("..foo", "bar", "api", "is invalid", id="invalid key on api"),
+ ],
+)
def test_create_manifest_label(key, value, source_type, expected_error, initialized_db):
- manifest = Manifest.get()
+ manifest = Manifest.get()
- if expected_error:
- with pytest.raises(DataModelException) as ex:
- create_manifest_label(manifest, key, value, source_type)
+ if expected_error:
+ with pytest.raises(DataModelException) as ex:
+ create_manifest_label(manifest, key, value, source_type)
- assert ex.match(expected_error)
- return
+ assert ex.match(expected_error)
+ return
- label = create_manifest_label(manifest, key, value, source_type)
- labels = [ml.label_id for ml in ManifestLabel.select().where(ManifestLabel.manifest == manifest)]
- assert label.id in labels
+ label = create_manifest_label(manifest, key, value, source_type)
+ labels = [
+ ml.label_id
+ for ml in ManifestLabel.select().where(ManifestLabel.manifest == manifest)
+ ]
+ assert label.id in labels
- with assert_query_count(1):
- assert label in list_manifest_labels(manifest)
+ with assert_query_count(1):
+ assert label in list_manifest_labels(manifest)
- assert label not in list_manifest_labels(manifest, 'someprefix')
- assert label in list_manifest_labels(manifest, key[0:2])
+ assert label not in list_manifest_labels(manifest, "someprefix")
+ assert label in list_manifest_labels(manifest, key[0:2])
- with assert_query_count(1):
- assert get_manifest_label(label.uuid, manifest) == label
+ with assert_query_count(1):
+ assert get_manifest_label(label.uuid, manifest) == label
def test_list_manifest_labels(initialized_db):
- manifest = Manifest.get()
+ manifest = Manifest.get()
- label1 = create_manifest_label(manifest, 'foo', '1', 'manifest')
- label2 = create_manifest_label(manifest, 'bar', '2', 'api')
- label3 = create_manifest_label(manifest, 'baz', '3', 'internal')
+ label1 = create_manifest_label(manifest, "foo", "1", "manifest")
+ label2 = create_manifest_label(manifest, "bar", "2", "api")
+ label3 = create_manifest_label(manifest, "baz", "3", "internal")
- assert label1 in list_manifest_labels(manifest)
- assert label2 in list_manifest_labels(manifest)
- assert label3 in list_manifest_labels(manifest)
+ assert label1 in list_manifest_labels(manifest)
+ assert label2 in list_manifest_labels(manifest)
+ assert label3 in list_manifest_labels(manifest)
- other_manifest = Manifest.select().where(Manifest.id != manifest.id).get()
- assert label1 not in list_manifest_labels(other_manifest)
- assert label2 not in list_manifest_labels(other_manifest)
- assert label3 not in list_manifest_labels(other_manifest)
+ other_manifest = Manifest.select().where(Manifest.id != manifest.id).get()
+ assert label1 not in list_manifest_labels(other_manifest)
+ assert label2 not in list_manifest_labels(other_manifest)
+ assert label3 not in list_manifest_labels(other_manifest)
def test_get_manifest_label(initialized_db):
- found = False
- for manifest_label in ManifestLabel.select():
- assert (get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) ==
- manifest_label.label)
- assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
- found = True
+ found = False
+ for manifest_label in ManifestLabel.select():
+ assert (
+ get_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
+ == manifest_label.label
+ )
+ assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
+ found = True
- assert found
+ assert found
def test_delete_manifest_label(initialized_db):
- found = False
- for manifest_label in list(ManifestLabel.select()):
- assert (get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) ==
- manifest_label.label)
- assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
+ found = False
+ for manifest_label in list(ManifestLabel.select()):
+ assert (
+ get_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
+ == manifest_label.label
+ )
+ assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
- if manifest_label.label.source_type.mutable:
- assert delete_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
- assert manifest_label.label not in list_manifest_labels(manifest_label.manifest)
- assert get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) is None
- else:
- with pytest.raises(DataModelException):
- delete_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
+ if manifest_label.label.source_type.mutable:
+ assert delete_manifest_label(
+ manifest_label.label.uuid, manifest_label.manifest
+ )
+ assert manifest_label.label not in list_manifest_labels(
+ manifest_label.manifest
+ )
+ assert (
+ get_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
+ is None
+ )
+ else:
+ with pytest.raises(DataModelException):
+ delete_manifest_label(
+ manifest_label.label.uuid, manifest_label.manifest
+ )
- found = True
+ found = True
- assert found
+ assert found
diff --git a/data/model/oci/test/test_oci_manifest.py b/data/model/oci/test/test_oci_manifest.py
index 4c5d6ed3b..5dc4278f3 100644
--- a/data/model/oci/test/test_oci_manifest.py
+++ b/data/model/oci/test/test_oci_manifest.py
@@ -5,8 +5,16 @@ from playhouse.test_utils import assert_query_count
from app import docker_v2_signing_key, storage
from digest.digest_tools import sha256_digest
-from data.database import (Tag, ManifestBlob, ImageStorageLocation, ManifestChild,
- ImageStorage, Image, RepositoryTag, get_epoch_timestamp_ms)
+from data.database import (
+ Tag,
+ ManifestBlob,
+ ImageStorageLocation,
+ ManifestChild,
+ ImageStorage,
+ Image,
+ RepositoryTag,
+ get_epoch_timestamp_ms,
+)
from data.model.oci.manifest import lookup_manifest, get_or_create_manifest
from data.model.oci.tag import filter_to_alive_tags, get_tag
from data.model.oci.shared import get_legacy_image_for_manifest
@@ -23,538 +31,542 @@ from util.bytes import Bytes
from test.fixtures import *
+
def test_lookup_manifest(initialized_db):
- found = False
- for tag in filter_to_alive_tags(Tag.select()):
- found = True
- repo = tag.repository
- digest = tag.manifest.digest
- with assert_query_count(1):
- assert lookup_manifest(repo, digest) == tag.manifest
+ found = False
+ for tag in filter_to_alive_tags(Tag.select()):
+ found = True
+ repo = tag.repository
+ digest = tag.manifest.digest
+ with assert_query_count(1):
+ assert lookup_manifest(repo, digest) == tag.manifest
- assert found
+ assert found
- for tag in Tag.select():
- repo = tag.repository
- digest = tag.manifest.digest
- with assert_query_count(1):
- assert lookup_manifest(repo, digest, allow_dead=True) == tag.manifest
+ for tag in Tag.select():
+ repo = tag.repository
+ digest = tag.manifest.digest
+ with assert_query_count(1):
+ assert lookup_manifest(repo, digest, allow_dead=True) == tag.manifest
def test_lookup_manifest_dead_tag(initialized_db):
- dead_tag = Tag.select().where(Tag.lifetime_end_ms <= get_epoch_timestamp_ms()).get()
- assert dead_tag.lifetime_end_ms <= get_epoch_timestamp_ms()
+ dead_tag = Tag.select().where(Tag.lifetime_end_ms <= get_epoch_timestamp_ms()).get()
+ assert dead_tag.lifetime_end_ms <= get_epoch_timestamp_ms()
- assert lookup_manifest(dead_tag.repository, dead_tag.manifest.digest) is None
- assert (lookup_manifest(dead_tag.repository, dead_tag.manifest.digest, allow_dead=True) ==
- dead_tag.manifest)
+ assert lookup_manifest(dead_tag.repository, dead_tag.manifest.digest) is None
+ assert (
+ lookup_manifest(dead_tag.repository, dead_tag.manifest.digest, allow_dead=True)
+ == dead_tag.manifest
+ )
-def create_manifest_for_testing(repository, differentiation_field='1'):
- # Populate a manifest.
- layer_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [],
- })
+def create_manifest_for_testing(repository, differentiation_field="1"):
+ # Populate a manifest.
+ layer_json = json.dumps(
+ {"config": {}, "rootfs": {"type": "layers", "diff_ids": []}, "history": []}
+ )
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- remote_digest = sha256_digest('something')
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(config_digest, len(layer_json))
- builder.add_layer(remote_digest, 1234, urls=['http://hello/world' + differentiation_field])
- manifest = builder.build()
+ remote_digest = sha256_digest("something")
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(config_digest, len(layer_json))
+ builder.add_layer(
+ remote_digest, 1234, urls=["http://hello/world" + differentiation_field]
+ )
+ manifest = builder.build()
- created = get_or_create_manifest(repository, manifest, storage)
- assert created
- return created.manifest, manifest
+ created = get_or_create_manifest(repository, manifest, storage)
+ assert created
+ return created.manifest, manifest
def test_lookup_manifest_child_tag(initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
- manifest, manifest_impl = create_manifest_for_testing(repository)
+ repository = create_repository("devtable", "newrepo", None)
+ manifest, manifest_impl = create_manifest_for_testing(repository)
- # Mark the hidden tag as dead.
- hidden_tag = Tag.get(manifest=manifest, hidden=True)
- hidden_tag.lifetime_end_ms = hidden_tag.lifetime_start_ms
- hidden_tag.save()
+ # Mark the hidden tag as dead.
+ hidden_tag = Tag.get(manifest=manifest, hidden=True)
+ hidden_tag.lifetime_end_ms = hidden_tag.lifetime_start_ms
+ hidden_tag.save()
- # Ensure the manifest cannot currently be looked up, as it is not pointed to by an alive tag.
- assert lookup_manifest(repository, manifest.digest) is None
- assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
+ # Ensure the manifest cannot currently be looked up, as it is not pointed to by an alive tag.
+ assert lookup_manifest(repository, manifest.digest) is None
+ assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
- # Populate a manifest list.
- list_builder = DockerSchema2ManifestListBuilder()
- list_builder.add_manifest(manifest_impl, 'amd64', 'linux')
- manifest_list = list_builder.build()
+ # Populate a manifest list.
+ list_builder = DockerSchema2ManifestListBuilder()
+ list_builder.add_manifest(manifest_impl, "amd64", "linux")
+ manifest_list = list_builder.build()
- # Write the manifest list, which should also write the manifests themselves.
- created_tuple = get_or_create_manifest(repository, manifest_list, storage)
- assert created_tuple is not None
+ # Write the manifest list, which should also write the manifests themselves.
+ created_tuple = get_or_create_manifest(repository, manifest_list, storage)
+ assert created_tuple is not None
- # Since the manifests are not yet referenced by a tag, they cannot be found.
- assert lookup_manifest(repository, manifest.digest) is None
- assert lookup_manifest(repository, manifest_list.digest) is None
+ # Since the manifests are not yet referenced by a tag, they cannot be found.
+ assert lookup_manifest(repository, manifest.digest) is None
+ assert lookup_manifest(repository, manifest_list.digest) is None
- # Unless we ask for "dead" manifests.
- assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
- assert lookup_manifest(repository, manifest_list.digest, allow_dead=True) is not None
+ # Unless we ask for "dead" manifests.
+ assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
+ assert (
+ lookup_manifest(repository, manifest_list.digest, allow_dead=True) is not None
+ )
def _populate_blob(content):
- digest = str(sha256_digest(content))
- location = ImageStorageLocation.get(name='local_us')
- blob = store_blob_record_and_temp_link('devtable', 'newrepo', digest, location,
- len(content), 120)
- storage.put_content(['local_us'], get_layer_path(blob), content)
- return blob, digest
+ digest = str(sha256_digest(content))
+ location = ImageStorageLocation.get(name="local_us")
+ blob = store_blob_record_and_temp_link(
+ "devtable", "newrepo", digest, location, len(content), 120
+ )
+ storage.put_content(["local_us"], get_layer_path(blob), content)
+ return blob, digest
-@pytest.mark.parametrize('schema_version', [
- 1,
- 2,
-])
+@pytest.mark.parametrize("schema_version", [1, 2])
def test_get_or_create_manifest(schema_version, initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
+ repository = create_repository("devtable", "newrepo", None)
- expected_labels = {
- 'Foo': 'Bar',
- 'Baz': 'Meh',
- }
+ expected_labels = {"Foo": "Bar", "Baz": "Meh"}
- layer_json = json.dumps({
- 'id': 'somelegacyid',
- 'config': {
- 'Labels': expected_labels,
- },
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ layer_json = json.dumps(
+ {
+ "id": "somelegacyid",
+ "config": {"Labels": expected_labels},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ }
+ ],
+ }
+ )
- # Create a legacy image.
- find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
+ # Create a legacy image.
+ find_create_or_link_image("somelegacyid", repository, "devtable", {}, "local_us")
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- # Add a blob of random data.
- random_data = 'hello world'
- _, random_digest = _populate_blob(random_data)
+ # Add a blob of random data.
+ random_data = "hello world"
+ _, random_digest = _populate_blob(random_data)
- # Build the manifest.
- if schema_version == 1:
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
- builder.add_layer(random_digest, layer_json)
- sample_manifest_instance = builder.build(docker_v2_signing_key)
- elif schema_version == 2:
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(config_digest, len(layer_json))
- builder.add_layer(random_digest, len(random_data))
- sample_manifest_instance = builder.build()
+ # Build the manifest.
+ if schema_version == 1:
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "anothertag")
+ builder.add_layer(random_digest, layer_json)
+ sample_manifest_instance = builder.build(docker_v2_signing_key)
+ elif schema_version == 2:
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(config_digest, len(layer_json))
+ builder.add_layer(random_digest, len(random_data))
+ sample_manifest_instance = builder.build()
- # Create a new manifest.
- created_manifest = get_or_create_manifest(repository, sample_manifest_instance, storage)
- created = created_manifest.manifest
- newly_created = created_manifest.newly_created
+ # Create a new manifest.
+ created_manifest = get_or_create_manifest(
+ repository, sample_manifest_instance, storage
+ )
+ created = created_manifest.manifest
+ newly_created = created_manifest.newly_created
- assert newly_created
- assert created is not None
- assert created.media_type.name == sample_manifest_instance.media_type
- assert created.digest == sample_manifest_instance.digest
- assert created.manifest_bytes == sample_manifest_instance.bytes.as_encoded_str()
- assert created_manifest.labels_to_apply == expected_labels
+ assert newly_created
+ assert created is not None
+ assert created.media_type.name == sample_manifest_instance.media_type
+ assert created.digest == sample_manifest_instance.digest
+ assert created.manifest_bytes == sample_manifest_instance.bytes.as_encoded_str()
+ assert created_manifest.labels_to_apply == expected_labels
- # Verify it has a temporary tag pointing to it.
- assert Tag.get(manifest=created, hidden=True).lifetime_end_ms
+ # Verify it has a temporary tag pointing to it.
+ assert Tag.get(manifest=created, hidden=True).lifetime_end_ms
- # Verify the legacy image.
- legacy_image = get_legacy_image_for_manifest(created)
- assert legacy_image is not None
- assert legacy_image.storage.content_checksum == random_digest
+ # Verify the legacy image.
+ legacy_image = get_legacy_image_for_manifest(created)
+ assert legacy_image is not None
+ assert legacy_image.storage.content_checksum == random_digest
- # Verify the linked blobs.
- blob_digests = [mb.blob.content_checksum for mb
- in ManifestBlob.select().where(ManifestBlob.manifest == created)]
+ # Verify the linked blobs.
+ blob_digests = [
+ mb.blob.content_checksum
+ for mb in ManifestBlob.select().where(ManifestBlob.manifest == created)
+ ]
- assert random_digest in blob_digests
- if schema_version == 2:
- assert config_digest in blob_digests
+ assert random_digest in blob_digests
+ if schema_version == 2:
+ assert config_digest in blob_digests
- # Retrieve it again and ensure it is the same manifest.
- created_manifest2 = get_or_create_manifest(repository, sample_manifest_instance, storage)
- created2 = created_manifest2.manifest
- newly_created2 = created_manifest2.newly_created
+ # Retrieve it again and ensure it is the same manifest.
+ created_manifest2 = get_or_create_manifest(
+ repository, sample_manifest_instance, storage
+ )
+ created2 = created_manifest2.manifest
+ newly_created2 = created_manifest2.newly_created
- assert not newly_created2
- assert created2 == created
+ assert not newly_created2
+ assert created2 == created
- # Ensure it again has a temporary tag.
- assert Tag.get(manifest=created2, hidden=True).lifetime_end_ms
+ # Ensure it again has a temporary tag.
+ assert Tag.get(manifest=created2, hidden=True).lifetime_end_ms
- # Ensure the labels were added.
- labels = list(list_manifest_labels(created))
- assert len(labels) == 2
+ # Ensure the labels were added.
+ labels = list(list_manifest_labels(created))
+ assert len(labels) == 2
- labels_dict = {label.key: label.value for label in labels}
- assert labels_dict == expected_labels
+ labels_dict = {label.key: label.value for label in labels}
+ assert labels_dict == expected_labels
def test_get_or_create_manifest_invalid_image(initialized_db):
- repository = get_repository('devtable', 'simple')
+ repository = get_repository("devtable", "simple")
- latest_tag = get_tag(repository, 'latest')
- parsed = DockerSchema1Manifest(Bytes.for_string_or_unicode(latest_tag.manifest.manifest_bytes),
- validate=False)
+ latest_tag = get_tag(repository, "latest")
+ parsed = DockerSchema1Manifest(
+ Bytes.for_string_or_unicode(latest_tag.manifest.manifest_bytes), validate=False
+ )
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
- builder.add_layer(parsed.blob_digests[0], '{"id": "foo", "parent": "someinvalidimageid"}')
- sample_manifest_instance = builder.build(docker_v2_signing_key)
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "anothertag")
+ builder.add_layer(
+ parsed.blob_digests[0], '{"id": "foo", "parent": "someinvalidimageid"}'
+ )
+ sample_manifest_instance = builder.build(docker_v2_signing_key)
- created_manifest = get_or_create_manifest(repository, sample_manifest_instance, storage)
- assert created_manifest is None
+ created_manifest = get_or_create_manifest(
+ repository, sample_manifest_instance, storage
+ )
+ assert created_manifest is None
def test_get_or_create_manifest_list(initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
+ repository = create_repository("devtable", "newrepo", None)
- expected_labels = {
- 'Foo': 'Bar',
- 'Baz': 'Meh',
- }
+ expected_labels = {"Foo": "Bar", "Baz": "Meh"}
- layer_json = json.dumps({
- 'id': 'somelegacyid',
- 'config': {
- 'Labels': expected_labels,
- },
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ layer_json = json.dumps(
+ {
+ "id": "somelegacyid",
+ "config": {"Labels": expected_labels},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ }
+ ],
+ }
+ )
- # Create a legacy image.
- find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
+ # Create a legacy image.
+ find_create_or_link_image("somelegacyid", repository, "devtable", {}, "local_us")
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- # Add a blob of random data.
- random_data = 'hello world'
- _, random_digest = _populate_blob(random_data)
+ # Add a blob of random data.
+ random_data = "hello world"
+ _, random_digest = _populate_blob(random_data)
- # Build the manifests.
- v1_builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
- v1_builder.add_layer(random_digest, layer_json)
- v1_manifest = v1_builder.build(docker_v2_signing_key).unsigned()
+ # Build the manifests.
+ v1_builder = DockerSchema1ManifestBuilder("devtable", "simple", "anothertag")
+ v1_builder.add_layer(random_digest, layer_json)
+ v1_manifest = v1_builder.build(docker_v2_signing_key).unsigned()
- v2_builder = DockerSchema2ManifestBuilder()
- v2_builder.set_config_digest(config_digest, len(layer_json))
- v2_builder.add_layer(random_digest, len(random_data))
- v2_manifest = v2_builder.build()
+ v2_builder = DockerSchema2ManifestBuilder()
+ v2_builder.set_config_digest(config_digest, len(layer_json))
+ v2_builder.add_layer(random_digest, len(random_data))
+ v2_manifest = v2_builder.build()
- # Write the manifests.
- v1_created = get_or_create_manifest(repository, v1_manifest, storage)
- assert v1_created
- assert v1_created.manifest.digest == v1_manifest.digest
+ # Write the manifests.
+ v1_created = get_or_create_manifest(repository, v1_manifest, storage)
+ assert v1_created
+ assert v1_created.manifest.digest == v1_manifest.digest
- v2_created = get_or_create_manifest(repository, v2_manifest, storage)
- assert v2_created
- assert v2_created.manifest.digest == v2_manifest.digest
+ v2_created = get_or_create_manifest(repository, v2_manifest, storage)
+ assert v2_created
+ assert v2_created.manifest.digest == v2_manifest.digest
- # Build the manifest list.
- list_builder = DockerSchema2ManifestListBuilder()
- list_builder.add_manifest(v1_manifest, 'amd64', 'linux')
- list_builder.add_manifest(v2_manifest, 'amd32', 'linux')
- manifest_list = list_builder.build()
+ # Build the manifest list.
+ list_builder = DockerSchema2ManifestListBuilder()
+ list_builder.add_manifest(v1_manifest, "amd64", "linux")
+ list_builder.add_manifest(v2_manifest, "amd32", "linux")
+ manifest_list = list_builder.build()
- # Write the manifest list, which should also write the manifests themselves.
- created_tuple = get_or_create_manifest(repository, manifest_list, storage)
- assert created_tuple is not None
+ # Write the manifest list, which should also write the manifests themselves.
+ created_tuple = get_or_create_manifest(repository, manifest_list, storage)
+ assert created_tuple is not None
- created_list = created_tuple.manifest
- assert created_list
- assert created_list.media_type.name == manifest_list.media_type
- assert created_list.digest == manifest_list.digest
+ created_list = created_tuple.manifest
+ assert created_list
+ assert created_list.media_type.name == manifest_list.media_type
+ assert created_list.digest == manifest_list.digest
- # Ensure the child manifest links exist.
- child_manifests = {cm.child_manifest.digest: cm.child_manifest
- for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)}
- assert len(child_manifests) == 2
- assert v1_manifest.digest in child_manifests
- assert v2_manifest.digest in child_manifests
+ # Ensure the child manifest links exist.
+ child_manifests = {
+ cm.child_manifest.digest: cm.child_manifest
+ for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)
+ }
+ assert len(child_manifests) == 2
+ assert v1_manifest.digest in child_manifests
+ assert v2_manifest.digest in child_manifests
- assert child_manifests[v1_manifest.digest].media_type.name == v1_manifest.media_type
- assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
+ assert child_manifests[v1_manifest.digest].media_type.name == v1_manifest.media_type
+ assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
def test_get_or_create_manifest_list_duplicate_child_manifest(initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
+ repository = create_repository("devtable", "newrepo", None)
- expected_labels = {
- 'Foo': 'Bar',
- 'Baz': 'Meh',
- }
+ expected_labels = {"Foo": "Bar", "Baz": "Meh"}
- layer_json = json.dumps({
- 'id': 'somelegacyid',
- 'config': {
- 'Labels': expected_labels,
- },
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ layer_json = json.dumps(
+ {
+ "id": "somelegacyid",
+ "config": {"Labels": expected_labels},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ }
+ ],
+ }
+ )
- # Create a legacy image.
- find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
+ # Create a legacy image.
+ find_create_or_link_image("somelegacyid", repository, "devtable", {}, "local_us")
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- # Add a blob of random data.
- random_data = 'hello world'
- _, random_digest = _populate_blob(random_data)
+ # Add a blob of random data.
+ random_data = "hello world"
+ _, random_digest = _populate_blob(random_data)
- # Build the manifest.
- v2_builder = DockerSchema2ManifestBuilder()
- v2_builder.set_config_digest(config_digest, len(layer_json))
- v2_builder.add_layer(random_digest, len(random_data))
- v2_manifest = v2_builder.build()
+ # Build the manifest.
+ v2_builder = DockerSchema2ManifestBuilder()
+ v2_builder.set_config_digest(config_digest, len(layer_json))
+ v2_builder.add_layer(random_digest, len(random_data))
+ v2_manifest = v2_builder.build()
- # Write the manifest.
- v2_created = get_or_create_manifest(repository, v2_manifest, storage)
- assert v2_created
- assert v2_created.manifest.digest == v2_manifest.digest
+ # Write the manifest.
+ v2_created = get_or_create_manifest(repository, v2_manifest, storage)
+ assert v2_created
+ assert v2_created.manifest.digest == v2_manifest.digest
- # Build the manifest list, with the child manifest repeated.
- list_builder = DockerSchema2ManifestListBuilder()
- list_builder.add_manifest(v2_manifest, 'amd64', 'linux')
- list_builder.add_manifest(v2_manifest, 'amd32', 'linux')
- manifest_list = list_builder.build()
+ # Build the manifest list, with the child manifest repeated.
+ list_builder = DockerSchema2ManifestListBuilder()
+ list_builder.add_manifest(v2_manifest, "amd64", "linux")
+ list_builder.add_manifest(v2_manifest, "amd32", "linux")
+ manifest_list = list_builder.build()
- # Write the manifest list, which should also write the manifests themselves.
- created_tuple = get_or_create_manifest(repository, manifest_list, storage)
- assert created_tuple is not None
+ # Write the manifest list, which should also write the manifests themselves.
+ created_tuple = get_or_create_manifest(repository, manifest_list, storage)
+ assert created_tuple is not None
- created_list = created_tuple.manifest
- assert created_list
- assert created_list.media_type.name == manifest_list.media_type
- assert created_list.digest == manifest_list.digest
+ created_list = created_tuple.manifest
+ assert created_list
+ assert created_list.media_type.name == manifest_list.media_type
+ assert created_list.digest == manifest_list.digest
- # Ensure the child manifest links exist.
- child_manifests = {cm.child_manifest.digest: cm.child_manifest
- for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)}
- assert len(child_manifests) == 1
- assert v2_manifest.digest in child_manifests
- assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
+ # Ensure the child manifest links exist.
+ child_manifests = {
+ cm.child_manifest.digest: cm.child_manifest
+ for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)
+ }
+ assert len(child_manifests) == 1
+ assert v2_manifest.digest in child_manifests
+ assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
- # Try to create again and ensure we get back the same manifest list.
- created2_tuple = get_or_create_manifest(repository, manifest_list, storage)
- assert created2_tuple is not None
- assert created2_tuple.manifest == created_list
+ # Try to create again and ensure we get back the same manifest list.
+ created2_tuple = get_or_create_manifest(repository, manifest_list, storage)
+ assert created2_tuple is not None
+ assert created2_tuple.manifest == created_list
def test_get_or_create_manifest_with_remote_layers(initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
+ repository = create_repository("devtable", "newrepo", None)
- layer_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ layer_json = json.dumps(
+ {
+ "config": {},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ },
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ },
+ ],
+ }
+ )
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- # Add a blob of random data.
- random_data = 'hello world'
- _, random_digest = _populate_blob(random_data)
+ # Add a blob of random data.
+ random_data = "hello world"
+ _, random_digest = _populate_blob(random_data)
- remote_digest = sha256_digest('something')
+ remote_digest = sha256_digest("something")
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(config_digest, len(layer_json))
- builder.add_layer(remote_digest, 1234, urls=['http://hello/world'])
- builder.add_layer(random_digest, len(random_data))
- manifest = builder.build()
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(config_digest, len(layer_json))
+ builder.add_layer(remote_digest, 1234, urls=["http://hello/world"])
+ builder.add_layer(random_digest, len(random_data))
+ manifest = builder.build()
- assert remote_digest in manifest.blob_digests
- assert remote_digest not in manifest.local_blob_digests
+ assert remote_digest in manifest.blob_digests
+ assert remote_digest not in manifest.local_blob_digests
- assert manifest.has_remote_layer
- assert not manifest.has_legacy_image
- assert manifest.get_schema1_manifest('foo', 'bar', 'baz', None) is None
+ assert manifest.has_remote_layer
+ assert not manifest.has_legacy_image
+ assert manifest.get_schema1_manifest("foo", "bar", "baz", None) is None
- # Write the manifest.
- created_tuple = get_or_create_manifest(repository, manifest, storage)
- assert created_tuple is not None
+ # Write the manifest.
+ created_tuple = get_or_create_manifest(repository, manifest, storage)
+ assert created_tuple is not None
- created_manifest = created_tuple.manifest
- assert created_manifest
- assert created_manifest.media_type.name == manifest.media_type
- assert created_manifest.digest == manifest.digest
+ created_manifest = created_tuple.manifest
+ assert created_manifest
+ assert created_manifest.media_type.name == manifest.media_type
+ assert created_manifest.digest == manifest.digest
- # Verify the legacy image.
- legacy_image = get_legacy_image_for_manifest(created_manifest)
- assert legacy_image is None
+ # Verify the legacy image.
+ legacy_image = get_legacy_image_for_manifest(created_manifest)
+ assert legacy_image is None
- # Verify the linked blobs.
- blob_digests = {mb.blob.content_checksum for mb
- in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)}
+ # Verify the linked blobs.
+ blob_digests = {
+ mb.blob.content_checksum
+ for mb in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)
+ }
- assert random_digest in blob_digests
- assert config_digest in blob_digests
- assert remote_digest not in blob_digests
+ assert random_digest in blob_digests
+ assert config_digest in blob_digests
+ assert remote_digest not in blob_digests
-def create_manifest_for_testing(repository, differentiation_field='1', include_shared_blob=False):
- # Populate a manifest.
- layer_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [],
- })
+def create_manifest_for_testing(
+ repository, differentiation_field="1", include_shared_blob=False
+):
+ # Populate a manifest.
+ layer_json = json.dumps(
+ {"config": {}, "rootfs": {"type": "layers", "diff_ids": []}, "history": []}
+ )
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- remote_digest = sha256_digest('something')
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(config_digest, len(layer_json))
- builder.add_layer(remote_digest, 1234, urls=['http://hello/world' + differentiation_field])
+ remote_digest = sha256_digest("something")
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(config_digest, len(layer_json))
+ builder.add_layer(
+ remote_digest, 1234, urls=["http://hello/world" + differentiation_field]
+ )
- if include_shared_blob:
- _, blob_digest = _populate_blob('some data here')
- builder.add_layer(blob_digest, 4567)
+ if include_shared_blob:
+ _, blob_digest = _populate_blob("some data here")
+ builder.add_layer(blob_digest, 4567)
- manifest = builder.build()
+ manifest = builder.build()
- created = get_or_create_manifest(repository, manifest, storage)
- assert created
- return created.manifest, manifest
+ created = get_or_create_manifest(repository, manifest, storage)
+ assert created
+ return created.manifest, manifest
def test_retriever(initialized_db):
- repository = create_repository('devtable', 'newrepo', None)
+ repository = create_repository("devtable", "newrepo", None)
- layer_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ layer_json = json.dumps(
+ {
+ "config": {},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ },
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ },
+ ],
+ }
+ )
- # Add a blob containing the config.
- _, config_digest = _populate_blob(layer_json)
+ # Add a blob containing the config.
+ _, config_digest = _populate_blob(layer_json)
- # Add a blob of random data.
- random_data = 'hello world'
- _, random_digest = _populate_blob(random_data)
+ # Add a blob of random data.
+ random_data = "hello world"
+ _, random_digest = _populate_blob(random_data)
- # Add another blob of random data.
- other_random_data = 'hi place'
- _, other_random_digest = _populate_blob(other_random_data)
+ # Add another blob of random data.
+ other_random_data = "hi place"
+ _, other_random_digest = _populate_blob(other_random_data)
- remote_digest = sha256_digest('something')
+ remote_digest = sha256_digest("something")
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(config_digest, len(layer_json))
- builder.add_layer(other_random_digest, len(other_random_data))
- builder.add_layer(random_digest, len(random_data))
- manifest = builder.build()
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(config_digest, len(layer_json))
+ builder.add_layer(other_random_digest, len(other_random_data))
+ builder.add_layer(random_digest, len(random_data))
+ manifest = builder.build()
- assert config_digest in manifest.blob_digests
- assert random_digest in manifest.blob_digests
- assert other_random_digest in manifest.blob_digests
+ assert config_digest in manifest.blob_digests
+ assert random_digest in manifest.blob_digests
+ assert other_random_digest in manifest.blob_digests
- assert config_digest in manifest.local_blob_digests
- assert random_digest in manifest.local_blob_digests
- assert other_random_digest in manifest.local_blob_digests
+ assert config_digest in manifest.local_blob_digests
+ assert random_digest in manifest.local_blob_digests
+ assert other_random_digest in manifest.local_blob_digests
- # Write the manifest.
- created_tuple = get_or_create_manifest(repository, manifest, storage)
- assert created_tuple is not None
+ # Write the manifest.
+ created_tuple = get_or_create_manifest(repository, manifest, storage)
+ assert created_tuple is not None
- created_manifest = created_tuple.manifest
- assert created_manifest
- assert created_manifest.media_type.name == manifest.media_type
- assert created_manifest.digest == manifest.digest
+ created_manifest = created_tuple.manifest
+ assert created_manifest
+ assert created_manifest.media_type.name == manifest.media_type
+ assert created_manifest.digest == manifest.digest
- # Verify the linked blobs.
- blob_digests = {mb.blob.content_checksum for mb
- in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)}
+ # Verify the linked blobs.
+ blob_digests = {
+ mb.blob.content_checksum
+ for mb in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)
+ }
- assert random_digest in blob_digests
- assert other_random_digest in blob_digests
- assert config_digest in blob_digests
+ assert random_digest in blob_digests
+ assert other_random_digest in blob_digests
+ assert config_digest in blob_digests
- # Delete any Image rows linking to the blobs from temp tags.
- for blob_digest in blob_digests:
- storage_row = ImageStorage.get(content_checksum=blob_digest)
- for image in list(Image.select().where(Image.storage == storage_row)):
- all_temp = all([rt.hidden for rt
- in RepositoryTag.select().where(RepositoryTag.image == image)])
- if all_temp:
- RepositoryTag.delete().where(RepositoryTag.image == image).execute()
- image.delete_instance(recursive=True)
+ # Delete any Image rows linking to the blobs from temp tags.
+ for blob_digest in blob_digests:
+ storage_row = ImageStorage.get(content_checksum=blob_digest)
+ for image in list(Image.select().where(Image.storage == storage_row)):
+ all_temp = all(
+ [
+ rt.hidden
+ for rt in RepositoryTag.select().where(RepositoryTag.image == image)
+ ]
+ )
+ if all_temp:
+ RepositoryTag.delete().where(RepositoryTag.image == image).execute()
+ image.delete_instance(recursive=True)
- # Verify the blobs in the retriever.
- retriever = RepositoryContentRetriever(repository, storage)
- assert (retriever.get_manifest_bytes_with_digest(created_manifest.digest) ==
- manifest.bytes.as_encoded_str())
+ # Verify the blobs in the retriever.
+ retriever = RepositoryContentRetriever(repository, storage)
+ assert (
+ retriever.get_manifest_bytes_with_digest(created_manifest.digest)
+ == manifest.bytes.as_encoded_str()
+ )
- for blob_digest in blob_digests:
- assert retriever.get_blob_bytes_with_digest(blob_digest) is not None
+ for blob_digest in blob_digests:
+ assert retriever.get_blob_bytes_with_digest(blob_digest) is not None
diff --git a/data/model/oci/test/test_oci_tag.py b/data/model/oci/test/test_oci_tag.py
index d37828cf7..b61d44c04 100644
--- a/data/model/oci/test/test_oci_tag.py
+++ b/data/model/oci/test/test_oci_tag.py
@@ -3,376 +3,417 @@ from datetime import timedelta, datetime
from playhouse.test_utils import assert_query_count
-from data.database import (Tag, ManifestLegacyImage, TagToRepositoryTag, TagManifestToManifest,
- TagManifest, Manifest, Repository)
+from data.database import (
+ Tag,
+ ManifestLegacyImage,
+ TagToRepositoryTag,
+ TagManifestToManifest,
+ TagManifest,
+ Manifest,
+ Repository,
+)
from data.model.oci.test.test_oci_manifest import create_manifest_for_testing
-from data.model.oci.tag import (find_matching_tag, get_most_recent_tag,
- get_most_recent_tag_lifetime_start, list_alive_tags,
- get_legacy_images_for_tags, filter_to_alive_tags,
- filter_to_visible_tags, list_repository_tag_history,
- get_expired_tag, get_tag, delete_tag,
- delete_tags_for_manifest, change_tag_expiration,
- set_tag_expiration_for_manifest, retarget_tag,
- create_temporary_tag_if_necessary,
- lookup_alive_tags_shallow,
- lookup_unrecoverable_tags,
- get_epoch_timestamp_ms)
+from data.model.oci.tag import (
+ find_matching_tag,
+ get_most_recent_tag,
+ get_most_recent_tag_lifetime_start,
+ list_alive_tags,
+ get_legacy_images_for_tags,
+ filter_to_alive_tags,
+ filter_to_visible_tags,
+ list_repository_tag_history,
+ get_expired_tag,
+ get_tag,
+ delete_tag,
+ delete_tags_for_manifest,
+ change_tag_expiration,
+ set_tag_expiration_for_manifest,
+ retarget_tag,
+ create_temporary_tag_if_necessary,
+ lookup_alive_tags_shallow,
+ lookup_unrecoverable_tags,
+ get_epoch_timestamp_ms,
+)
from data.model.repository import get_repository, create_repository
from test.fixtures import *
-@pytest.mark.parametrize('namespace_name, repo_name, tag_names, expected', [
- ('devtable', 'simple', ['latest'], 'latest'),
- ('devtable', 'simple', ['unknown', 'latest'], 'latest'),
- ('devtable', 'simple', ['unknown'], None),
-])
-def test_find_matching_tag(namespace_name, repo_name, tag_names, expected, initialized_db):
- repo = get_repository(namespace_name, repo_name)
- if expected is not None:
- with assert_query_count(1):
- found = find_matching_tag(repo, tag_names)
- assert found is not None
- assert found.name == expected
- assert not found.lifetime_end_ms
- else:
- with assert_query_count(1):
- assert find_matching_tag(repo, tag_names) is None
+@pytest.mark.parametrize(
+ "namespace_name, repo_name, tag_names, expected",
+ [
+ ("devtable", "simple", ["latest"], "latest"),
+ ("devtable", "simple", ["unknown", "latest"], "latest"),
+ ("devtable", "simple", ["unknown"], None),
+ ],
+)
+def test_find_matching_tag(
+ namespace_name, repo_name, tag_names, expected, initialized_db
+):
+ repo = get_repository(namespace_name, repo_name)
+ if expected is not None:
+ with assert_query_count(1):
+ found = find_matching_tag(repo, tag_names)
+
+ assert found is not None
+ assert found.name == expected
+ assert not found.lifetime_end_ms
+ else:
+ with assert_query_count(1):
+ assert find_matching_tag(repo, tag_names) is None
def test_get_most_recent_tag_lifetime_start(initialized_db):
- repo = get_repository('devtable', 'simple')
- tag = get_most_recent_tag(repo)
+ repo = get_repository("devtable", "simple")
+ tag = get_most_recent_tag(repo)
- with assert_query_count(1):
- tags = get_most_recent_tag_lifetime_start([repo])
- assert tags[repo.id] == tag.lifetime_start_ms
+ with assert_query_count(1):
+ tags = get_most_recent_tag_lifetime_start([repo])
+ assert tags[repo.id] == tag.lifetime_start_ms
def test_get_most_recent_tag(initialized_db):
- repo = get_repository('outsideorg', 'coolrepo')
+ repo = get_repository("outsideorg", "coolrepo")
- with assert_query_count(1):
- assert get_most_recent_tag(repo).name == 'latest'
+ with assert_query_count(1):
+ assert get_most_recent_tag(repo).name == "latest"
def test_get_most_recent_tag_empty_repo(initialized_db):
- empty_repo = create_repository('devtable', 'empty', None)
+ empty_repo = create_repository("devtable", "empty", None)
- with assert_query_count(1):
- assert get_most_recent_tag(empty_repo) is None
+ with assert_query_count(1):
+ assert get_most_recent_tag(empty_repo) is None
def test_list_alive_tags(initialized_db):
- found = False
- for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
+ found = False
+ for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
+ tags = list_alive_tags(tag.repository)
+ assert tag in tags
+
+ with assert_query_count(1):
+ legacy_images = get_legacy_images_for_tags(tags)
+
+ for tag in tags:
+ assert (
+ ManifestLegacyImage.get(manifest=tag.manifest).image
+ == legacy_images[tag.id]
+ )
+
+ found = True
+
+ assert found
+
+ # Ensure hidden tags cannot be listed.
+ tag = Tag.get()
+ tag.hidden = True
+ tag.save()
+
tags = list_alive_tags(tag.repository)
- assert tag in tags
-
- with assert_query_count(1):
- legacy_images = get_legacy_images_for_tags(tags)
-
- for tag in tags:
- assert ManifestLegacyImage.get(manifest=tag.manifest).image == legacy_images[tag.id]
-
- found = True
-
- assert found
-
- # Ensure hidden tags cannot be listed.
- tag = Tag.get()
- tag.hidden = True
- tag.save()
-
- tags = list_alive_tags(tag.repository)
- assert tag not in tags
+ assert tag not in tags
def test_lookup_alive_tags_shallow(initialized_db):
- found = False
- for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
+ found = False
+ for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
+ tags = lookup_alive_tags_shallow(tag.repository)
+ found = True
+ assert tag in tags
+
+ assert found
+
+ # Ensure hidden tags cannot be listed.
+ tag = Tag.get()
+ tag.hidden = True
+ tag.save()
+
tags = lookup_alive_tags_shallow(tag.repository)
- found = True
- assert tag in tags
-
- assert found
-
- # Ensure hidden tags cannot be listed.
- tag = Tag.get()
- tag.hidden = True
- tag.save()
-
- tags = lookup_alive_tags_shallow(tag.repository)
- assert tag not in tags
+ assert tag not in tags
def test_get_tag(initialized_db):
- found = False
- for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
- repo = tag.repository
+ found = False
+ for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
+ repo = tag.repository
+
+ with assert_query_count(1):
+ assert get_tag(repo, tag.name) == tag
+ found = True
+
+ assert found
+
+
+@pytest.mark.parametrize(
+ "namespace_name, repo_name", [("devtable", "simple"), ("devtable", "complex")]
+)
+def test_list_repository_tag_history(namespace_name, repo_name, initialized_db):
+ repo = get_repository(namespace_name, repo_name)
with assert_query_count(1):
- assert get_tag(repo, tag.name) == tag
- found = True
+ results, has_more = list_repository_tag_history(repo, 1, 100)
- assert found
-
-
-@pytest.mark.parametrize('namespace_name, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
-])
-def test_list_repository_tag_history(namespace_name, repo_name, initialized_db):
- repo = get_repository(namespace_name, repo_name)
-
- with assert_query_count(1):
- results, has_more = list_repository_tag_history(repo, 1, 100)
-
- assert results
- assert not has_more
+ assert results
+ assert not has_more
def test_list_repository_tag_history_with_history(initialized_db):
- repo = get_repository('devtable', 'history')
+ repo = get_repository("devtable", "history")
- with assert_query_count(1):
- results, _ = list_repository_tag_history(repo, 1, 100)
+ with assert_query_count(1):
+ results, _ = list_repository_tag_history(repo, 1, 100)
- assert len(results) == 2
- assert results[0].lifetime_end_ms is None
- assert results[1].lifetime_end_ms is not None
+ assert len(results) == 2
+ assert results[0].lifetime_end_ms is None
+ assert results[1].lifetime_end_ms is not None
- with assert_query_count(1):
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
+ with assert_query_count(1):
+ results, _ = list_repository_tag_history(
+ repo, 1, 100, specific_tag_name="latest"
+ )
- assert len(results) == 2
- assert results[0].lifetime_end_ms is None
- assert results[1].lifetime_end_ms is not None
+ assert len(results) == 2
+ assert results[0].lifetime_end_ms is None
+ assert results[1].lifetime_end_ms is not None
- with assert_query_count(1):
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='foobar')
+ with assert_query_count(1):
+ results, _ = list_repository_tag_history(
+ repo, 1, 100, specific_tag_name="foobar"
+ )
- assert len(results) == 0
+ assert len(results) == 0
def test_list_repository_tag_history_all_tags(initialized_db):
- for tag in Tag.select():
- repo = tag.repository
- with assert_query_count(1):
- results, _ = list_repository_tag_history(repo, 1, 1000)
+ for tag in Tag.select():
+ repo = tag.repository
+ with assert_query_count(1):
+ results, _ = list_repository_tag_history(repo, 1, 1000)
- assert (tag in results) == (not tag.hidden)
+ assert (tag in results) == (not tag.hidden)
-@pytest.mark.parametrize('namespace_name, repo_name, tag_name, expected', [
- ('devtable', 'simple', 'latest', False),
- ('devtable', 'simple', 'unknown', False),
- ('devtable', 'complex', 'latest', False),
-
- ('devtable', 'history', 'latest', True),
-])
+@pytest.mark.parametrize(
+ "namespace_name, repo_name, tag_name, expected",
+ [
+ ("devtable", "simple", "latest", False),
+ ("devtable", "simple", "unknown", False),
+ ("devtable", "complex", "latest", False),
+ ("devtable", "history", "latest", True),
+ ],
+)
def test_get_expired_tag(namespace_name, repo_name, tag_name, expected, initialized_db):
- repo = get_repository(namespace_name, repo_name)
+ repo = get_repository(namespace_name, repo_name)
- with assert_query_count(1):
- assert bool(get_expired_tag(repo, tag_name)) == expected
+ with assert_query_count(1):
+ assert bool(get_expired_tag(repo, tag_name)) == expected
def test_delete_tag(initialized_db):
- found = False
- for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
- repo = tag.repository
+ found = False
+ for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
+ repo = tag.repository
- assert get_tag(repo, tag.name) == tag
- assert tag.lifetime_end_ms is None
+ assert get_tag(repo, tag.name) == tag
+ assert tag.lifetime_end_ms is None
- with assert_query_count(4):
- assert delete_tag(repo, tag.name) == tag
+ with assert_query_count(4):
+ assert delete_tag(repo, tag.name) == tag
- assert get_tag(repo, tag.name) is None
- found = True
+ assert get_tag(repo, tag.name) is None
+ found = True
- assert found
+ assert found
def test_delete_tags_for_manifest(initialized_db):
- for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
- repo = tag.repository
- assert get_tag(repo, tag.name) == tag
+ for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
+ repo = tag.repository
+ assert get_tag(repo, tag.name) == tag
- with assert_query_count(5):
- assert delete_tags_for_manifest(tag.manifest) == [tag]
+ with assert_query_count(5):
+ assert delete_tags_for_manifest(tag.manifest) == [tag]
- assert get_tag(repo, tag.name) is None
+ assert get_tag(repo, tag.name) is None
def test_delete_tags_for_manifest_same_manifest(initialized_db):
- new_repo = model.repository.create_repository('devtable', 'newrepo', None)
- manifest_1, _ = create_manifest_for_testing(new_repo, '1')
- manifest_2, _ = create_manifest_for_testing(new_repo, '2')
+ new_repo = model.repository.create_repository("devtable", "newrepo", None)
+ manifest_1, _ = create_manifest_for_testing(new_repo, "1")
+ manifest_2, _ = create_manifest_for_testing(new_repo, "2")
- assert manifest_1.digest != manifest_2.digest
+ assert manifest_1.digest != manifest_2.digest
- # Add some tag history, moving a tag back and forth between two manifests.
- retarget_tag('latest', manifest_1)
- retarget_tag('latest', manifest_2)
- retarget_tag('latest', manifest_1)
- retarget_tag('latest', manifest_2)
+ # Add some tag history, moving a tag back and forth between two manifests.
+ retarget_tag("latest", manifest_1)
+ retarget_tag("latest", manifest_2)
+ retarget_tag("latest", manifest_1)
+ retarget_tag("latest", manifest_2)
- retarget_tag('another1', manifest_1)
- retarget_tag('another2', manifest_2)
+ retarget_tag("another1", manifest_1)
+ retarget_tag("another2", manifest_2)
- # Delete all tags pointing to the first manifest.
- delete_tags_for_manifest(manifest_1)
+ # Delete all tags pointing to the first manifest.
+ delete_tags_for_manifest(manifest_1)
- assert get_tag(new_repo, 'latest').manifest == manifest_2
- assert get_tag(new_repo, 'another1') is None
- assert get_tag(new_repo, 'another2').manifest == manifest_2
+ assert get_tag(new_repo, "latest").manifest == manifest_2
+ assert get_tag(new_repo, "another1") is None
+ assert get_tag(new_repo, "another2").manifest == manifest_2
- # Delete all tags pointing to the second manifest, which should actually delete the `latest`
- # tag now.
- delete_tags_for_manifest(manifest_2)
- assert get_tag(new_repo, 'latest') is None
- assert get_tag(new_repo, 'another1') is None
- assert get_tag(new_repo, 'another2') is None
+ # Delete all tags pointing to the second manifest, which should actually delete the `latest`
+ # tag now.
+ delete_tags_for_manifest(manifest_2)
+ assert get_tag(new_repo, "latest") is None
+ assert get_tag(new_repo, "another1") is None
+ assert get_tag(new_repo, "another2") is None
-@pytest.mark.parametrize('timedelta, expected_timedelta', [
- pytest.param(timedelta(seconds=1), timedelta(hours=1), id='less than minimum'),
- pytest.param(timedelta(weeks=300), timedelta(weeks=104), id='more than maxium'),
- pytest.param(timedelta(weeks=1), timedelta(weeks=1), id='within range'),
-])
+@pytest.mark.parametrize(
+ "timedelta, expected_timedelta",
+ [
+ pytest.param(timedelta(seconds=1), timedelta(hours=1), id="less than minimum"),
+ pytest.param(timedelta(weeks=300), timedelta(weeks=104), id="more than maxium"),
+ pytest.param(timedelta(weeks=1), timedelta(weeks=1), id="within range"),
+ ],
+)
def test_change_tag_expiration(timedelta, expected_timedelta, initialized_db):
- now = datetime.utcnow()
- now_ms = timegm(now.utctimetuple()) * 1000
+ now = datetime.utcnow()
+ now_ms = timegm(now.utctimetuple()) * 1000
- tag = Tag.get()
- tag.lifetime_start_ms = now_ms
- tag.save()
+ tag = Tag.get()
+ tag.lifetime_start_ms = now_ms
+ tag.save()
- original_end_ms, okay = change_tag_expiration(tag, now + timedelta)
- assert okay
- assert original_end_ms == tag.lifetime_end_ms
+ original_end_ms, okay = change_tag_expiration(tag, now + timedelta)
+ assert okay
+ assert original_end_ms == tag.lifetime_end_ms
- updated_tag = Tag.get(id=tag.id)
- offset = expected_timedelta.total_seconds() * 1000
- expected_ms = (updated_tag.lifetime_start_ms + offset)
- assert updated_tag.lifetime_end_ms == expected_ms
+ updated_tag = Tag.get(id=tag.id)
+ offset = expected_timedelta.total_seconds() * 1000
+ expected_ms = updated_tag.lifetime_start_ms + offset
+ assert updated_tag.lifetime_end_ms == expected_ms
- original_end_ms, okay = change_tag_expiration(tag, None)
- assert okay
- assert original_end_ms == expected_ms
+ original_end_ms, okay = change_tag_expiration(tag, None)
+ assert okay
+ assert original_end_ms == expected_ms
- updated_tag = Tag.get(id=tag.id)
- assert updated_tag.lifetime_end_ms is None
+ updated_tag = Tag.get(id=tag.id)
+ assert updated_tag.lifetime_end_ms is None
def test_set_tag_expiration_for_manifest(initialized_db):
- tag = Tag.get()
- manifest = tag.manifest
- assert manifest is not None
+ tag = Tag.get()
+ manifest = tag.manifest
+ assert manifest is not None
- set_tag_expiration_for_manifest(manifest, datetime.utcnow() + timedelta(weeks=1))
+ set_tag_expiration_for_manifest(manifest, datetime.utcnow() + timedelta(weeks=1))
- updated_tag = Tag.get(id=tag.id)
- assert updated_tag.lifetime_end_ms is not None
+ updated_tag = Tag.get(id=tag.id)
+ assert updated_tag.lifetime_end_ms is not None
def test_create_temporary_tag_if_necessary(initialized_db):
- tag = Tag.get()
- manifest = tag.manifest
- assert manifest is not None
+ tag = Tag.get()
+ manifest = tag.manifest
+ assert manifest is not None
- # Ensure no tag is created, since an existing one is present.
- created = create_temporary_tag_if_necessary(manifest, 60)
- assert created is None
+ # Ensure no tag is created, since an existing one is present.
+ created = create_temporary_tag_if_necessary(manifest, 60)
+ assert created is None
- # Mark the tag as deleted.
- tag.lifetime_end_ms = 1
- tag.save()
+ # Mark the tag as deleted.
+ tag.lifetime_end_ms = 1
+ tag.save()
- # Now create a temp tag.
- created = create_temporary_tag_if_necessary(manifest, 60)
- assert created is not None
- assert created.hidden
- assert created.name.startswith('$temp-')
- assert created.manifest == manifest
- assert created.lifetime_end_ms is not None
- assert created.lifetime_end_ms == (created.lifetime_start_ms + 60000)
+ # Now create a temp tag.
+ created = create_temporary_tag_if_necessary(manifest, 60)
+ assert created is not None
+ assert created.hidden
+ assert created.name.startswith("$temp-")
+ assert created.manifest == manifest
+ assert created.lifetime_end_ms is not None
+ assert created.lifetime_end_ms == (created.lifetime_start_ms + 60000)
- # Try again and ensure it is not created.
- created = create_temporary_tag_if_necessary(manifest, 30)
- assert created is None
+ # Try again and ensure it is not created.
+ created = create_temporary_tag_if_necessary(manifest, 30)
+ assert created is None
def test_retarget_tag(initialized_db):
- repo = get_repository('devtable', 'history')
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
+ repo = get_repository("devtable", "history")
+ results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name="latest")
- assert len(results) == 2
- assert results[0].lifetime_end_ms is None
- assert results[1].lifetime_end_ms is not None
+ assert len(results) == 2
+ assert results[0].lifetime_end_ms is None
+ assert results[1].lifetime_end_ms is not None
- # Revert back to the original manifest.
- created = retarget_tag('latest', results[0].manifest, is_reversion=True,
- now_ms=results[1].lifetime_end_ms + 10000)
- assert created.lifetime_end_ms is None
- assert created.reversion
- assert created.name == 'latest'
- assert created.manifest == results[0].manifest
+ # Revert back to the original manifest.
+ created = retarget_tag(
+ "latest",
+ results[0].manifest,
+ is_reversion=True,
+ now_ms=results[1].lifetime_end_ms + 10000,
+ )
+ assert created.lifetime_end_ms is None
+ assert created.reversion
+ assert created.name == "latest"
+ assert created.manifest == results[0].manifest
- # Verify in the history.
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
+ # Verify in the history.
+ results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name="latest")
- assert len(results) == 3
- assert results[0].lifetime_end_ms is None
- assert results[1].lifetime_end_ms is not None
- assert results[2].lifetime_end_ms is not None
+ assert len(results) == 3
+ assert results[0].lifetime_end_ms is None
+ assert results[1].lifetime_end_ms is not None
+ assert results[2].lifetime_end_ms is not None
- assert results[0] == created
+ assert results[0] == created
- # Verify old-style tables.
- repository_tag = TagToRepositoryTag.get(tag=created).repository_tag
- assert repository_tag.lifetime_start_ts == int(created.lifetime_start_ms / 1000)
+ # Verify old-style tables.
+ repository_tag = TagToRepositoryTag.get(tag=created).repository_tag
+ assert repository_tag.lifetime_start_ts == int(created.lifetime_start_ms / 1000)
- tag_manifest = TagManifest.get(tag=repository_tag)
- assert TagManifestToManifest.get(tag_manifest=tag_manifest).manifest == created.manifest
+ tag_manifest = TagManifest.get(tag=repository_tag)
+ assert (
+ TagManifestToManifest.get(tag_manifest=tag_manifest).manifest
+ == created.manifest
+ )
def test_retarget_tag_wrong_name(initialized_db):
- repo = get_repository('devtable', 'history')
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
- assert len(results) == 2
+ repo = get_repository("devtable", "history")
+ results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name="latest")
+ assert len(results) == 2
- created = retarget_tag('someothername', results[1].manifest, is_reversion=True)
- assert created is None
+ created = retarget_tag("someothername", results[1].manifest, is_reversion=True)
+ assert created is None
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
- assert len(results) == 2
+ results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name="latest")
+ assert len(results) == 2
def test_lookup_unrecoverable_tags(initialized_db):
- # Ensure no existing tags are found.
- for repo in Repository.select():
- assert not list(lookup_unrecoverable_tags(repo))
+ # Ensure no existing tags are found.
+ for repo in Repository.select():
+ assert not list(lookup_unrecoverable_tags(repo))
- # Mark a tag as outside the expiration window and ensure it is found.
- repo = get_repository('devtable', 'history')
- results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
- assert len(results) == 2
+ # Mark a tag as outside the expiration window and ensure it is found.
+ repo = get_repository("devtable", "history")
+ results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name="latest")
+ assert len(results) == 2
- results[1].lifetime_end_ms = 1
- results[1].save()
+ results[1].lifetime_end_ms = 1
+ results[1].save()
- # Ensure the tag is now found.
- found = list(lookup_unrecoverable_tags(repo))
- assert found
- assert len(found) == 1
- assert found[0] == results[1]
+ # Ensure the tag is now found.
+ found = list(lookup_unrecoverable_tags(repo))
+ assert found
+ assert len(found) == 1
+ assert found[0] == results[1]
- # Mark the tag as expiring in the future and ensure it is no longer found.
- results[1].lifetime_end_ms = get_epoch_timestamp_ms() + 1000000
- results[1].save()
+ # Mark the tag as expiring in the future and ensure it is no longer found.
+ results[1].lifetime_end_ms = get_epoch_timestamp_ms() + 1000000
+ results[1].save()
- found = list(lookup_unrecoverable_tags(repo))
- assert not found
+ found = list(lookup_unrecoverable_tags(repo))
+ assert not found
diff --git a/data/model/organization.py b/data/model/organization.py
index b42f0d454..a3543e459 100644
--- a/data/model/organization.py
+++ b/data/model/organization.py
@@ -1,167 +1,208 @@
-
-from data.database import (User, FederatedLogin, TeamMember, Team, TeamRole, RepositoryPermission,
- Repository, Namespace, DeletedNamespace)
-from data.model import (user, team, DataModelException, InvalidOrganizationException,
- InvalidUsernameException, db_transaction, _basequery)
+from data.database import (
+ User,
+ FederatedLogin,
+ TeamMember,
+ Team,
+ TeamRole,
+ RepositoryPermission,
+ Repository,
+ Namespace,
+ DeletedNamespace,
+)
+from data.model import (
+ user,
+ team,
+ DataModelException,
+ InvalidOrganizationException,
+ InvalidUsernameException,
+ db_transaction,
+ _basequery,
+)
-def create_organization(name, email, creating_user, email_required=True, is_possible_abuser=False):
- with db_transaction():
- try:
- # Create the org
- new_org = user.create_user_noverify(name, email, email_required=email_required,
- is_possible_abuser=is_possible_abuser)
- new_org.organization = True
- new_org.save()
+def create_organization(
+ name, email, creating_user, email_required=True, is_possible_abuser=False
+):
+ with db_transaction():
+ try:
+ # Create the org
+ new_org = user.create_user_noverify(
+ name,
+ email,
+ email_required=email_required,
+ is_possible_abuser=is_possible_abuser,
+ )
+ new_org.organization = True
+ new_org.save()
- # Create a team for the owners
- owners_team = team.create_team('owners', new_org, 'admin')
+ # Create a team for the owners
+ owners_team = team.create_team("owners", new_org, "admin")
- # Add the user who created the org to the owners team
- team.add_user_to_team(creating_user, owners_team)
+ # Add the user who created the org to the owners team
+ team.add_user_to_team(creating_user, owners_team)
- return new_org
- except InvalidUsernameException as iue:
- raise InvalidOrganizationException(iue.message)
+ return new_org
+ except InvalidUsernameException as iue:
+ raise InvalidOrganizationException(iue.message)
def get_organization(name):
- try:
- return User.get(username=name, organization=True)
- except User.DoesNotExist:
- raise InvalidOrganizationException('Organization does not exist: %s' %
- name)
+ try:
+ return User.get(username=name, organization=True)
+ except User.DoesNotExist:
+ raise InvalidOrganizationException("Organization does not exist: %s" % name)
def convert_user_to_organization(user_obj, admin_user):
- if user_obj.robot:
- raise DataModelException('Cannot convert a robot into an organization')
+ if user_obj.robot:
+ raise DataModelException("Cannot convert a robot into an organization")
- with db_transaction():
- # Change the user to an organization and disable this account for login.
- user_obj.organization = True
- user_obj.password_hash = None
- user_obj.save()
+ with db_transaction():
+ # Change the user to an organization and disable this account for login.
+ user_obj.organization = True
+ user_obj.password_hash = None
+ user_obj.save()
- # Clear any federated auth pointing to this user.
- FederatedLogin.delete().where(FederatedLogin.user == user_obj).execute()
+ # Clear any federated auth pointing to this user.
+ FederatedLogin.delete().where(FederatedLogin.user == user_obj).execute()
- # Delete any user-specific permissions on repositories.
- (RepositoryPermission.delete()
- .where(RepositoryPermission.user == user_obj)
- .execute())
+ # Delete any user-specific permissions on repositories.
+ (
+ RepositoryPermission.delete()
+ .where(RepositoryPermission.user == user_obj)
+ .execute()
+ )
- # Create a team for the owners
- owners_team = team.create_team('owners', user_obj, 'admin')
+ # Create a team for the owners
+ owners_team = team.create_team("owners", user_obj, "admin")
- # Add the user who will admin the org to the owners team
- team.add_user_to_team(admin_user, owners_team)
+ # Add the user who will admin the org to the owners team
+ team.add_user_to_team(admin_user, owners_team)
- return user_obj
+ return user_obj
def get_user_organizations(username):
- return _basequery.get_user_organizations(username)
+ return _basequery.get_user_organizations(username)
+
def get_organization_team_members(teamid):
- joined = User.select().join(TeamMember).join(Team)
- query = joined.where(Team.id == teamid)
- return query
+ joined = User.select().join(TeamMember).join(Team)
+ query = joined.where(Team.id == teamid)
+ return query
def __get_org_admin_users(org):
- return (User
- .select()
- .join(TeamMember)
- .join(Team)
- .join(TeamRole)
- .where(Team.organization == org, TeamRole.name == 'admin', User.robot == False)
- .distinct())
+ return (
+ User.select()
+ .join(TeamMember)
+ .join(Team)
+ .join(TeamRole)
+ .where(Team.organization == org, TeamRole.name == "admin", User.robot == False)
+ .distinct()
+ )
+
def get_admin_users(org):
- """ Returns the owner users for the organization. """
- return __get_org_admin_users(org)
+ """ Returns the owner users for the organization. """
+ return __get_org_admin_users(org)
+
def remove_organization_member(org, user_obj):
- org_admins = [u.username for u in __get_org_admin_users(org)]
- if len(org_admins) == 1 and user_obj.username in org_admins:
- raise DataModelException('Cannot remove user as they are the only organization admin')
+ org_admins = [u.username for u in __get_org_admin_users(org)]
+ if len(org_admins) == 1 and user_obj.username in org_admins:
+ raise DataModelException(
+ "Cannot remove user as they are the only organization admin"
+ )
- with db_transaction():
- # Find and remove the user from any repositories under the org.
- permissions = list(RepositoryPermission
- .select(RepositoryPermission.id)
- .join(Repository)
- .where(Repository.namespace_user == org,
- RepositoryPermission.user == user_obj))
+ with db_transaction():
+ # Find and remove the user from any repositories under the org.
+ permissions = list(
+ RepositoryPermission.select(RepositoryPermission.id)
+ .join(Repository)
+ .where(
+ Repository.namespace_user == org, RepositoryPermission.user == user_obj
+ )
+ )
- if permissions:
- RepositoryPermission.delete().where(RepositoryPermission.id << permissions).execute()
+ if permissions:
+ RepositoryPermission.delete().where(
+ RepositoryPermission.id << permissions
+ ).execute()
- # Find and remove the user from any teams under the org.
- members = list(TeamMember
- .select(TeamMember.id)
- .join(Team)
- .where(Team.organization == org, TeamMember.user == user_obj))
+ # Find and remove the user from any teams under the org.
+ members = list(
+ TeamMember.select(TeamMember.id)
+ .join(Team)
+ .where(Team.organization == org, TeamMember.user == user_obj)
+ )
- if members:
- TeamMember.delete().where(TeamMember.id << members).execute()
+ if members:
+ TeamMember.delete().where(TeamMember.id << members).execute()
def get_organization_member_set(org, include_robots=False, users_filter=None):
- """ Returns the set of all member usernames under the given organization, with optional
+ """ Returns the set of all member usernames under the given organization, with optional
filtering by robots and/or by a specific set of User objects.
"""
- Org = User.alias()
- org_users = (User
- .select(User.username)
- .join(TeamMember)
- .join(Team)
- .where(Team.organization == org)
- .distinct())
+ Org = User.alias()
+ org_users = (
+ User.select(User.username)
+ .join(TeamMember)
+ .join(Team)
+ .where(Team.organization == org)
+ .distinct()
+ )
- if not include_robots:
- org_users = org_users.where(User.robot == False)
+ if not include_robots:
+ org_users = org_users.where(User.robot == False)
- if users_filter is not None:
- ids_list = [u.id for u in users_filter if u is not None]
- if not ids_list:
- return set()
+ if users_filter is not None:
+ ids_list = [u.id for u in users_filter if u is not None]
+ if not ids_list:
+ return set()
- org_users = org_users.where(User.id << ids_list)
+ org_users = org_users.where(User.id << ids_list)
- return {user.username for user in org_users}
+ return {user.username for user in org_users}
def get_all_repo_users_transitive_via_teams(namespace_name, repository_name):
- return (User
- .select()
- .distinct()
- .join(TeamMember)
- .join(Team)
- .join(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ return (
+ User.select()
+ .distinct()
+ .join(TeamMember)
+ .join(Team)
+ .join(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
def get_organizations(deleted=False):
- query = User.select().where(User.organization == True, User.robot == False)
+ query = User.select().where(User.organization == True, User.robot == False)
- if not deleted:
- query = query.where(User.id.not_in(DeletedNamespace.select(DeletedNamespace.namespace)))
+ if not deleted:
+ query = query.where(
+ User.id.not_in(DeletedNamespace.select(DeletedNamespace.namespace))
+ )
- return query
+ return query
def get_active_org_count():
- return get_organizations().count()
+ return get_organizations().count()
def add_user_as_admin(user_obj, org_obj):
- try:
- admin_role = TeamRole.get(name='admin')
- admin_team = Team.select().where(Team.role == admin_role, Team.organization == org_obj).get()
- team.add_user_to_team(user_obj, admin_team)
- except team.UserAlreadyInTeam:
- pass
+ try:
+ admin_role = TeamRole.get(name="admin")
+ admin_team = (
+ Team.select()
+ .where(Team.role == admin_role, Team.organization == org_obj)
+ .get()
+ )
+ team.add_user_to_team(user_obj, admin_team)
+ except team.UserAlreadyInTeam:
+ pass
diff --git a/data/model/permission.py b/data/model/permission.py
index e38584561..159110f16 100644
--- a/data/model/permission.py
+++ b/data/model/permission.py
@@ -1,322 +1,387 @@
from peewee import JOIN
-from data.database import (RepositoryPermission, User, Repository, Visibility, Role, TeamMember,
- PermissionPrototype, Team, TeamRole, Namespace)
+from data.database import (
+ RepositoryPermission,
+ User,
+ Repository,
+ Visibility,
+ Role,
+ TeamMember,
+ PermissionPrototype,
+ Team,
+ TeamRole,
+ Namespace,
+)
from data.model import DataModelException, _basequery
from util.names import parse_robot_username
+
def list_team_permissions(team):
- return (RepositoryPermission
- .select(RepositoryPermission)
- .join(Repository)
- .join(Visibility)
- .switch(RepositoryPermission)
- .join(Role)
- .switch(RepositoryPermission)
- .where(RepositoryPermission.team == team))
+ return (
+ RepositoryPermission.select(RepositoryPermission)
+ .join(Repository)
+ .join(Visibility)
+ .switch(RepositoryPermission)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .where(RepositoryPermission.team == team)
+ )
def list_robot_permissions(robot_name):
- return (RepositoryPermission
- .select(RepositoryPermission, User, Repository)
- .join(Repository)
- .join(Visibility)
- .switch(RepositoryPermission)
- .join(Role)
- .switch(RepositoryPermission)
- .join(User)
- .where(User.username == robot_name, User.robot == True))
+ return (
+ RepositoryPermission.select(RepositoryPermission, User, Repository)
+ .join(Repository)
+ .join(Visibility)
+ .switch(RepositoryPermission)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .join(User)
+ .where(User.username == robot_name, User.robot == True)
+ )
def list_organization_member_permissions(organization, limit_to_user=None):
- query = (RepositoryPermission
- .select(RepositoryPermission, Repository, User)
- .join(Repository)
- .switch(RepositoryPermission)
- .join(User)
- .where(Repository.namespace_user == organization))
+ query = (
+ RepositoryPermission.select(RepositoryPermission, Repository, User)
+ .join(Repository)
+ .switch(RepositoryPermission)
+ .join(User)
+ .where(Repository.namespace_user == organization)
+ )
- if limit_to_user is not None:
- query = query.where(RepositoryPermission.user == limit_to_user)
- else:
- query = query.where(User.robot == False)
+ if limit_to_user is not None:
+ query = query.where(RepositoryPermission.user == limit_to_user)
+ else:
+ query = query.where(User.robot == False)
- return query
+ return query
def get_all_user_repository_permissions(user):
- return _get_user_repo_permissions(user)
+ return _get_user_repo_permissions(user)
def get_user_repo_permissions(user, repo):
- return _get_user_repo_permissions(user, limit_to_repository_obj=repo)
+ return _get_user_repo_permissions(user, limit_to_repository_obj=repo)
def get_user_repository_permissions(user, namespace, repo_name):
- return _get_user_repo_permissions(user, limit_namespace=namespace, limit_repo_name=repo_name)
+ return _get_user_repo_permissions(
+ user, limit_namespace=namespace, limit_repo_name=repo_name
+ )
-def _get_user_repo_permissions(user, limit_to_repository_obj=None, limit_namespace=None,
- limit_repo_name=None):
- UserThroughTeam = User.alias()
+def _get_user_repo_permissions(
+ user, limit_to_repository_obj=None, limit_namespace=None, limit_repo_name=None
+):
+ UserThroughTeam = User.alias()
- base_query = (RepositoryPermission
- .select(RepositoryPermission, Role, Repository, Namespace)
- .join(Role)
- .switch(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(RepositoryPermission))
+ base_query = (
+ RepositoryPermission.select(RepositoryPermission, Role, Repository, Namespace)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(RepositoryPermission)
+ )
- if limit_to_repository_obj is not None:
- base_query = base_query.where(RepositoryPermission.repository == limit_to_repository_obj)
- elif limit_namespace and limit_repo_name:
- base_query = base_query.where(Repository.name == limit_repo_name,
- Namespace.username == limit_namespace)
+ if limit_to_repository_obj is not None:
+ base_query = base_query.where(
+ RepositoryPermission.repository == limit_to_repository_obj
+ )
+ elif limit_namespace and limit_repo_name:
+ base_query = base_query.where(
+ Repository.name == limit_repo_name, Namespace.username == limit_namespace
+ )
- direct = (base_query
- .clone()
- .join(User)
- .where(User.id == user))
+ direct = base_query.clone().join(User).where(User.id == user)
- team = (base_query
- .clone()
- .join(Team)
- .join(TeamMember)
- .join(UserThroughTeam, on=(UserThroughTeam.id == TeamMember.user))
- .where(UserThroughTeam.id == user))
+ team = (
+ base_query.clone()
+ .join(Team)
+ .join(TeamMember)
+ .join(UserThroughTeam, on=(UserThroughTeam.id == TeamMember.user))
+ .where(UserThroughTeam.id == user)
+ )
- return direct | team
+ return direct | team
def delete_prototype_permission(org, uid):
- found = get_prototype_permission(org, uid)
- if not found:
- return None
+ found = get_prototype_permission(org, uid)
+ if not found:
+ return None
- found.delete_instance()
- return found
+ found.delete_instance()
+ return found
def get_prototype_permission(org, uid):
- try:
- return PermissionPrototype.get(PermissionPrototype.org == org,
- PermissionPrototype.uuid == uid)
- except PermissionPrototype.DoesNotExist:
- return None
+ try:
+ return PermissionPrototype.get(
+ PermissionPrototype.org == org, PermissionPrototype.uuid == uid
+ )
+ except PermissionPrototype.DoesNotExist:
+ return None
def get_prototype_permissions(org):
- ActivatingUser = User.alias()
- DelegateUser = User.alias()
- query = (PermissionPrototype
- .select()
- .where(PermissionPrototype.org == org)
- .join(ActivatingUser, JOIN.LEFT_OUTER,
- on=(ActivatingUser.id == PermissionPrototype.activating_user))
- .join(DelegateUser, JOIN.LEFT_OUTER,
- on=(DelegateUser.id == PermissionPrototype.delegate_user))
- .join(Team, JOIN.LEFT_OUTER,
- on=(Team.id == PermissionPrototype.delegate_team))
- .join(Role, JOIN.LEFT_OUTER, on=(Role.id == PermissionPrototype.role)))
- return query
+ ActivatingUser = User.alias()
+ DelegateUser = User.alias()
+ query = (
+ PermissionPrototype.select()
+ .where(PermissionPrototype.org == org)
+ .join(
+ ActivatingUser,
+ JOIN.LEFT_OUTER,
+ on=(ActivatingUser.id == PermissionPrototype.activating_user),
+ )
+ .join(
+ DelegateUser,
+ JOIN.LEFT_OUTER,
+ on=(DelegateUser.id == PermissionPrototype.delegate_user),
+ )
+ .join(Team, JOIN.LEFT_OUTER, on=(Team.id == PermissionPrototype.delegate_team))
+ .join(Role, JOIN.LEFT_OUTER, on=(Role.id == PermissionPrototype.role))
+ )
+ return query
def update_prototype_permission(org, uid, role_name):
- found = get_prototype_permission(org, uid)
- if not found:
- return None
+ found = get_prototype_permission(org, uid)
+ if not found:
+ return None
- new_role = Role.get(Role.name == role_name)
- found.role = new_role
- found.save()
- return found
+ new_role = Role.get(Role.name == role_name)
+ found.role = new_role
+ found.save()
+ return found
-def add_prototype_permission(org, role_name, activating_user,
- delegate_user=None, delegate_team=None):
- new_role = Role.get(Role.name == role_name)
- return PermissionPrototype.create(org=org, role=new_role, activating_user=activating_user,
- delegate_user=delegate_user, delegate_team=delegate_team)
+def add_prototype_permission(
+ org, role_name, activating_user, delegate_user=None, delegate_team=None
+):
+ new_role = Role.get(Role.name == role_name)
+ return PermissionPrototype.create(
+ org=org,
+ role=new_role,
+ activating_user=activating_user,
+ delegate_user=delegate_user,
+ delegate_team=delegate_team,
+ )
def get_org_wide_permissions(user, org_filter=None):
- Org = User.alias()
- team_with_role = Team.select(Team, Org, TeamRole).join(TeamRole)
- with_org = team_with_role.switch(Team).join(Org, on=(Team.organization ==
- Org.id))
- with_user = with_org.switch(Team).join(TeamMember).join(User)
+ Org = User.alias()
+ team_with_role = Team.select(Team, Org, TeamRole).join(TeamRole)
+ with_org = team_with_role.switch(Team).join(Org, on=(Team.organization == Org.id))
+ with_user = with_org.switch(Team).join(TeamMember).join(User)
- if org_filter:
- with_user.where(Org.username == org_filter)
+ if org_filter:
+ with_user.where(Org.username == org_filter)
- return with_user.where(User.id == user, Org.organization == True)
+ return with_user.where(User.id == user, Org.organization == True)
def get_all_repo_teams(namespace_name, repository_name):
- return (RepositoryPermission
- .select(Team.name, Role.name, RepositoryPermission)
- .join(Team)
- .switch(RepositoryPermission)
- .join(Role)
- .switch(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ return (
+ RepositoryPermission.select(Team.name, Role.name, RepositoryPermission)
+ .join(Team)
+ .switch(RepositoryPermission)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
def apply_default_permissions(repo_obj, creating_user_obj):
- org = repo_obj.namespace_user
- user_clause = ((PermissionPrototype.activating_user == creating_user_obj) |
- (PermissionPrototype.activating_user >> None))
+ org = repo_obj.namespace_user
+ user_clause = (PermissionPrototype.activating_user == creating_user_obj) | (
+ PermissionPrototype.activating_user >> None
+ )
- team_protos = (PermissionPrototype
- .select()
- .where(PermissionPrototype.org == org, user_clause,
- PermissionPrototype.delegate_user >> None))
+ team_protos = PermissionPrototype.select().where(
+ PermissionPrototype.org == org,
+ user_clause,
+ PermissionPrototype.delegate_user >> None,
+ )
- def create_team_permission(team, repo, role):
- RepositoryPermission.create(team=team, repository=repo, role=role)
+ def create_team_permission(team, repo, role):
+ RepositoryPermission.create(team=team, repository=repo, role=role)
- __apply_permission_list(repo_obj, team_protos, 'name', create_team_permission)
+ __apply_permission_list(repo_obj, team_protos, "name", create_team_permission)
- user_protos = (PermissionPrototype
- .select()
- .where(PermissionPrototype.org == org, user_clause,
- PermissionPrototype.delegate_team >> None))
+ user_protos = PermissionPrototype.select().where(
+ PermissionPrototype.org == org,
+ user_clause,
+ PermissionPrototype.delegate_team >> None,
+ )
- def create_user_permission(user, repo, role):
- # The creating user always gets admin anyway
- if user.username == creating_user_obj.username:
- return
+ def create_user_permission(user, repo, role):
+ # The creating user always gets admin anyway
+ if user.username == creating_user_obj.username:
+ return
- RepositoryPermission.create(user=user, repository=repo, role=role)
+ RepositoryPermission.create(user=user, repository=repo, role=role)
- __apply_permission_list(repo_obj, user_protos, 'username', create_user_permission)
+ __apply_permission_list(repo_obj, user_protos, "username", create_user_permission)
def __apply_permission_list(repo, proto_query, name_property, create_permission_func):
- final_protos = {}
- for proto in proto_query:
- applies_to = proto.delegate_team or proto.delegate_user
- name = getattr(applies_to, name_property)
- # We will skip the proto if it is pre-empted by a more important proto
- if name in final_protos and proto.activating_user is None:
- continue
+ final_protos = {}
+ for proto in proto_query:
+ applies_to = proto.delegate_team or proto.delegate_user
+ name = getattr(applies_to, name_property)
+ # We will skip the proto if it is pre-empted by a more important proto
+ if name in final_protos and proto.activating_user is None:
+ continue
- # By this point, it is either a user specific proto, or there is no
- # proto yet, so we can safely assume it applies
- final_protos[name] = (applies_to, proto.role)
+ # By this point, it is either a user specific proto, or there is no
+ # proto yet, so we can safely assume it applies
+ final_protos[name] = (applies_to, proto.role)
- for delegate, role in final_protos.values():
- create_permission_func(delegate, repo, role)
+ for delegate, role in final_protos.values():
+ create_permission_func(delegate, repo, role)
-def __entity_permission_repo_query(entity_id, entity_table, entity_id_property, namespace_name,
- repository_name):
- """ This method works for both users and teams. """
+def __entity_permission_repo_query(
+ entity_id, entity_table, entity_id_property, namespace_name, repository_name
+):
+ """ This method works for both users and teams. """
- return (RepositoryPermission
- .select(entity_table, Repository, Namespace, Role, RepositoryPermission)
- .join(entity_table)
- .switch(RepositoryPermission)
- .join(Role)
- .switch(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.name == repository_name, Namespace.username == namespace_name,
- entity_id_property == entity_id))
+ return (
+ RepositoryPermission.select(
+ entity_table, Repository, Namespace, Role, RepositoryPermission
+ )
+ .join(entity_table)
+ .switch(RepositoryPermission)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ Repository.name == repository_name,
+ Namespace.username == namespace_name,
+ entity_id_property == entity_id,
+ )
+ )
def get_user_reponame_permission(username, namespace_name, repository_name):
- fetched = list(__entity_permission_repo_query(username, User, User.username, namespace_name,
- repository_name))
- if not fetched:
- raise DataModelException('User does not have permission for repo.')
+ fetched = list(
+ __entity_permission_repo_query(
+ username, User, User.username, namespace_name, repository_name
+ )
+ )
+ if not fetched:
+ raise DataModelException("User does not have permission for repo.")
- return fetched[0]
+ return fetched[0]
def get_team_reponame_permission(team_name, namespace_name, repository_name):
- fetched = list(__entity_permission_repo_query(team_name, Team, Team.name, namespace_name,
- repository_name))
- if not fetched:
- raise DataModelException('Team does not have permission for repo.')
+ fetched = list(
+ __entity_permission_repo_query(
+ team_name, Team, Team.name, namespace_name, repository_name
+ )
+ )
+ if not fetched:
+ raise DataModelException("Team does not have permission for repo.")
- return fetched[0]
+ return fetched[0]
def delete_user_permission(username, namespace_name, repository_name):
- if username == namespace_name:
- raise DataModelException('Namespace owner must always be admin.')
+ if username == namespace_name:
+ raise DataModelException("Namespace owner must always be admin.")
- fetched = list(__entity_permission_repo_query(username, User, User.username, namespace_name,
- repository_name))
- if not fetched:
- raise DataModelException('User does not have permission for repo.')
+ fetched = list(
+ __entity_permission_repo_query(
+ username, User, User.username, namespace_name, repository_name
+ )
+ )
+ if not fetched:
+ raise DataModelException("User does not have permission for repo.")
- fetched[0].delete_instance()
+ fetched[0].delete_instance()
def delete_team_permission(team_name, namespace_name, repository_name):
- fetched = list(__entity_permission_repo_query(team_name, Team, Team.name, namespace_name,
- repository_name))
- if not fetched:
- raise DataModelException('Team does not have permission for repo.')
+ fetched = list(
+ __entity_permission_repo_query(
+ team_name, Team, Team.name, namespace_name, repository_name
+ )
+ )
+ if not fetched:
+ raise DataModelException("Team does not have permission for repo.")
- fetched[0].delete_instance()
+ fetched[0].delete_instance()
-def __set_entity_repo_permission(entity, permission_entity_property,
- namespace_name, repository_name, role_name):
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- new_role = Role.get(Role.name == role_name)
+def __set_entity_repo_permission(
+ entity, permission_entity_property, namespace_name, repository_name, role_name
+):
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ new_role = Role.get(Role.name == role_name)
- # Fetch any existing permission for this entity on the repo
- try:
- entity_attr = getattr(RepositoryPermission, permission_entity_property)
- perm = RepositoryPermission.get(entity_attr == entity, RepositoryPermission.repository == repo)
- perm.role = new_role
- perm.save()
- return perm
- except RepositoryPermission.DoesNotExist:
- set_entity_kwargs = {permission_entity_property: entity}
- new_perm = RepositoryPermission.create(repository=repo, role=new_role, **set_entity_kwargs)
- return new_perm
+ # Fetch any existing permission for this entity on the repo
+ try:
+ entity_attr = getattr(RepositoryPermission, permission_entity_property)
+ perm = RepositoryPermission.get(
+ entity_attr == entity, RepositoryPermission.repository == repo
+ )
+ perm.role = new_role
+ perm.save()
+ return perm
+ except RepositoryPermission.DoesNotExist:
+ set_entity_kwargs = {permission_entity_property: entity}
+ new_perm = RepositoryPermission.create(
+ repository=repo, role=new_role, **set_entity_kwargs
+ )
+ return new_perm
def set_user_repo_permission(username, namespace_name, repository_name, role_name):
- if username == namespace_name:
- raise DataModelException('Namespace owner must always be admin.')
+ if username == namespace_name:
+ raise DataModelException("Namespace owner must always be admin.")
- try:
- user = User.get(User.username == username)
- except User.DoesNotExist:
- raise DataModelException('Invalid username: %s' % username)
+ try:
+ user = User.get(User.username == username)
+ except User.DoesNotExist:
+ raise DataModelException("Invalid username: %s" % username)
- if user.robot:
- parts = parse_robot_username(user.username)
- if not parts:
- raise DataModelException('Invalid robot: %s' % username)
+ if user.robot:
+ parts = parse_robot_username(user.username)
+ if not parts:
+ raise DataModelException("Invalid robot: %s" % username)
- robot_namespace, _ = parts
- if robot_namespace != namespace_name:
- raise DataModelException('Cannot add robot %s under namespace %s' %
- (username, namespace_name))
+ robot_namespace, _ = parts
+ if robot_namespace != namespace_name:
+ raise DataModelException(
+ "Cannot add robot %s under namespace %s" % (username, namespace_name)
+ )
- return __set_entity_repo_permission(user, 'user', namespace_name, repository_name, role_name)
+ return __set_entity_repo_permission(
+ user, "user", namespace_name, repository_name, role_name
+ )
def set_team_repo_permission(team_name, namespace_name, repository_name, role_name):
- try:
- team = (Team
- .select()
+ try:
+ team = (
+ Team.select()
.join(User)
.where(Team.name == team_name, User.username == namespace_name)
- .get())
- except Team.DoesNotExist:
- raise DataModelException('No team %s in organization %s' % (team_name, namespace_name))
-
- return __set_entity_repo_permission(team, 'team', namespace_name, repository_name, role_name)
-
+ .get()
+ )
+ except Team.DoesNotExist:
+ raise DataModelException(
+ "No team %s in organization %s" % (team_name, namespace_name)
+ )
+ return __set_entity_repo_permission(
+ team, "team", namespace_name, repository_name, role_name
+ )
diff --git a/data/model/release.py b/data/model/release.py
index f827eaeb0..79c5ecf5e 100644
--- a/data/model/release.py
+++ b/data/model/release.py
@@ -2,20 +2,22 @@ from data.database import QuayRelease, QuayRegion, QuayService
def set_region_release(service_name, region_name, version):
- service, _ = QuayService.get_or_create(name=service_name)
- region, _ = QuayRegion.get_or_create(name=region_name)
+ service, _ = QuayService.get_or_create(name=service_name)
+ region, _ = QuayRegion.get_or_create(name=region_name)
- return QuayRelease.get_or_create(service=service, version=version, region=region)
+ return QuayRelease.get_or_create(service=service, version=version, region=region)
def get_recent_releases(service_name, region_name):
- return (QuayRelease
- .select(QuayRelease)
- .join(QuayService)
- .switch(QuayRelease)
- .join(QuayRegion)
- .where(QuayService.name == service_name,
- QuayRegion.name == region_name,
- QuayRelease.reverted == False,
- )
- .order_by(QuayRelease.created.desc()))
+ return (
+ QuayRelease.select(QuayRelease)
+ .join(QuayService)
+ .switch(QuayRelease)
+ .join(QuayRegion)
+ .where(
+ QuayService.name == service_name,
+ QuayRegion.name == region_name,
+ QuayRelease.reverted == False,
+ )
+ .order_by(QuayRelease.created.desc())
+ )
diff --git a/data/model/repo_mirror.py b/data/model/repo_mirror.py
index a9824f3ab..151057a22 100644
--- a/data/model/repo_mirror.py
+++ b/data/model/repo_mirror.py
@@ -5,8 +5,16 @@ from datetime import datetime, timedelta
from peewee import IntegrityError, fn
from jsonschema import ValidationError
-from data.database import (RepoMirrorConfig, RepoMirrorRule, RepoMirrorRuleType, RepoMirrorStatus,
- RepositoryState, Repository, uuid_generator, db_transaction)
+from data.database import (
+ RepoMirrorConfig,
+ RepoMirrorRule,
+ RepoMirrorRuleType,
+ RepoMirrorStatus,
+ RepositoryState,
+ Repository,
+ uuid_generator,
+ db_transaction,
+)
from data.fields import DecryptedValue
from data.model import DataModelException
from util.names import parse_robot_username
@@ -14,75 +22,87 @@ from util.names import parse_robot_username
# TODO: Move these to the configuration
MAX_SYNC_RETRIES = 3
-MAX_SYNC_DURATION = 60*60*2 # 2 Hours
+MAX_SYNC_DURATION = 60 * 60 * 2 # 2 Hours
def get_eligible_mirrors():
- """
+ """
Returns the RepoMirrorConfig that are ready to run now. This includes those that are:
1. Not currently syncing but whose start time is in the past
2. Status of "sync now"
3. Currently marked as syncing but whose expiration time is in the past
"""
- now = datetime.utcnow()
- immediate_candidates_filter = ((RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNC_NOW) &
- (RepoMirrorConfig.sync_expiration_date >> None))
+ now = datetime.utcnow()
+ immediate_candidates_filter = (
+ RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNC_NOW
+ ) & (RepoMirrorConfig.sync_expiration_date >> None)
- ready_candidates_filter = ((RepoMirrorConfig.sync_start_date <= now) &
- (RepoMirrorConfig.sync_retries_remaining > 0) &
- (RepoMirrorConfig.sync_status != RepoMirrorStatus.SYNCING) &
- (RepoMirrorConfig.sync_expiration_date >> None) &
- (RepoMirrorConfig.is_enabled == True))
+ ready_candidates_filter = (
+ (RepoMirrorConfig.sync_start_date <= now)
+ & (RepoMirrorConfig.sync_retries_remaining > 0)
+ & (RepoMirrorConfig.sync_status != RepoMirrorStatus.SYNCING)
+ & (RepoMirrorConfig.sync_expiration_date >> None)
+ & (RepoMirrorConfig.is_enabled == True)
+ )
- expired_candidates_filter = ((RepoMirrorConfig.sync_start_date <= now) &
- (RepoMirrorConfig.sync_retries_remaining > 0) &
- (RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNCING) &
- (RepoMirrorConfig.sync_expiration_date <= now) &
- (RepoMirrorConfig.is_enabled == True))
+ expired_candidates_filter = (
+ (RepoMirrorConfig.sync_start_date <= now)
+ & (RepoMirrorConfig.sync_retries_remaining > 0)
+ & (RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNCING)
+ & (RepoMirrorConfig.sync_expiration_date <= now)
+ & (RepoMirrorConfig.is_enabled == True)
+ )
- return (RepoMirrorConfig
- .select()
- .join(Repository)
- .where(Repository.state == RepositoryState.MIRROR)
- .where(immediate_candidates_filter | ready_candidates_filter | expired_candidates_filter)
- .order_by(RepoMirrorConfig.sync_start_date.asc()))
+ return (
+ RepoMirrorConfig.select()
+ .join(Repository)
+ .where(Repository.state == RepositoryState.MIRROR)
+ .where(
+ immediate_candidates_filter
+ | ready_candidates_filter
+ | expired_candidates_filter
+ )
+ .order_by(RepoMirrorConfig.sync_start_date.asc())
+ )
def get_max_id_for_repo_mirror_config():
- """ Gets the maximum id for repository mirroring """
- return RepoMirrorConfig.select(fn.Max(RepoMirrorConfig.id)).scalar()
+ """ Gets the maximum id for repository mirroring """
+ return RepoMirrorConfig.select(fn.Max(RepoMirrorConfig.id)).scalar()
def get_min_id_for_repo_mirror_config():
- """ Gets the minimum id for a repository mirroring """
- return RepoMirrorConfig.select(fn.Min(RepoMirrorConfig.id)).scalar()
+ """ Gets the minimum id for a repository mirroring """
+ return RepoMirrorConfig.select(fn.Min(RepoMirrorConfig.id)).scalar()
def claim_mirror(mirror):
- """
+ """
Attempt to create an exclusive lock on the RepoMirrorConfig and return it.
If unable to create the lock, `None` will be returned.
"""
- # Attempt to update the RepoMirrorConfig to mark it as "claimed"
- now = datetime.utcnow()
- expiration_date = now + timedelta(seconds=MAX_SYNC_DURATION)
- query = (RepoMirrorConfig
- .update(sync_status=RepoMirrorStatus.SYNCING,
- sync_expiration_date=expiration_date,
- sync_transaction_id=uuid_generator())
- .where(RepoMirrorConfig.id == mirror.id,
- RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
+ # Attempt to update the RepoMirrorConfig to mark it as "claimed"
+ now = datetime.utcnow()
+ expiration_date = now + timedelta(seconds=MAX_SYNC_DURATION)
+ query = RepoMirrorConfig.update(
+ sync_status=RepoMirrorStatus.SYNCING,
+ sync_expiration_date=expiration_date,
+ sync_transaction_id=uuid_generator(),
+ ).where(
+ RepoMirrorConfig.id == mirror.id,
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ )
- # If the update was successful, then it was claimed. Return the updated instance.
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
+ # If the update was successful, then it was claimed. Return the updated instance.
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
- return None # Another process must have claimed the mirror faster.
+ return None # Another process must have claimed the mirror faster.
def release_mirror(mirror, sync_status):
- """
+ """
Return a mirror to the queue and update its status.
Upon success, move next sync to be at the next interval in the future. Failures remain with
@@ -91,429 +111,467 @@ def release_mirror(mirror, sync_status):
for example, to retry the next day. Without this, users would need to manually run syncs
to clear failure state.
"""
- if sync_status == RepoMirrorStatus.FAIL:
- retries = max(0, mirror.sync_retries_remaining - 1)
+ if sync_status == RepoMirrorStatus.FAIL:
+ retries = max(0, mirror.sync_retries_remaining - 1)
- if sync_status == RepoMirrorStatus.SUCCESS or retries < 1:
- now = datetime.utcnow()
- delta = now - mirror.sync_start_date
- delta_seconds = (delta.days * 24 * 60 * 60) + delta.seconds
- next_start_date = now + timedelta(seconds=mirror.sync_interval - (delta_seconds % mirror.sync_interval))
- retries = MAX_SYNC_RETRIES
- else:
- next_start_date = mirror.sync_start_date
+ if sync_status == RepoMirrorStatus.SUCCESS or retries < 1:
+ now = datetime.utcnow()
+ delta = now - mirror.sync_start_date
+ delta_seconds = (delta.days * 24 * 60 * 60) + delta.seconds
+ next_start_date = now + timedelta(
+ seconds=mirror.sync_interval - (delta_seconds % mirror.sync_interval)
+ )
+ retries = MAX_SYNC_RETRIES
+ else:
+ next_start_date = mirror.sync_start_date
- query = (RepoMirrorConfig
- .update(sync_transaction_id=uuid_generator(),
- sync_status=sync_status,
- sync_start_date=next_start_date,
- sync_expiration_date=None,
- sync_retries_remaining=retries)
- .where(RepoMirrorConfig.id == mirror.id,
- RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
+ query = RepoMirrorConfig.update(
+ sync_transaction_id=uuid_generator(),
+ sync_status=sync_status,
+ sync_start_date=next_start_date,
+ sync_expiration_date=None,
+ sync_retries_remaining=retries,
+ ).where(
+ RepoMirrorConfig.id == mirror.id,
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ )
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
- # Unable to release Mirror. Has it been claimed by another process?
- return None
+ # Unable to release Mirror. Has it been claimed by another process?
+ return None
def expire_mirror(mirror):
- """
+ """
Set the mirror to synchronize ASAP and reset its failure count.
"""
- # Set the next-sync date to now
- # TODO: Verify the `where` conditions would not expire a currently syncing mirror.
- query = (RepoMirrorConfig
- .update(sync_transaction_id=uuid_generator(),
- sync_expiration_date=datetime.utcnow(),
- sync_retries_remaining=MAX_SYNC_RETRIES)
- .where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
- RepoMirrorConfig.id == mirror.id,
- RepoMirrorConfig.state != RepoMirrorStatus.SYNCING))
+ # Set the next-sync date to now
+ # TODO: Verify the `where` conditions would not expire a currently syncing mirror.
+ query = RepoMirrorConfig.update(
+ sync_transaction_id=uuid_generator(),
+ sync_expiration_date=datetime.utcnow(),
+ sync_retries_remaining=MAX_SYNC_RETRIES,
+ ).where(
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ RepoMirrorConfig.id == mirror.id,
+ RepoMirrorConfig.state != RepoMirrorStatus.SYNCING,
+ )
- # Fetch and return the latest updates
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
+ # Fetch and return the latest updates
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
- # Unable to update expiration date. Perhaps another process has claimed it?
- return None # TODO: Raise some Exception?
+ # Unable to update expiration date. Perhaps another process has claimed it?
+ return None # TODO: Raise some Exception?
-def create_mirroring_rule(repository, rule_value, rule_type=RepoMirrorRuleType.TAG_GLOB_CSV):
- """
+def create_mirroring_rule(
+ repository, rule_value, rule_type=RepoMirrorRuleType.TAG_GLOB_CSV
+):
+ """
Create a RepoMirrorRule for a given Repository.
"""
- if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
- raise ValidationError('validation failed: rule_type must be TAG_GLOB_CSV')
+ if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
+ raise ValidationError("validation failed: rule_type must be TAG_GLOB_CSV")
- if not isinstance(rule_value, list) or len(rule_value) < 1:
- raise ValidationError('validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule')
+ if not isinstance(rule_value, list) or len(rule_value) < 1:
+ raise ValidationError(
+ "validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule"
+ )
- rule = RepoMirrorRule.create(repository=repository, rule_type=rule_type, rule_value=rule_value)
- return rule
+ rule = RepoMirrorRule.create(
+ repository=repository, rule_type=rule_type, rule_value=rule_value
+ )
+ return rule
-def enable_mirroring_for_repository(repository,
- root_rule,
- internal_robot,
- external_reference,
- sync_interval,
- external_registry_username=None,
- external_registry_password=None,
- external_registry_config=None,
- is_enabled=True,
- sync_start_date=None):
- """
+def enable_mirroring_for_repository(
+ repository,
+ root_rule,
+ internal_robot,
+ external_reference,
+ sync_interval,
+ external_registry_username=None,
+ external_registry_password=None,
+ external_registry_config=None,
+ is_enabled=True,
+ sync_start_date=None,
+):
+ """
Create a RepoMirrorConfig and set the Repository to the MIRROR state.
"""
- assert internal_robot.robot
+ assert internal_robot.robot
- namespace, _ = parse_robot_username(internal_robot.username)
- if namespace != repository.namespace_user.username:
- raise DataModelException('Cannot use robot for mirroring')
+ namespace, _ = parse_robot_username(internal_robot.username)
+ if namespace != repository.namespace_user.username:
+ raise DataModelException("Cannot use robot for mirroring")
- with db_transaction():
- # Create the RepoMirrorConfig
- try:
- username = DecryptedValue(external_registry_username) if external_registry_username else None
- password = DecryptedValue(external_registry_password) if external_registry_password else None
- mirror = RepoMirrorConfig.create(repository=repository,
- root_rule=root_rule,
- is_enabled=is_enabled,
- internal_robot=internal_robot,
- external_reference=external_reference,
- external_registry_username=username,
- external_registry_password=password,
- external_registry_config=external_registry_config or {},
- sync_interval=sync_interval,
- sync_start_date=sync_start_date or datetime.utcnow())
- except IntegrityError:
- return RepoMirrorConfig.get(repository=repository)
+ with db_transaction():
+ # Create the RepoMirrorConfig
+ try:
+ username = (
+ DecryptedValue(external_registry_username)
+ if external_registry_username
+ else None
+ )
+ password = (
+ DecryptedValue(external_registry_password)
+ if external_registry_password
+ else None
+ )
+ mirror = RepoMirrorConfig.create(
+ repository=repository,
+ root_rule=root_rule,
+ is_enabled=is_enabled,
+ internal_robot=internal_robot,
+ external_reference=external_reference,
+ external_registry_username=username,
+ external_registry_password=password,
+ external_registry_config=external_registry_config or {},
+ sync_interval=sync_interval,
+ sync_start_date=sync_start_date or datetime.utcnow(),
+ )
+ except IntegrityError:
+ return RepoMirrorConfig.get(repository=repository)
- # Change Repository state to mirroring mode as needed
- if repository.state != RepositoryState.MIRROR:
- query = (Repository
- .update(state=RepositoryState.MIRROR)
- .where(Repository.id == repository.id))
- if not query.execute():
- raise DataModelException('Could not change the state of the repository')
+ # Change Repository state to mirroring mode as needed
+ if repository.state != RepositoryState.MIRROR:
+ query = Repository.update(state=RepositoryState.MIRROR).where(
+ Repository.id == repository.id
+ )
+ if not query.execute():
+ raise DataModelException("Could not change the state of the repository")
- return mirror
+ return mirror
def update_sync_status(mirror, sync_status):
- """
+ """
Update the sync status
"""
- query = (RepoMirrorConfig
- .update(sync_transaction_id=uuid_generator(),
- sync_status=sync_status)
- .where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
- RepoMirrorConfig.id == mirror.id))
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
+ query = RepoMirrorConfig.update(
+ sync_transaction_id=uuid_generator(), sync_status=sync_status
+ ).where(
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ RepoMirrorConfig.id == mirror.id,
+ )
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
- return None
+ return None
def update_sync_status_to_sync_now(mirror):
- """
+ """
This will change the sync status to SYNC_NOW and set the retries remaining to one, if it is
less than one. None will be returned in cases where this is not possible, such as if the
mirror is in the SYNCING state.
"""
- if mirror.sync_status == RepoMirrorStatus.SYNCING:
+ if mirror.sync_status == RepoMirrorStatus.SYNCING:
+ return None
+
+ retries = max(mirror.sync_retries_remaining, 1)
+
+ query = RepoMirrorConfig.update(
+ sync_transaction_id=uuid_generator(),
+ sync_status=RepoMirrorStatus.SYNC_NOW,
+ sync_expiration_date=None,
+ sync_retries_remaining=retries,
+ ).where(
+ RepoMirrorConfig.id == mirror.id,
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ )
+
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
+
return None
- retries = max(mirror.sync_retries_remaining, 1)
-
- query = (RepoMirrorConfig
- .update(sync_transaction_id=uuid_generator(),
- sync_status=RepoMirrorStatus.SYNC_NOW,
- sync_expiration_date=None,
- sync_retries_remaining=retries)
- .where(RepoMirrorConfig.id == mirror.id,
- RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
-
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
-
- return None
-
def update_sync_status_to_cancel(mirror):
- """
+ """
If the mirror is SYNCING, it will be force-claimed (ignoring existing transaction id), and the
state will set to NEVER_RUN. None will be returned in cases where this is not possible, such
as if the mirror is not in the SYNCING state.
"""
- if mirror.sync_status != RepoMirrorStatus.SYNCING and mirror.sync_status != RepoMirrorStatus.SYNC_NOW:
+ if (
+ mirror.sync_status != RepoMirrorStatus.SYNCING
+ and mirror.sync_status != RepoMirrorStatus.SYNC_NOW
+ ):
+ return None
+
+ query = RepoMirrorConfig.update(
+ sync_transaction_id=uuid_generator(),
+ sync_status=RepoMirrorStatus.NEVER_RUN,
+ sync_expiration_date=None,
+ ).where(RepoMirrorConfig.id == mirror.id)
+
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
+
return None
- query = (RepoMirrorConfig
- .update(sync_transaction_id=uuid_generator(),
- sync_status=RepoMirrorStatus.NEVER_RUN,
- sync_expiration_date=None)
- .where(RepoMirrorConfig.id == mirror.id))
-
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
-
- return None
-
def update_with_transaction(mirror, **kwargs):
- """
+ """
Helper function which updates a Repository's RepoMirrorConfig while also rolling its
sync_transaction_id for locking purposes.
"""
- # RepoMirrorConfig attributes which can be modified
- mutable_attributes = (
- 'is_enabled',
- 'mirror_type',
- 'external_reference',
- 'external_registry_username',
- 'external_registry_password',
- 'external_registry_config',
- 'sync_interval',
- 'sync_start_date',
- 'sync_expiration_date',
- 'sync_retries_remaining',
- 'sync_status',
- 'sync_transaction_id'
- )
+ # RepoMirrorConfig attributes which can be modified
+ mutable_attributes = (
+ "is_enabled",
+ "mirror_type",
+ "external_reference",
+ "external_registry_username",
+ "external_registry_password",
+ "external_registry_config",
+ "sync_interval",
+ "sync_start_date",
+ "sync_expiration_date",
+ "sync_retries_remaining",
+ "sync_status",
+ "sync_transaction_id",
+ )
- # Key-Value map of changes to make
- filtered_kwargs = {key:kwargs.pop(key) for key in mutable_attributes if key in kwargs}
+ # Key-Value map of changes to make
+ filtered_kwargs = {
+ key: kwargs.pop(key) for key in mutable_attributes if key in kwargs
+ }
- # Roll the sync_transaction_id to a new value
- filtered_kwargs['sync_transaction_id'] = uuid_generator()
+ # Roll the sync_transaction_id to a new value
+ filtered_kwargs["sync_transaction_id"] = uuid_generator()
- # Generate the query to perform the updates
- query = (RepoMirrorConfig
- .update(filtered_kwargs)
- .where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
- RepoMirrorConfig.id == mirror.id))
+ # Generate the query to perform the updates
+ query = RepoMirrorConfig.update(filtered_kwargs).where(
+ RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
+ RepoMirrorConfig.id == mirror.id,
+ )
- # Apply the change(s) and return the object if successful
- if query.execute():
- return RepoMirrorConfig.get_by_id(mirror.id)
- else:
- return None
+ # Apply the change(s) and return the object if successful
+ if query.execute():
+ return RepoMirrorConfig.get_by_id(mirror.id)
+ else:
+ return None
def get_mirror(repository):
- """
+ """
Return the RepoMirrorConfig associated with the given Repository, or None if it doesn't exist.
"""
- try:
- return RepoMirrorConfig.get(repository=repository)
- except RepoMirrorConfig.DoesNotExist:
- return None
+ try:
+ return RepoMirrorConfig.get(repository=repository)
+ except RepoMirrorConfig.DoesNotExist:
+ return None
def enable_mirror(repository):
- """
+ """
Enables a RepoMirrorConfig.
"""
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, is_enabled=True))
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, is_enabled=True))
def disable_mirror(repository):
- """
+ """
Disables a RepoMirrorConfig.
"""
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, is_enabled=False))
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, is_enabled=False))
def delete_mirror(repository):
- """
+ """
Delete a Repository Mirroring configuration.
"""
- raise NotImplementedError("TODO: Not Implemented")
+ raise NotImplementedError("TODO: Not Implemented")
def change_remote(repository, remote_repository):
- """
+ """
Update the external repository for Repository Mirroring.
"""
- mirror = get_mirror(repository)
- updates = {
- 'external_reference': remote_repository
- }
- return bool(update_with_transaction(mirror, **updates))
+ mirror = get_mirror(repository)
+ updates = {"external_reference": remote_repository}
+ return bool(update_with_transaction(mirror, **updates))
def change_credentials(repository, username, password):
- """
+ """
Update the credentials used to access the remote repository.
"""
- mirror = get_mirror(repository)
- updates = {
- 'external_registry_username': username,
- 'external_registry_password': password,
- }
- return bool(update_with_transaction(mirror, **updates))
+ mirror = get_mirror(repository)
+ updates = {
+ "external_registry_username": username,
+ "external_registry_password": password,
+ }
+ return bool(update_with_transaction(mirror, **updates))
def change_username(repository, username):
- """
+ """
Update the Username used to access the external repository.
"""
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, external_registry_username=username))
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, external_registry_username=username))
def change_sync_interval(repository, interval):
- """
+ """
Update the interval at which a repository will be synchronized.
"""
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, sync_interval=interval))
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, sync_interval=interval))
def change_sync_start_date(repository, dt):
- """
+ """
Specify when the repository should be synchronized next.
"""
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, sync_start_date=dt))
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, sync_start_date=dt))
def change_root_rule(repository, rule):
- """
+ """
Specify which rule should be used for repository mirroring.
"""
- assert rule.repository == repository
- mirror = get_mirror(repository)
- return bool(update_with_transaction(mirror, root_rule=rule))
+ assert rule.repository == repository
+ mirror = get_mirror(repository)
+ return bool(update_with_transaction(mirror, root_rule=rule))
def change_sync_status(repository, sync_status):
- """
+ """
Change Repository's mirroring status.
"""
- mirror = get_mirror(repository)
- return update_with_transaction(mirror, sync_status=sync_status)
+ mirror = get_mirror(repository)
+ return update_with_transaction(mirror, sync_status=sync_status)
def change_retries_remaining(repository, retries_remaining):
- """
+ """
Change the number of retries remaining for mirroring a repository.
"""
- mirror = get_mirror(repository)
- return update_with_transaction(mirror, sync_retries_remaining=retries_remaining)
+ mirror = get_mirror(repository)
+ return update_with_transaction(mirror, sync_retries_remaining=retries_remaining)
def change_external_registry_config(repository, config_updates):
- """
+ """
Update the 'external_registry_config' with the passed in fields. Config has:
verify_tls: True|False
proxy: JSON fields 'http_proxy', 'https_proxy', andn 'no_proxy'
"""
- mirror = get_mirror(repository)
- external_registry_config = mirror.external_registry_config
+ mirror = get_mirror(repository)
+ external_registry_config = mirror.external_registry_config
- if 'verify_tls' in config_updates:
- external_registry_config['verify_tls'] = config_updates['verify_tls']
+ if "verify_tls" in config_updates:
+ external_registry_config["verify_tls"] = config_updates["verify_tls"]
- if 'proxy' in config_updates:
- proxy_updates = config_updates['proxy']
- for key in ('http_proxy', 'https_proxy', 'no_proxy'):
- if key in config_updates['proxy']:
- if 'proxy' not in external_registry_config:
- external_registry_config['proxy'] = {}
- else:
- external_registry_config['proxy'][key] = proxy_updates[key]
+ if "proxy" in config_updates:
+ proxy_updates = config_updates["proxy"]
+ for key in ("http_proxy", "https_proxy", "no_proxy"):
+ if key in config_updates["proxy"]:
+ if "proxy" not in external_registry_config:
+ external_registry_config["proxy"] = {}
+ else:
+ external_registry_config["proxy"][key] = proxy_updates[key]
- return update_with_transaction(mirror, external_registry_config=external_registry_config)
+ return update_with_transaction(
+ mirror, external_registry_config=external_registry_config
+ )
def get_mirroring_robot(repository):
- """
+ """
Return the robot used for mirroring. Returns None if the repository does not have an associated
RepoMirrorConfig or the robot does not exist.
"""
- mirror = get_mirror(repository)
- if mirror:
- return mirror.internal_robot
+ mirror = get_mirror(repository)
+ if mirror:
+ return mirror.internal_robot
- return None
+ return None
def set_mirroring_robot(repository, robot):
- """
+ """
Sets the mirroring robot for the repository.
"""
- assert robot.robot
- namespace, _ = parse_robot_username(robot.username)
- if namespace != repository.namespace_user.username:
- raise DataModelException('Cannot use robot for mirroring')
+ assert robot.robot
+ namespace, _ = parse_robot_username(robot.username)
+ if namespace != repository.namespace_user.username:
+ raise DataModelException("Cannot use robot for mirroring")
- mirror = get_mirror(repository)
- mirror.internal_robot = robot
- mirror.save()
+ mirror = get_mirror(repository)
+ mirror.internal_robot = robot
+ mirror.save()
# -------------------- Mirroring Rules --------------------------#
-def create_rule(repository, rule_value, rule_type=RepoMirrorRuleType.TAG_GLOB_CSV, left_child=None, right_child=None):
- """
+def create_rule(
+ repository,
+ rule_value,
+ rule_type=RepoMirrorRuleType.TAG_GLOB_CSV,
+ left_child=None,
+ right_child=None,
+):
+ """
Create a new Rule for mirroring a Repository
"""
- if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
- raise ValidationError('validation failed: rule_type must be TAG_GLOB_CSV')
+ if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
+ raise ValidationError("validation failed: rule_type must be TAG_GLOB_CSV")
- if not isinstance(rule_value, list) or len(rule_value) < 1:
- raise ValidationError('validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule')
+ if not isinstance(rule_value, list) or len(rule_value) < 1:
+ raise ValidationError(
+ "validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule"
+ )
- rule_kwargs = {
- 'repository': repository,
- 'rule_value': rule_value,
- 'rule_type': rule_type,
- 'left_child': left_child,
- 'right_child': right_child,
- }
- rule = RepoMirrorRule.create(**rule_kwargs)
- return rule
+ rule_kwargs = {
+ "repository": repository,
+ "rule_value": rule_value,
+ "rule_type": rule_type,
+ "left_child": left_child,
+ "right_child": right_child,
+ }
+ rule = RepoMirrorRule.create(**rule_kwargs)
+ return rule
def list_rules(repository):
- """
+ """
Returns all RepoMirrorRules associated with a Repository.
"""
- rules = RepoMirrorRule.select().where(RepoMirrorRule.repository == repository).all()
- return rules
+ rules = RepoMirrorRule.select().where(RepoMirrorRule.repository == repository).all()
+ return rules
def get_root_rule(repository):
- """
+ """
Return the primary mirroring Rule
"""
- mirror = get_mirror(repository)
- try:
- rule = RepoMirrorRule.get(repository=repository)
- return rule
- except RepoMirrorRule.DoesNotExist:
- return None
+ mirror = get_mirror(repository)
+ try:
+ rule = RepoMirrorRule.get(repository=repository)
+ return rule
+ except RepoMirrorRule.DoesNotExist:
+ return None
def change_rule_value(rule, value):
- """
+ """
Update the value of an existing rule.
"""
- query = (RepoMirrorRule
- .update(rule_value=value)
- .where(RepoMirrorRule.id == rule.id))
- return query.execute()
+ query = RepoMirrorRule.update(rule_value=value).where(RepoMirrorRule.id == rule.id)
+ return query.execute()
diff --git a/data/model/repository.py b/data/model/repository.py
index 3400bfde8..fc4bee236 100644
--- a/data/model/repository.py
+++ b/data/model/repository.py
@@ -7,13 +7,40 @@ from peewee import Case, JOIN, fn, SQL, IntegrityError
from cachetools.func import ttl_cache
from data.model import (
- config, DataModelException, tag, db_transaction, storage, permission, _basequery)
+ config,
+ DataModelException,
+ tag,
+ db_transaction,
+ storage,
+ permission,
+ _basequery,
+)
from data.database import (
- Repository, Namespace, RepositoryTag, Star, Image, ImageStorage, User, Visibility,
- RepositoryPermission, RepositoryActionCount, Role, RepositoryAuthorizedEmail,
- DerivedStorageForImage, Label, db_for_update, get_epoch_timestamp,
- db_random_func, db_concat_func, RepositorySearchScore, RepositoryKind, ApprTag,
- ManifestLegacyImage, Manifest, ManifestChild)
+ Repository,
+ Namespace,
+ RepositoryTag,
+ Star,
+ Image,
+ ImageStorage,
+ User,
+ Visibility,
+ RepositoryPermission,
+ RepositoryActionCount,
+ Role,
+ RepositoryAuthorizedEmail,
+ DerivedStorageForImage,
+ Label,
+ db_for_update,
+ get_epoch_timestamp,
+ db_random_func,
+ db_concat_func,
+ RepositorySearchScore,
+ RepositoryKind,
+ ApprTag,
+ ManifestLegacyImage,
+ Manifest,
+ ManifestChild,
+)
from data.text import prefix_search
from util.itertoolrecipes import take
@@ -22,436 +49,572 @@ SEARCH_FIELDS = Enum("SearchFields", ["name", "description"])
class RepoStateConfigException(Exception):
- """ Repository.state value requires further configuration to operate. """
- pass
+ """ Repository.state value requires further configuration to operate. """
+
+ pass
def get_repo_kind_name(repo):
- return Repository.kind.get_name(repo.kind_id)
+ return Repository.kind.get_name(repo.kind_id)
def get_repository_count():
- return Repository.select().count()
+ return Repository.select().count()
def get_public_repo_visibility():
- return _basequery.get_public_repo_visibility()
+ return _basequery.get_public_repo_visibility()
-def create_repository(namespace, name, creating_user, visibility='private', repo_kind='image',
- description=None):
- namespace_user = User.get(username=namespace)
- yesterday = datetime.now() - timedelta(days=1)
+def create_repository(
+ namespace,
+ name,
+ creating_user,
+ visibility="private",
+ repo_kind="image",
+ description=None,
+):
+ namespace_user = User.get(username=namespace)
+ yesterday = datetime.now() - timedelta(days=1)
- with db_transaction():
- repo = Repository.create(name=name, visibility=Repository.visibility.get_id(visibility),
- namespace_user=namespace_user,
- kind=Repository.kind.get_id(repo_kind),
- description=description)
+ with db_transaction():
+ repo = Repository.create(
+ name=name,
+ visibility=Repository.visibility.get_id(visibility),
+ namespace_user=namespace_user,
+ kind=Repository.kind.get_id(repo_kind),
+ description=description,
+ )
- RepositoryActionCount.create(repository=repo, count=0, date=yesterday)
- RepositorySearchScore.create(repository=repo, score=0)
+ RepositoryActionCount.create(repository=repo, count=0, date=yesterday)
+ RepositorySearchScore.create(repository=repo, score=0)
- # Note: We put the admin create permission under the transaction to ensure it is created.
- if creating_user and not creating_user.organization:
- admin = Role.get(name='admin')
- RepositoryPermission.create(user=creating_user, repository=repo, role=admin)
+ # Note: We put the admin create permission under the transaction to ensure it is created.
+ if creating_user and not creating_user.organization:
+ admin = Role.get(name="admin")
+ RepositoryPermission.create(user=creating_user, repository=repo, role=admin)
- # Apply default permissions (only occurs for repositories under organizations)
- if creating_user and not creating_user.organization and creating_user.username != namespace:
- permission.apply_default_permissions(repo, creating_user)
+ # Apply default permissions (only occurs for repositories under organizations)
+ if (
+ creating_user
+ and not creating_user.organization
+ and creating_user.username != namespace
+ ):
+ permission.apply_default_permissions(repo, creating_user)
- return repo
+ return repo
def get_repository(namespace_name, repository_name, kind_filter=None):
- try:
- return _basequery.get_existing_repository(namespace_name, repository_name,
- kind_filter=kind_filter)
- except Repository.DoesNotExist:
- return None
+ try:
+ return _basequery.get_existing_repository(
+ namespace_name, repository_name, kind_filter=kind_filter
+ )
+ except Repository.DoesNotExist:
+ return None
-def get_or_create_repository(namespace, name, creating_user, visibility='private',
- repo_kind='image'):
- repo = get_repository(namespace, name, repo_kind)
- if repo is None:
- repo = create_repository(namespace, name, creating_user, visibility, repo_kind)
- return repo
+def get_or_create_repository(
+ namespace, name, creating_user, visibility="private", repo_kind="image"
+):
+ repo = get_repository(namespace, name, repo_kind)
+ if repo is None:
+ repo = create_repository(namespace, name, creating_user, visibility, repo_kind)
+ return repo
@ttl_cache(maxsize=1, ttl=600)
def _get_gc_expiration_policies():
- policy_tuples_query = (
- Namespace.select(Namespace.removed_tag_expiration_s).distinct()
- .limit(100) # This sucks but it's the only way to limit memory
- .tuples())
- return [policy[0] for policy in policy_tuples_query]
+ policy_tuples_query = (
+ Namespace.select(Namespace.removed_tag_expiration_s)
+ .distinct()
+ .limit(100) # This sucks but it's the only way to limit memory
+ .tuples()
+ )
+ return [policy[0] for policy in policy_tuples_query]
def get_random_gc_policy():
- """ Return a single random policy from the database to use when garbage collecting.
+ """ Return a single random policy from the database to use when garbage collecting.
"""
- return random.choice(_get_gc_expiration_policies())
+ return random.choice(_get_gc_expiration_policies())
def find_repository_with_garbage(limit_to_gc_policy_s):
- expiration_timestamp = get_epoch_timestamp() - limit_to_gc_policy_s
+ expiration_timestamp = get_epoch_timestamp() - limit_to_gc_policy_s
- try:
- candidates = (RepositoryTag.select(RepositoryTag.repository).join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(~(RepositoryTag.lifetime_end_ts >> None),
- (RepositoryTag.lifetime_end_ts <= expiration_timestamp),
- (Namespace.removed_tag_expiration_s == limit_to_gc_policy_s)).limit(500)
- .distinct().alias('candidates'))
+ try:
+ candidates = (
+ RepositoryTag.select(RepositoryTag.repository)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ ~(RepositoryTag.lifetime_end_ts >> None),
+ (RepositoryTag.lifetime_end_ts <= expiration_timestamp),
+ (Namespace.removed_tag_expiration_s == limit_to_gc_policy_s),
+ )
+ .limit(500)
+ .distinct()
+ .alias("candidates")
+ )
- found = (RepositoryTag.select(candidates.c.repository_id).from_(candidates)
- .order_by(db_random_func()).get())
+ found = (
+ RepositoryTag.select(candidates.c.repository_id)
+ .from_(candidates)
+ .order_by(db_random_func())
+ .get()
+ )
- if found is None:
- return
+ if found is None:
+ return
- return Repository.get(Repository.id == found.repository_id)
- except RepositoryTag.DoesNotExist:
- return None
- except Repository.DoesNotExist:
- return None
+ return Repository.get(Repository.id == found.repository_id)
+ except RepositoryTag.DoesNotExist:
+ return None
+ except Repository.DoesNotExist:
+ return None
def star_repository(user, repository):
- """ Stars a repository. """
- star = Star.create(user=user.id, repository=repository.id)
- star.save()
+ """ Stars a repository. """
+ star = Star.create(user=user.id, repository=repository.id)
+ star.save()
def unstar_repository(user, repository):
- """ Unstars a repository. """
- try:
- (Star.delete().where(Star.repository == repository.id, Star.user == user.id).execute())
- except Star.DoesNotExist:
- raise DataModelException('Star not found.')
+ """ Unstars a repository. """
+ try:
+ (
+ Star.delete()
+ .where(Star.repository == repository.id, Star.user == user.id)
+ .execute()
+ )
+ except Star.DoesNotExist:
+ raise DataModelException("Star not found.")
def set_trust(repo, trust_enabled):
- repo.trust_enabled = trust_enabled
- repo.save()
+ repo.trust_enabled = trust_enabled
+ repo.save()
def set_description(repo, description):
- repo.description = description
- repo.save()
+ repo.description = description
+ repo.save()
-def get_user_starred_repositories(user, kind_filter='image'):
- """ Retrieves all of the repositories a user has starred. """
- try:
- repo_kind = Repository.kind.get_id(kind_filter)
- except RepositoryKind.DoesNotExist:
- raise DataModelException('Unknown kind of repository')
+def get_user_starred_repositories(user, kind_filter="image"):
+ """ Retrieves all of the repositories a user has starred. """
+ try:
+ repo_kind = Repository.kind.get_id(kind_filter)
+ except RepositoryKind.DoesNotExist:
+ raise DataModelException("Unknown kind of repository")
- query = (Repository.select(Repository, User, Visibility, Repository.id.alias('rid')).join(Star)
- .switch(Repository).join(User).switch(Repository).join(Visibility)
- .where(Star.user == user, Repository.kind == repo_kind))
+ query = (
+ Repository.select(Repository, User, Visibility, Repository.id.alias("rid"))
+ .join(Star)
+ .switch(Repository)
+ .join(User)
+ .switch(Repository)
+ .join(Visibility)
+ .where(Star.user == user, Repository.kind == repo_kind)
+ )
- return query
+ return query
def repository_is_starred(user, repository):
- """ Determines whether a user has starred a repository or not. """
- try:
- (Star.select().where(Star.repository == repository.id, Star.user == user.id).get())
- return True
- except Star.DoesNotExist:
- return False
+ """ Determines whether a user has starred a repository or not. """
+ try:
+ (
+ Star.select()
+ .where(Star.repository == repository.id, Star.user == user.id)
+ .get()
+ )
+ return True
+ except Star.DoesNotExist:
+ return False
def get_stars(repository_ids):
- """ Returns a map from repository ID to the number of stars for each repository in the
+ """ Returns a map from repository ID to the number of stars for each repository in the
given repository IDs list.
"""
- if not repository_ids:
- return {}
+ if not repository_ids:
+ return {}
- tuples = (Star.select(Star.repository, fn.Count(Star.id))
- .where(Star.repository << repository_ids).group_by(Star.repository).tuples())
+ tuples = (
+ Star.select(Star.repository, fn.Count(Star.id))
+ .where(Star.repository << repository_ids)
+ .group_by(Star.repository)
+ .tuples()
+ )
- star_map = {}
- for record in tuples:
- star_map[record[0]] = record[1]
+ star_map = {}
+ for record in tuples:
+ star_map[record[0]] = record[1]
- return star_map
+ return star_map
-def get_visible_repositories(username, namespace=None, kind_filter='image', include_public=False,
- start_id=None, limit=None):
- """ Returns the repositories visible to the given user (if any).
+def get_visible_repositories(
+ username,
+ namespace=None,
+ kind_filter="image",
+ include_public=False,
+ start_id=None,
+ limit=None,
+):
+ """ Returns the repositories visible to the given user (if any).
"""
- if not include_public and not username:
- # Short circuit by returning a query that will find no repositories. We need to return a query
- # here, as it will be modified by other queries later on.
- return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
+ if not include_public and not username:
+ # Short circuit by returning a query that will find no repositories. We need to return a query
+ # here, as it will be modified by other queries later on.
+ return Repository.select(Repository.id.alias("rid")).where(Repository.id == -1)
- query = (Repository.select(Repository.name,
- Repository.id.alias('rid'), Repository.description,
- Namespace.username, Repository.visibility, Repository.kind)
- .switch(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id)))
+ query = (
+ Repository.select(
+ Repository.name,
+ Repository.id.alias("rid"),
+ Repository.description,
+ Namespace.username,
+ Repository.visibility,
+ Repository.kind,
+ )
+ .switch(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ )
- user_id = None
- if username:
- # Note: We only need the permissions table if we will filter based on a user's permissions.
- query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN.LEFT_OUTER)
- found_namespace = _get_namespace_user(username)
- if not found_namespace:
- return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
+ user_id = None
+ if username:
+ # Note: We only need the permissions table if we will filter based on a user's permissions.
+ query = (
+ query.switch(Repository)
+ .distinct()
+ .join(RepositoryPermission, JOIN.LEFT_OUTER)
+ )
+ found_namespace = _get_namespace_user(username)
+ if not found_namespace:
+ return Repository.select(Repository.id.alias("rid")).where(
+ Repository.id == -1
+ )
- user_id = found_namespace.id
+ user_id = found_namespace.id
- query = _basequery.filter_to_repos_for_user(query, user_id, namespace, kind_filter,
- include_public, start_id=start_id)
+ query = _basequery.filter_to_repos_for_user(
+ query, user_id, namespace, kind_filter, include_public, start_id=start_id
+ )
- if limit is not None:
- query = query.limit(limit).order_by(SQL('rid'))
+ if limit is not None:
+ query = query.limit(limit).order_by(SQL("rid"))
- return query
+ return query
def get_app_repository(namespace_name, repository_name):
- """ Find an application repository. """
- try:
- return _basequery.get_existing_repository(namespace_name, repository_name,
- kind_filter='application')
- except Repository.DoesNotExist:
- return None
+ """ Find an application repository. """
+ try:
+ return _basequery.get_existing_repository(
+ namespace_name, repository_name, kind_filter="application"
+ )
+ except Repository.DoesNotExist:
+ return None
def get_app_search(lookup, search_fields=None, username=None, limit=50):
- if search_fields is None:
- search_fields = set([SEARCH_FIELDS.name.name])
+ if search_fields is None:
+ search_fields = set([SEARCH_FIELDS.name.name])
- return get_filtered_matching_repositories(lookup, filter_username=username,
- search_fields=search_fields, repo_kind='application',
- offset=0, limit=limit)
+ return get_filtered_matching_repositories(
+ lookup,
+ filter_username=username,
+ search_fields=search_fields,
+ repo_kind="application",
+ offset=0,
+ limit=limit,
+ )
def _get_namespace_user(username):
- try:
- return User.get(username=username)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(username=username)
+ except User.DoesNotExist:
+ return None
-def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_kind='image',
- offset=0, limit=25, search_fields=None):
- """ Returns an iterator of all repositories matching the given lookup value, with optional
+def get_filtered_matching_repositories(
+ lookup_value,
+ filter_username=None,
+ repo_kind="image",
+ offset=0,
+ limit=25,
+ search_fields=None,
+):
+ """ Returns an iterator of all repositories matching the given lookup value, with optional
filtering to a specific user. If the user is unspecified, only public repositories will
be returned.
"""
- if search_fields is None:
- search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])
+ if search_fields is None:
+ search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])
- # Build the unfiltered search query.
- unfiltered_query = _get_sorted_matching_repositories(lookup_value, repo_kind=repo_kind,
- search_fields=search_fields,
- include_private=filter_username is not None,
- ids_only=filter_username is not None)
+ # Build the unfiltered search query.
+ unfiltered_query = _get_sorted_matching_repositories(
+ lookup_value,
+ repo_kind=repo_kind,
+ search_fields=search_fields,
+ include_private=filter_username is not None,
+ ids_only=filter_username is not None,
+ )
- # Add a filter to the iterator, if necessary.
- if filter_username is not None:
- filter_user = _get_namespace_user(filter_username)
- if filter_user is None:
- return []
+ # Add a filter to the iterator, if necessary.
+ if filter_username is not None:
+ filter_user = _get_namespace_user(filter_username)
+ if filter_user is None:
+ return []
- iterator = _filter_repositories_visible_to_user(unfiltered_query, filter_user.id, limit,
- repo_kind)
- if offset > 0:
- take(offset, iterator)
+ iterator = _filter_repositories_visible_to_user(
+ unfiltered_query, filter_user.id, limit, repo_kind
+ )
+ if offset > 0:
+ take(offset, iterator)
- # Return the results.
- return list(take(limit, iterator))
+ # Return the results.
+ return list(take(limit, iterator))
- return list(unfiltered_query.offset(offset).limit(limit))
+ return list(unfiltered_query.offset(offset).limit(limit))
-def _filter_repositories_visible_to_user(unfiltered_query, filter_user_id, limit, repo_kind):
- encountered = set()
- chunk_count = limit * 2
- unfiltered_page = 0
- iteration_count = 0
+def _filter_repositories_visible_to_user(
+ unfiltered_query, filter_user_id, limit, repo_kind
+):
+ encountered = set()
+ chunk_count = limit * 2
+ unfiltered_page = 0
+ iteration_count = 0
- while iteration_count < 10: # Just to be safe
- # Find the next chunk's worth of repository IDs, paginated by the chunk size.
- unfiltered_page = unfiltered_page + 1
- found_ids = [r.id for r in unfiltered_query.paginate(unfiltered_page, chunk_count)]
+ while iteration_count < 10: # Just to be safe
+ # Find the next chunk's worth of repository IDs, paginated by the chunk size.
+ unfiltered_page = unfiltered_page + 1
+ found_ids = [
+ r.id for r in unfiltered_query.paginate(unfiltered_page, chunk_count)
+ ]
- # Make sure we haven't encountered these results before. This code is used to handle
- # the case where we've previously seen a result, as pagination is not necessary
- # stable in SQL databases.
- unfiltered_repository_ids = set(found_ids)
- new_unfiltered_ids = unfiltered_repository_ids - encountered
- if not new_unfiltered_ids:
- break
+ # Make sure we haven't encountered these results before. This code is used to handle
+ # the case where we've previously seen a result, as pagination is not necessary
+ # stable in SQL databases.
+ unfiltered_repository_ids = set(found_ids)
+ new_unfiltered_ids = unfiltered_repository_ids - encountered
+ if not new_unfiltered_ids:
+ break
- encountered.update(new_unfiltered_ids)
+ encountered.update(new_unfiltered_ids)
- # Filter the repositories found to only those visible to the current user.
- query = (Repository
- .select(Repository, Namespace)
- .distinct()
- .join(Namespace, on=(Namespace.id == Repository.namespace_user)).switch(Repository)
- .join(RepositoryPermission).where(Repository.id << list(new_unfiltered_ids)))
+ # Filter the repositories found to only those visible to the current user.
+ query = (
+ Repository.select(Repository, Namespace)
+ .distinct()
+ .join(Namespace, on=(Namespace.id == Repository.namespace_user))
+ .switch(Repository)
+ .join(RepositoryPermission)
+ .where(Repository.id << list(new_unfiltered_ids))
+ )
- filtered = _basequery.filter_to_repos_for_user(query, filter_user_id, repo_kind=repo_kind)
+ filtered = _basequery.filter_to_repos_for_user(
+ query, filter_user_id, repo_kind=repo_kind
+ )
- # Sort the filtered repositories by their initial order.
- all_filtered_repos = list(filtered)
- all_filtered_repos.sort(key=lambda repo: found_ids.index(repo.id))
+ # Sort the filtered repositories by their initial order.
+ all_filtered_repos = list(filtered)
+ all_filtered_repos.sort(key=lambda repo: found_ids.index(repo.id))
- # Yield the repositories in sorted order.
- for filtered_repo in all_filtered_repos:
- yield filtered_repo
+ # Yield the repositories in sorted order.
+ for filtered_repo in all_filtered_repos:
+ yield filtered_repo
- # If the number of found IDs is less than the chunk count, then we're done.
- if len(found_ids) < chunk_count:
- break
+ # If the number of found IDs is less than the chunk count, then we're done.
+ if len(found_ids) < chunk_count:
+ break
- iteration_count = iteration_count + 1
+ iteration_count = iteration_count + 1
-def _get_sorted_matching_repositories(lookup_value, repo_kind='image', include_private=False,
- search_fields=None, ids_only=False):
- """ Returns a query of repositories matching the given lookup string, with optional inclusion of
+def _get_sorted_matching_repositories(
+ lookup_value,
+ repo_kind="image",
+ include_private=False,
+ search_fields=None,
+ ids_only=False,
+):
+ """ Returns a query of repositories matching the given lookup string, with optional inclusion of
private repositories. Note that this method does *not* filter results based on visibility
to users.
"""
- select_fields = [Repository.id] if ids_only else [Repository, Namespace]
+ select_fields = [Repository.id] if ids_only else [Repository, Namespace]
- if not lookup_value:
- # This is a generic listing of repositories. Simply return the sorted repositories based
- # on RepositorySearchScore.
- query = (Repository
- .select(*select_fields)
- .join(RepositorySearchScore)
- .order_by(RepositorySearchScore.score.desc()))
- else:
- if search_fields is None:
- search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])
+ if not lookup_value:
+ # This is a generic listing of repositories. Simply return the sorted repositories based
+ # on RepositorySearchScore.
+ query = (
+ Repository.select(*select_fields)
+ .join(RepositorySearchScore)
+ .order_by(RepositorySearchScore.score.desc())
+ )
+ else:
+ if search_fields is None:
+ search_fields = set(
+ [SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name]
+ )
- # Always search at least on name (init clause)
- clause = Repository.name.match(lookup_value)
- computed_score = RepositorySearchScore.score.alias('score')
+ # Always search at least on name (init clause)
+ clause = Repository.name.match(lookup_value)
+ computed_score = RepositorySearchScore.score.alias("score")
- # If the description field is in the search fields, then we need to compute a synthetic score
- # to discount the weight of the description more than the name.
- if SEARCH_FIELDS.description.name in search_fields:
- clause = Repository.description.match(lookup_value) | clause
- cases = [(Repository.name.match(lookup_value), 100 * RepositorySearchScore.score),]
- computed_score = Case(None, cases, RepositorySearchScore.score).alias('score')
+ # If the description field is in the search fields, then we need to compute a synthetic score
+ # to discount the weight of the description more than the name.
+ if SEARCH_FIELDS.description.name in search_fields:
+ clause = Repository.description.match(lookup_value) | clause
+ cases = [
+ (Repository.name.match(lookup_value), 100 * RepositorySearchScore.score)
+ ]
+ computed_score = Case(None, cases, RepositorySearchScore.score).alias(
+ "score"
+ )
- select_fields.append(computed_score)
- query = (Repository.select(*select_fields)
- .join(RepositorySearchScore)
- .where(clause)
- .order_by(SQL('score').desc()))
+ select_fields.append(computed_score)
+ query = (
+ Repository.select(*select_fields)
+ .join(RepositorySearchScore)
+ .where(clause)
+ .order_by(SQL("score").desc())
+ )
- if repo_kind is not None:
- query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
+ if repo_kind is not None:
+ query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
- if not include_private:
- query = query.where(Repository.visibility == _basequery.get_public_repo_visibility())
+ if not include_private:
+ query = query.where(
+ Repository.visibility == _basequery.get_public_repo_visibility()
+ )
- if not ids_only:
- query = (query
- .switch(Repository)
- .join(Namespace, on=(Namespace.id == Repository.namespace_user)))
+ if not ids_only:
+ query = query.switch(Repository).join(
+ Namespace, on=(Namespace.id == Repository.namespace_user)
+ )
- return query
+ return query
def lookup_repository(repo_id):
- try:
- return Repository.get(Repository.id == repo_id)
- except Repository.DoesNotExist:
- return None
+ try:
+ return Repository.get(Repository.id == repo_id)
+ except Repository.DoesNotExist:
+ return None
def is_repository_public(repository):
- return repository.visibility_id == _basequery.get_public_repo_visibility().id
+ return repository.visibility_id == _basequery.get_public_repo_visibility().id
def repository_is_public(namespace_name, repository_name):
- try:
- (Repository.select().join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(Repository).join(Visibility).where(Namespace.username == namespace_name,
- Repository.name == repository_name,
- Visibility.name == 'public').get())
- return True
- except Repository.DoesNotExist:
- return False
+ try:
+ (
+ Repository.select()
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(Repository)
+ .join(Visibility)
+ .where(
+ Namespace.username == namespace_name,
+ Repository.name == repository_name,
+ Visibility.name == "public",
+ )
+ .get()
+ )
+ return True
+ except Repository.DoesNotExist:
+ return False
def set_repository_visibility(repo, visibility):
- visibility_obj = Visibility.get(name=visibility)
- if not visibility_obj:
- return
+ visibility_obj = Visibility.get(name=visibility)
+ if not visibility_obj:
+ return
- repo.visibility = visibility_obj
- repo.save()
+ repo.visibility = visibility_obj
+ repo.save()
def get_email_authorized_for_repo(namespace, repository, email):
- try:
- return (RepositoryAuthorizedEmail.select(RepositoryAuthorizedEmail, Repository, Namespace)
- .join(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace, Repository.name == repository,
- RepositoryAuthorizedEmail.email == email).get())
- except RepositoryAuthorizedEmail.DoesNotExist:
- return None
+ try:
+ return (
+ RepositoryAuthorizedEmail.select(
+ RepositoryAuthorizedEmail, Repository, Namespace
+ )
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ Namespace.username == namespace,
+ Repository.name == repository,
+ RepositoryAuthorizedEmail.email == email,
+ )
+ .get()
+ )
+ except RepositoryAuthorizedEmail.DoesNotExist:
+ return None
def create_email_authorization_for_repo(namespace_name, repository_name, email):
- try:
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- except Repository.DoesNotExist:
- raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
+ try:
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ except Repository.DoesNotExist:
+ raise DataModelException(
+ "Invalid repository %s/%s" % (namespace_name, repository_name)
+ )
- return RepositoryAuthorizedEmail.create(repository=repo, email=email, confirmed=False)
+ return RepositoryAuthorizedEmail.create(
+ repository=repo, email=email, confirmed=False
+ )
def confirm_email_authorization_for_repo(code):
- try:
- found = (RepositoryAuthorizedEmail.select(RepositoryAuthorizedEmail, Repository, Namespace)
- .join(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(RepositoryAuthorizedEmail.code == code).get())
- except RepositoryAuthorizedEmail.DoesNotExist:
- raise DataModelException('Invalid confirmation code.')
+ try:
+ found = (
+ RepositoryAuthorizedEmail.select(
+ RepositoryAuthorizedEmail, Repository, Namespace
+ )
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(RepositoryAuthorizedEmail.code == code)
+ .get()
+ )
+ except RepositoryAuthorizedEmail.DoesNotExist:
+ raise DataModelException("Invalid confirmation code.")
- found.confirmed = True
- found.save()
+ found.confirmed = True
+ found.save()
- return found
+ return found
def is_empty(namespace_name, repository_name):
- """ Returns if the repository referenced by the given namespace and name is empty. If the repo
+ """ Returns if the repository referenced by the given namespace and name is empty. If the repo
doesn't exist, returns True.
"""
- try:
- tag.list_repository_tags(namespace_name, repository_name).limit(1).get()
- return False
- except RepositoryTag.DoesNotExist:
- return True
+ try:
+ tag.list_repository_tags(namespace_name, repository_name).limit(1).get()
+ return False
+ except RepositoryTag.DoesNotExist:
+ return True
def get_repository_state(namespace_name, repository_name):
- """ Return the Repository State if the Repository exists. Otherwise, returns None. """
- repo = get_repository(namespace_name, repository_name)
- if repo:
- return repo.state
+ """ Return the Repository State if the Repository exists. Otherwise, returns None. """
+ repo = get_repository(namespace_name, repository_name)
+ if repo:
+ return repo.state
- return None
+ return None
def set_repository_state(repo, state):
- repo.state = state
- repo.save()
+ repo.state = state
+ repo.save()
diff --git a/data/model/repositoryactioncount.py b/data/model/repositoryactioncount.py
index 759edc093..7e9364479 100644
--- a/data/model/repositoryactioncount.py
+++ b/data/model/repositoryactioncount.py
@@ -4,12 +4,20 @@ from collections import namedtuple
from peewee import IntegrityError
from datetime import date, timedelta, datetime
-from data.database import (Repository, LogEntry, LogEntry2, LogEntry3, RepositoryActionCount,
- RepositorySearchScore, db_random_func, fn)
+from data.database import (
+ Repository,
+ LogEntry,
+ LogEntry2,
+ LogEntry3,
+ RepositoryActionCount,
+ RepositorySearchScore,
+ db_random_func,
+ fn,
+)
logger = logging.getLogger(__name__)
-search_bucket = namedtuple('SearchBucket', ['delta', 'days', 'weight'])
+search_bucket = namedtuple("SearchBucket", ["delta", "days", "weight"])
# Defines the various buckets for search scoring. Each bucket is computed using the given time
# delta from today *minus the previous bucket's time period*. Once all the actions over the
@@ -17,113 +25,132 @@ search_bucket = namedtuple('SearchBucket', ['delta', 'days', 'weight'])
# for this bucket were determined via the integral of (2/((x/183)+1)^2)/183 over the period of days
# in the bucket; this integral over 0..183 has a sum of 1, so we get a good normalize score result.
SEARCH_BUCKETS = [
- search_bucket(timedelta(days=1), 1, 0.010870),
- search_bucket(timedelta(days=7), 6, 0.062815),
- search_bucket(timedelta(days=31), 24, 0.21604),
- search_bucket(timedelta(days=183), 152, 0.71028),
+ search_bucket(timedelta(days=1), 1, 0.010870),
+ search_bucket(timedelta(days=7), 6, 0.062815),
+ search_bucket(timedelta(days=31), 24, 0.21604),
+ search_bucket(timedelta(days=183), 152, 0.71028),
]
+
def find_uncounted_repository():
- """ Returns a repository that has not yet had an entry added into the RepositoryActionCount
+ """ Returns a repository that has not yet had an entry added into the RepositoryActionCount
table for yesterday.
"""
- try:
- # Get a random repository to count.
- today = date.today()
- yesterday = today - timedelta(days=1)
- has_yesterday_actions = (RepositoryActionCount
- .select(RepositoryActionCount.repository)
- .where(RepositoryActionCount.date == yesterday))
+ try:
+ # Get a random repository to count.
+ today = date.today()
+ yesterday = today - timedelta(days=1)
+ has_yesterday_actions = RepositoryActionCount.select(
+ RepositoryActionCount.repository
+ ).where(RepositoryActionCount.date == yesterday)
- to_count = (Repository
- .select()
- .where(~(Repository.id << (has_yesterday_actions)))
- .order_by(db_random_func()).get())
- return to_count
- except Repository.DoesNotExist:
- return None
+ to_count = (
+ Repository.select()
+ .where(~(Repository.id << (has_yesterday_actions)))
+ .order_by(db_random_func())
+ .get()
+ )
+ return to_count
+ except Repository.DoesNotExist:
+ return None
def count_repository_actions(to_count, day):
- """ Aggregates repository actions from the LogEntry table for the specified day. Returns the
+ """ Aggregates repository actions from the LogEntry table for the specified day. Returns the
count or None on error.
"""
- # TODO: Clean this up a bit.
- def lookup_action_count(model):
- return (model
- .select()
- .where(model.repository == to_count,
- model.datetime >= day,
- model.datetime < (day + timedelta(days=1)))
- .count())
+ # TODO: Clean this up a bit.
+ def lookup_action_count(model):
+ return (
+ model.select()
+ .where(
+ model.repository == to_count,
+ model.datetime >= day,
+ model.datetime < (day + timedelta(days=1)),
+ )
+ .count()
+ )
- actions = (lookup_action_count(LogEntry3) + lookup_action_count(LogEntry2) +
- lookup_action_count(LogEntry))
+ actions = (
+ lookup_action_count(LogEntry3)
+ + lookup_action_count(LogEntry2)
+ + lookup_action_count(LogEntry)
+ )
- return actions
+ return actions
def store_repository_action_count(repository, day, action_count):
- """ Stores the action count for a repository for a specific day. Returns False if the
+ """ Stores the action count for a repository for a specific day. Returns False if the
repository already has an entry for the specified day.
"""
- try:
- RepositoryActionCount.create(repository=repository, date=day, count=action_count)
- return True
- except IntegrityError:
- logger.debug('Count already written for repository %s', repository.id)
- return False
+ try:
+ RepositoryActionCount.create(
+ repository=repository, date=day, count=action_count
+ )
+ return True
+ except IntegrityError:
+ logger.debug("Count already written for repository %s", repository.id)
+ return False
def update_repository_score(repo):
- """ Updates the repository score entry for the given table by retrieving information from
+ """ Updates the repository score entry for the given table by retrieving information from
the RepositoryActionCount table. Note that count_repository_actions for the repo should
be called first. Returns True if the row was updated and False otherwise.
"""
- today = date.today()
+ today = date.today()
- # Retrieve the counts for each bucket and calculate the final score.
- final_score = 0.0
- last_end_timedelta = timedelta(days=0)
+ # Retrieve the counts for each bucket and calculate the final score.
+ final_score = 0.0
+ last_end_timedelta = timedelta(days=0)
- for bucket in SEARCH_BUCKETS:
- start_date = today - bucket.delta
- end_date = today - last_end_timedelta
- last_end_timedelta = bucket.delta
+ for bucket in SEARCH_BUCKETS:
+ start_date = today - bucket.delta
+ end_date = today - last_end_timedelta
+ last_end_timedelta = bucket.delta
- query = (RepositoryActionCount
- .select(fn.Sum(RepositoryActionCount.count), fn.Count(RepositoryActionCount.id))
- .where(RepositoryActionCount.date >= start_date,
- RepositoryActionCount.date < end_date,
- RepositoryActionCount.repository == repo))
+ query = RepositoryActionCount.select(
+ fn.Sum(RepositoryActionCount.count), fn.Count(RepositoryActionCount.id)
+ ).where(
+ RepositoryActionCount.date >= start_date,
+ RepositoryActionCount.date < end_date,
+ RepositoryActionCount.repository == repo,
+ )
- bucket_tuple = query.tuples()[0]
- logger.debug('Got bucket tuple %s for bucket %s for repository %s', bucket_tuple, bucket,
- repo.id)
+ bucket_tuple = query.tuples()[0]
+ logger.debug(
+ "Got bucket tuple %s for bucket %s for repository %s",
+ bucket_tuple,
+ bucket,
+ repo.id,
+ )
- if bucket_tuple[0] is None:
- continue
+ if bucket_tuple[0] is None:
+ continue
- bucket_sum = float(bucket_tuple[0])
- bucket_count = int(bucket_tuple[1])
- if not bucket_count:
- continue
+ bucket_sum = float(bucket_tuple[0])
+ bucket_count = int(bucket_tuple[1])
+ if not bucket_count:
+ continue
- bucket_score = bucket_sum / (bucket_count * 1.0)
- final_score += bucket_score * bucket.weight
+ bucket_score = bucket_sum / (bucket_count * 1.0)
+ final_score += bucket_score * bucket.weight
- # Update the existing repo search score row or create a new one.
- normalized_score = int(final_score * 100.0)
- try:
+ # Update the existing repo search score row or create a new one.
+ normalized_score = int(final_score * 100.0)
try:
- search_score_row = RepositorySearchScore.get(repository=repo)
- search_score_row.last_updated = datetime.now()
- search_score_row.score = normalized_score
- search_score_row.save()
- return True
- except RepositorySearchScore.DoesNotExist:
- RepositorySearchScore.create(repository=repo, score=normalized_score, last_updated=today)
- return True
- except IntegrityError:
- logger.debug('RepositorySearchScore row already existed; skipping')
- return False
+ try:
+ search_score_row = RepositorySearchScore.get(repository=repo)
+ search_score_row.last_updated = datetime.now()
+ search_score_row.score = normalized_score
+ search_score_row.save()
+ return True
+ except RepositorySearchScore.DoesNotExist:
+ RepositorySearchScore.create(
+ repository=repo, score=normalized_score, last_updated=today
+ )
+ return True
+ except IntegrityError:
+ logger.debug("RepositorySearchScore row already existed; skipping")
+ return False
diff --git a/data/model/service_keys.py b/data/model/service_keys.py
index eb460299b..67ca282ae 100644
--- a/data/model/service_keys.py
+++ b/data/model/service_keys.py
@@ -8,198 +8,258 @@ from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from data.database import db_for_update, User, ServiceKey, ServiceKeyApproval
-from data.model import (ServiceKeyDoesNotExist, ServiceKeyAlreadyApproved, ServiceNameInvalid,
- db_transaction, config)
-from data.model.notification import create_notification, delete_all_notifications_by_path_prefix
+from data.model import (
+ ServiceKeyDoesNotExist,
+ ServiceKeyAlreadyApproved,
+ ServiceNameInvalid,
+ db_transaction,
+ config,
+)
+from data.model.notification import (
+ create_notification,
+ delete_all_notifications_by_path_prefix,
+)
from util.security.fingerprint import canonical_kid
-_SERVICE_NAME_REGEX = re.compile(r'^[a-z0-9_]+$')
+_SERVICE_NAME_REGEX = re.compile(r"^[a-z0-9_]+$")
+
def _expired_keys_clause(service):
- return ((ServiceKey.service == service) &
- (ServiceKey.expiration_date <= datetime.utcnow()))
+ return (ServiceKey.service == service) & (
+ ServiceKey.expiration_date <= datetime.utcnow()
+ )
def _stale_expired_keys_service_clause(service):
- return ((ServiceKey.service == service) & _stale_expired_keys_clause())
+ return (ServiceKey.service == service) & _stale_expired_keys_clause()
def _stale_expired_keys_clause():
- expired_ttl = timedelta(seconds=config.app_config['EXPIRED_SERVICE_KEY_TTL_SEC'])
- return (ServiceKey.expiration_date <= (datetime.utcnow() - expired_ttl))
+ expired_ttl = timedelta(seconds=config.app_config["EXPIRED_SERVICE_KEY_TTL_SEC"])
+ return ServiceKey.expiration_date <= (datetime.utcnow() - expired_ttl)
def _stale_unapproved_keys_clause(service):
- unapproved_ttl = timedelta(seconds=config.app_config['UNAPPROVED_SERVICE_KEY_TTL_SEC'])
- return ((ServiceKey.service == service) &
- (ServiceKey.approval >> None) &
- (ServiceKey.created_date <= (datetime.utcnow() - unapproved_ttl)))
+ unapproved_ttl = timedelta(
+ seconds=config.app_config["UNAPPROVED_SERVICE_KEY_TTL_SEC"]
+ )
+ return (
+ (ServiceKey.service == service)
+ & (ServiceKey.approval >> None)
+ & (ServiceKey.created_date <= (datetime.utcnow() - unapproved_ttl))
+ )
def _gc_expired(service):
- ServiceKey.delete().where(_stale_expired_keys_service_clause(service) |
- _stale_unapproved_keys_clause(service)).execute()
+ ServiceKey.delete().where(
+ _stale_expired_keys_service_clause(service)
+ | _stale_unapproved_keys_clause(service)
+ ).execute()
def _verify_service_name(service_name):
- if not _SERVICE_NAME_REGEX.match(service_name):
- raise ServiceNameInvalid
+ if not _SERVICE_NAME_REGEX.match(service_name):
+ raise ServiceNameInvalid
def _notify_superusers(key):
- notification_metadata = {
- 'name': key.name,
- 'kid': key.kid,
- 'service': key.service,
- 'jwk': key.jwk,
- 'metadata': key.metadata,
- 'created_date': timegm(key.created_date.utctimetuple()),
- }
+ notification_metadata = {
+ "name": key.name,
+ "kid": key.kid,
+ "service": key.service,
+ "jwk": key.jwk,
+ "metadata": key.metadata,
+ "created_date": timegm(key.created_date.utctimetuple()),
+ }
- if key.expiration_date is not None:
- notification_metadata['expiration_date'] = timegm(key.expiration_date.utctimetuple())
+ if key.expiration_date is not None:
+ notification_metadata["expiration_date"] = timegm(
+ key.expiration_date.utctimetuple()
+ )
- if len(config.app_config['SUPER_USERS']) > 0:
- superusers = User.select().where(User.username << config.app_config['SUPER_USERS'])
- for superuser in superusers:
- create_notification('service_key_submitted', superuser, metadata=notification_metadata,
- lookup_path='/service_key_approval/{0}/{1}'.format(key.kid, superuser.id))
+ if len(config.app_config["SUPER_USERS"]) > 0:
+ superusers = User.select().where(
+ User.username << config.app_config["SUPER_USERS"]
+ )
+ for superuser in superusers:
+ create_notification(
+ "service_key_submitted",
+ superuser,
+ metadata=notification_metadata,
+ lookup_path="/service_key_approval/{0}/{1}".format(
+ key.kid, superuser.id
+ ),
+ )
-def create_service_key(name, kid, service, jwk, metadata, expiration_date, rotation_duration=None):
- _verify_service_name(service)
- _gc_expired(service)
+def create_service_key(
+ name, kid, service, jwk, metadata, expiration_date, rotation_duration=None
+):
+ _verify_service_name(service)
+ _gc_expired(service)
- key = ServiceKey.create(name=name, kid=kid, service=service, jwk=jwk, metadata=metadata,
- expiration_date=expiration_date, rotation_duration=rotation_duration)
+ key = ServiceKey.create(
+ name=name,
+ kid=kid,
+ service=service,
+ jwk=jwk,
+ metadata=metadata,
+ expiration_date=expiration_date,
+ rotation_duration=rotation_duration,
+ )
- _notify_superusers(key)
- return key
+ _notify_superusers(key)
+ return key
-def generate_service_key(service, expiration_date, kid=None, name='', metadata=None,
- rotation_duration=None):
- private_key = RSA.generate(2048)
- jwk = RSAKey(key=private_key.publickey()).serialize()
- if kid is None:
- kid = canonical_kid(jwk)
+def generate_service_key(
+ service, expiration_date, kid=None, name="", metadata=None, rotation_duration=None
+):
+ private_key = RSA.generate(2048)
+ jwk = RSAKey(key=private_key.publickey()).serialize()
+ if kid is None:
+ kid = canonical_kid(jwk)
- key = create_service_key(name, kid, service, jwk, metadata or {}, expiration_date,
- rotation_duration=rotation_duration)
- return (private_key, key)
+ key = create_service_key(
+ name,
+ kid,
+ service,
+ jwk,
+ metadata or {},
+ expiration_date,
+ rotation_duration=rotation_duration,
+ )
+ return (private_key, key)
def replace_service_key(old_kid, kid, jwk, metadata, expiration_date):
- try:
- with db_transaction():
- key = db_for_update(ServiceKey.select().where(ServiceKey.kid == old_kid)).get()
- key.metadata.update(metadata)
+ try:
+ with db_transaction():
+ key = db_for_update(
+ ServiceKey.select().where(ServiceKey.kid == old_kid)
+ ).get()
+ key.metadata.update(metadata)
- ServiceKey.create(name=key.name, kid=kid, service=key.service, jwk=jwk,
- metadata=key.metadata, expiration_date=expiration_date,
- rotation_duration=key.rotation_duration, approval=key.approval)
- key.delete_instance()
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ ServiceKey.create(
+ name=key.name,
+ kid=kid,
+ service=key.service,
+ jwk=jwk,
+ metadata=key.metadata,
+ expiration_date=expiration_date,
+ rotation_duration=key.rotation_duration,
+ approval=key.approval,
+ )
+ key.delete_instance()
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
- _notify_superusers(key)
- delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(old_kid))
- _gc_expired(key.service)
+ _notify_superusers(key)
+ delete_all_notifications_by_path_prefix("/service_key_approval/{0}".format(old_kid))
+ _gc_expired(key.service)
def update_service_key(kid, name=None, metadata=None):
- try:
- with db_transaction():
- key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
- if name is not None:
- key.name = name
+ try:
+ with db_transaction():
+ key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
+ if name is not None:
+ key.name = name
- if metadata is not None:
- key.metadata.update(metadata)
+ if metadata is not None:
+ key.metadata.update(metadata)
- key.save()
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ key.save()
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
def delete_service_key(kid):
- try:
- key = ServiceKey.get(kid=kid)
- ServiceKey.delete().where(ServiceKey.kid == kid).execute()
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ try:
+ key = ServiceKey.get(kid=kid)
+ ServiceKey.delete().where(ServiceKey.kid == kid).execute()
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
- delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(kid))
- _gc_expired(key.service)
- return key
+ delete_all_notifications_by_path_prefix("/service_key_approval/{0}".format(kid))
+ _gc_expired(key.service)
+ return key
def set_key_expiration(kid, expiration_date):
- try:
- service_key = get_service_key(kid, alive_only=False, approved_only=False)
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ try:
+ service_key = get_service_key(kid, alive_only=False, approved_only=False)
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
- service_key.expiration_date = expiration_date
- service_key.save()
+ service_key.expiration_date = expiration_date
+ service_key.save()
-def approve_service_key(kid, approval_type, approver=None, notes=''):
- try:
- with db_transaction():
- key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
- if key.approval is not None:
- raise ServiceKeyAlreadyApproved
+def approve_service_key(kid, approval_type, approver=None, notes=""):
+ try:
+ with db_transaction():
+ key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
+ if key.approval is not None:
+ raise ServiceKeyAlreadyApproved
- approval = ServiceKeyApproval.create(approver=approver, approval_type=approval_type,
- notes=notes)
- key.approval = approval
- key.save()
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ approval = ServiceKeyApproval.create(
+ approver=approver, approval_type=approval_type, notes=notes
+ )
+ key.approval = approval
+ key.save()
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
- delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(kid))
- return key
+ delete_all_notifications_by_path_prefix("/service_key_approval/{0}".format(kid))
+ return key
-def _list_service_keys_query(kid=None, service=None, approved_only=True, alive_only=True,
- approval_type=None):
- query = ServiceKey.select().join(ServiceKeyApproval, JOIN.LEFT_OUTER)
+def _list_service_keys_query(
+ kid=None, service=None, approved_only=True, alive_only=True, approval_type=None
+):
+ query = ServiceKey.select().join(ServiceKeyApproval, JOIN.LEFT_OUTER)
- if approved_only:
- query = query.where(~(ServiceKey.approval >> None))
+ if approved_only:
+ query = query.where(~(ServiceKey.approval >> None))
- if alive_only:
- query = query.where((ServiceKey.expiration_date > datetime.utcnow()) |
- (ServiceKey.expiration_date >> None))
+ if alive_only:
+ query = query.where(
+ (ServiceKey.expiration_date > datetime.utcnow())
+ | (ServiceKey.expiration_date >> None)
+ )
- if approval_type is not None:
- query = query.where(ServiceKeyApproval.approval_type == approval_type)
+ if approval_type is not None:
+ query = query.where(ServiceKeyApproval.approval_type == approval_type)
- if service is not None:
- query = query.where(ServiceKey.service == service)
- query = query.where(~(_expired_keys_clause(service)) |
- ~(_stale_unapproved_keys_clause(service)))
+ if service is not None:
+ query = query.where(ServiceKey.service == service)
+ query = query.where(
+ ~(_expired_keys_clause(service)) | ~(_stale_unapproved_keys_clause(service))
+ )
- if kid is not None:
- query = query.where(ServiceKey.kid == kid)
+ if kid is not None:
+ query = query.where(ServiceKey.kid == kid)
- query = query.where(~(_stale_expired_keys_clause()) | (ServiceKey.expiration_date >> None))
- return query
+ query = query.where(
+ ~(_stale_expired_keys_clause()) | (ServiceKey.expiration_date >> None)
+ )
+ return query
def list_all_keys():
- return list(_list_service_keys_query(approved_only=False, alive_only=False))
+ return list(_list_service_keys_query(approved_only=False, alive_only=False))
def list_service_keys(service):
- return list(_list_service_keys_query(service=service))
+ return list(_list_service_keys_query(service=service))
def get_service_key(kid, service=None, alive_only=True, approved_only=True):
- try:
- return _list_service_keys_query(kid=kid, service=service, approved_only=approved_only,
- alive_only=alive_only).get()
- except ServiceKey.DoesNotExist:
- raise ServiceKeyDoesNotExist
+ try:
+ return _list_service_keys_query(
+ kid=kid, service=service, approved_only=approved_only, alive_only=alive_only
+ ).get()
+ except ServiceKey.DoesNotExist:
+ raise ServiceKeyDoesNotExist
diff --git a/data/model/sqlalchemybridge.py b/data/model/sqlalchemybridge.py
index e469eff00..780e1643c 100644
--- a/data/model/sqlalchemybridge.py
+++ b/data/model/sqlalchemybridge.py
@@ -1,94 +1,125 @@
-from sqlalchemy import (Table, MetaData, Column, ForeignKey, Integer, String, Boolean, Text,
- DateTime, Date, BigInteger, Index, text)
-from peewee import (PrimaryKeyField, CharField, BooleanField, DateTimeField, TextField,
- ForeignKeyField, BigIntegerField, IntegerField, DateField)
+from sqlalchemy import (
+ Table,
+ MetaData,
+ Column,
+ ForeignKey,
+ Integer,
+ String,
+ Boolean,
+ Text,
+ DateTime,
+ Date,
+ BigInteger,
+ Index,
+ text,
+)
+from peewee import (
+ PrimaryKeyField,
+ CharField,
+ BooleanField,
+ DateTimeField,
+ TextField,
+ ForeignKeyField,
+ BigIntegerField,
+ IntegerField,
+ DateField,
+)
-OPTIONS_TO_COPY = [
- 'null',
- 'default',
- 'primary_key',
-]
+OPTIONS_TO_COPY = ["null", "default", "primary_key"]
-OPTION_TRANSLATIONS = {
- 'null': 'nullable',
-}
+OPTION_TRANSLATIONS = {"null": "nullable"}
+
def gen_sqlalchemy_metadata(peewee_model_list):
- metadata = MetaData(naming_convention={
- "ix": 'ix_%(column_0_label)s',
- "uq": "uq_%(table_name)s_%(column_0_name)s",
- "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
- "pk": "pk_%(table_name)s"
- })
+ metadata = MetaData(
+ naming_convention={
+ "ix": "ix_%(column_0_label)s",
+ "uq": "uq_%(table_name)s_%(column_0_name)s",
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "pk": "pk_%(table_name)s",
+ }
+ )
- for model in peewee_model_list:
- meta = model._meta
+ for model in peewee_model_list:
+ meta = model._meta
- all_indexes = set(meta.indexes)
- fulltext_indexes = []
+ all_indexes = set(meta.indexes)
+ fulltext_indexes = []
- columns = []
- for field in meta.sorted_fields:
- alchemy_type = None
- col_args = []
- col_kwargs = {}
- if isinstance(field, PrimaryKeyField):
- alchemy_type = Integer
- elif isinstance(field, CharField):
- alchemy_type = String(field.max_length)
- elif isinstance(field, BooleanField):
- alchemy_type = Boolean
- elif isinstance(field, DateTimeField):
- alchemy_type = DateTime
- elif isinstance(field, DateField):
- alchemy_type = Date
- elif isinstance(field, TextField):
- alchemy_type = Text
- elif isinstance(field, ForeignKeyField):
- alchemy_type = Integer
- all_indexes.add(((field.name, ), field.unique))
- if not field.deferred:
- target_name = '%s.%s' % (field.rel_model._meta.table_name, field.rel_field.column_name)
- col_args.append(ForeignKey(target_name))
- elif isinstance(field, BigIntegerField):
- alchemy_type = BigInteger
- elif isinstance(field, IntegerField):
- alchemy_type = Integer
- else:
- raise RuntimeError('Unknown column type: %s' % field)
+ columns = []
+ for field in meta.sorted_fields:
+ alchemy_type = None
+ col_args = []
+ col_kwargs = {}
+ if isinstance(field, PrimaryKeyField):
+ alchemy_type = Integer
+ elif isinstance(field, CharField):
+ alchemy_type = String(field.max_length)
+ elif isinstance(field, BooleanField):
+ alchemy_type = Boolean
+ elif isinstance(field, DateTimeField):
+ alchemy_type = DateTime
+ elif isinstance(field, DateField):
+ alchemy_type = Date
+ elif isinstance(field, TextField):
+ alchemy_type = Text
+ elif isinstance(field, ForeignKeyField):
+ alchemy_type = Integer
+ all_indexes.add(((field.name,), field.unique))
+ if not field.deferred:
+ target_name = "%s.%s" % (
+ field.rel_model._meta.table_name,
+ field.rel_field.column_name,
+ )
+ col_args.append(ForeignKey(target_name))
+ elif isinstance(field, BigIntegerField):
+ alchemy_type = BigInteger
+ elif isinstance(field, IntegerField):
+ alchemy_type = Integer
+ else:
+ raise RuntimeError("Unknown column type: %s" % field)
- if hasattr(field, '__fulltext__'):
- # Add the fulltext index for the field, based on whether we are under MySQL or Postgres.
- fulltext_indexes.append(field.name)
+ if hasattr(field, "__fulltext__"):
+ # Add the fulltext index for the field, based on whether we are under MySQL or Postgres.
+ fulltext_indexes.append(field.name)
- for option_name in OPTIONS_TO_COPY:
- alchemy_option_name = (OPTION_TRANSLATIONS[option_name]
- if option_name in OPTION_TRANSLATIONS else option_name)
- if alchemy_option_name not in col_kwargs:
- option_val = getattr(field, option_name)
- col_kwargs[alchemy_option_name] = option_val
+ for option_name in OPTIONS_TO_COPY:
+ alchemy_option_name = (
+ OPTION_TRANSLATIONS[option_name]
+ if option_name in OPTION_TRANSLATIONS
+ else option_name
+ )
+ if alchemy_option_name not in col_kwargs:
+ option_val = getattr(field, option_name)
+ col_kwargs[alchemy_option_name] = option_val
- if field.unique or field.index:
- all_indexes.add(((field.name, ), field.unique))
+ if field.unique or field.index:
+ all_indexes.add(((field.name,), field.unique))
- new_col = Column(field.column_name, alchemy_type, *col_args, **col_kwargs)
- columns.append(new_col)
+ new_col = Column(field.column_name, alchemy_type, *col_args, **col_kwargs)
+ columns.append(new_col)
- new_table = Table(meta.table_name, metadata, *columns)
+ new_table = Table(meta.table_name, metadata, *columns)
- for col_prop_names, unique in all_indexes:
- col_names = [meta.fields[prop_name].column_name for prop_name in col_prop_names]
- index_name = '%s_%s' % (meta.table_name, '_'.join(col_names))
- col_refs = [getattr(new_table.c, col_name) for col_name in col_names]
- Index(index_name, *col_refs, unique=unique)
+ for col_prop_names, unique in all_indexes:
+ col_names = [
+ meta.fields[prop_name].column_name for prop_name in col_prop_names
+ ]
+ index_name = "%s_%s" % (meta.table_name, "_".join(col_names))
+ col_refs = [getattr(new_table.c, col_name) for col_name in col_names]
+ Index(index_name, *col_refs, unique=unique)
- for col_field_name in fulltext_indexes:
- index_name = '%s_%s__fulltext' % (meta.table_name, col_field_name)
- col_ref = getattr(new_table.c, col_field_name)
- Index(index_name, col_ref, postgresql_ops={col_field_name: 'gin_trgm_ops'},
- postgresql_using='gin',
- mysql_prefix='FULLTEXT')
+ for col_field_name in fulltext_indexes:
+ index_name = "%s_%s__fulltext" % (meta.table_name, col_field_name)
+ col_ref = getattr(new_table.c, col_field_name)
+ Index(
+ index_name,
+ col_ref,
+ postgresql_ops={col_field_name: "gin_trgm_ops"},
+ postgresql_using="gin",
+ mysql_prefix="FULLTEXT",
+ )
- return metadata
+ return metadata
diff --git a/data/model/storage.py b/data/model/storage.py
index adfa54cd9..646f3f9fc 100644
--- a/data/model/storage.py
+++ b/data/model/storage.py
@@ -4,370 +4,450 @@ from peewee import SQL, IntegrityError
from cachetools.func import lru_cache
from collections import namedtuple
-from data.model import (config, db_transaction, InvalidImageException, TorrentInfoDoesNotExist,
- DataModelException, _basequery)
-from data.database import (ImageStorage, Image, ImageStoragePlacement, ImageStorageLocation,
- ImageStorageTransformation, ImageStorageSignature,
- ImageStorageSignatureKind, Repository, Namespace, TorrentInfo, ApprBlob,
- ensure_under_transaction, ManifestBlob)
+from data.model import (
+ config,
+ db_transaction,
+ InvalidImageException,
+ TorrentInfoDoesNotExist,
+ DataModelException,
+ _basequery,
+)
+from data.database import (
+ ImageStorage,
+ Image,
+ ImageStoragePlacement,
+ ImageStorageLocation,
+ ImageStorageTransformation,
+ ImageStorageSignature,
+ ImageStorageSignatureKind,
+ Repository,
+ Namespace,
+ TorrentInfo,
+ ApprBlob,
+ ensure_under_transaction,
+ ManifestBlob,
+)
logger = logging.getLogger(__name__)
-_Location = namedtuple('location', ['id', 'name'])
+_Location = namedtuple("location", ["id", "name"])
+
@lru_cache(maxsize=1)
def get_image_locations():
- location_map = {}
- for location in ImageStorageLocation.select():
- location_tuple = _Location(location.id, location.name)
- location_map[location.id] = location_tuple
- location_map[location.name] = location_tuple
+ location_map = {}
+ for location in ImageStorageLocation.select():
+ location_tuple = _Location(location.id, location.name)
+ location_map[location.id] = location_tuple
+ location_map[location.name] = location_tuple
- return location_map
+ return location_map
def get_image_location_for_name(location_name):
- locations = get_image_locations()
- return locations[location_name]
+ locations = get_image_locations()
+ return locations[location_name]
def get_image_location_for_id(location_id):
- locations = get_image_locations()
- return locations[location_id]
+ locations = get_image_locations()
+ return locations[location_id]
def add_storage_placement(storage, location_name):
- """ Adds a storage placement for the given storage at the given location. """
- location = get_image_location_for_name(location_name)
- try:
- ImageStoragePlacement.create(location=location.id, storage=storage)
- except IntegrityError:
- # Placement already exists. Nothing to do.
- pass
+ """ Adds a storage placement for the given storage at the given location. """
+ location = get_image_location_for_name(location_name)
+ try:
+ ImageStoragePlacement.create(location=location.id, storage=storage)
+ except IntegrityError:
+ # Placement already exists. Nothing to do.
+ pass
def _orphaned_storage_query(candidate_ids):
- """ Returns the subset of the candidate ImageStorage IDs representing storages that are no
+ """ Returns the subset of the candidate ImageStorage IDs representing storages that are no
longer referenced by images.
"""
- # Issue a union query to find all storages that are still referenced by a candidate storage. This
- # is much faster than the group_by and having call we used to use here.
- nonorphaned_queries = []
- for counter, candidate_id in enumerate(candidate_ids):
- query_alias = 'q{0}'.format(counter)
+ # Issue a union query to find all storages that are still referenced by a candidate storage. This
+ # is much faster than the group_by and having call we used to use here.
+ nonorphaned_queries = []
+ for counter, candidate_id in enumerate(candidate_ids):
+ query_alias = "q{0}".format(counter)
- # TODO: remove the join with Image once fully on the OCI data model.
- storage_subq = (ImageStorage
- .select(ImageStorage.id)
- .join(Image)
- .where(ImageStorage.id == candidate_id)
- .limit(1)
- .alias(query_alias))
+ # TODO: remove the join with Image once fully on the OCI data model.
+ storage_subq = (
+ ImageStorage.select(ImageStorage.id)
+ .join(Image)
+ .where(ImageStorage.id == candidate_id)
+ .limit(1)
+ .alias(query_alias)
+ )
- nonorphaned_queries.append(ImageStorage
- .select(SQL('*'))
- .from_(storage_subq))
+ nonorphaned_queries.append(ImageStorage.select(SQL("*")).from_(storage_subq))
- manifest_storage_subq = (ImageStorage
- .select(ImageStorage.id)
- .join(ManifestBlob)
- .where(ImageStorage.id == candidate_id)
- .limit(1)
- .alias(query_alias))
+ manifest_storage_subq = (
+ ImageStorage.select(ImageStorage.id)
+ .join(ManifestBlob)
+ .where(ImageStorage.id == candidate_id)
+ .limit(1)
+ .alias(query_alias)
+ )
- nonorphaned_queries.append(ImageStorage
- .select(SQL('*'))
- .from_(manifest_storage_subq))
+ nonorphaned_queries.append(
+ ImageStorage.select(SQL("*")).from_(manifest_storage_subq)
+ )
- # Build the set of storages that are missing. These storages are orphaned.
- nonorphaned_storage_ids = {storage.id for storage
- in _basequery.reduce_as_tree(nonorphaned_queries)}
- return list(candidate_ids - nonorphaned_storage_ids)
+ # Build the set of storages that are missing. These storages are orphaned.
+ nonorphaned_storage_ids = {
+ storage.id for storage in _basequery.reduce_as_tree(nonorphaned_queries)
+ }
+ return list(candidate_ids - nonorphaned_storage_ids)
def garbage_collect_storage(storage_id_whitelist):
- """ Performs GC on a possible subset of the storage's with the IDs found in the
+ """ Performs GC on a possible subset of the storage's with the IDs found in the
whitelist. The storages in the whitelist will be checked, and any orphaned will
be removed, with those IDs being returned.
"""
- if len(storage_id_whitelist) == 0:
- return []
+ if len(storage_id_whitelist) == 0:
+ return []
- def placements_to_filtered_paths_set(placements_list):
- """ Returns the list of paths to remove from storage, filtered from the given placements
+ def placements_to_filtered_paths_set(placements_list):
+ """ Returns the list of paths to remove from storage, filtered from the given placements
query by removing any CAS paths that are still referenced by storage(s) in the database.
"""
- with ensure_under_transaction():
- if not placements_list:
- return set()
+ with ensure_under_transaction():
+ if not placements_list:
+ return set()
- # Find the content checksums not referenced by other storages. Any that are, we cannot
- # remove.
- content_checksums = set([placement.storage.content_checksum for placement in placements_list
- if placement.storage.cas_path])
+ # Find the content checksums not referenced by other storages. Any that are, we cannot
+ # remove.
+ content_checksums = set(
+ [
+ placement.storage.content_checksum
+ for placement in placements_list
+ if placement.storage.cas_path
+ ]
+ )
- unreferenced_checksums = set()
- if content_checksums:
- # Check the current image storage.
- query = (ImageStorage
- .select(ImageStorage.content_checksum)
- .where(ImageStorage.content_checksum << list(content_checksums)))
- is_referenced_checksums = set([image_storage.content_checksum for image_storage in query])
- if is_referenced_checksums:
- logger.warning('GC attempted to remove CAS checksums %s, which are still IS referenced',
- is_referenced_checksums)
+ unreferenced_checksums = set()
+ if content_checksums:
+ # Check the current image storage.
+ query = ImageStorage.select(ImageStorage.content_checksum).where(
+ ImageStorage.content_checksum << list(content_checksums)
+ )
+ is_referenced_checksums = set(
+ [image_storage.content_checksum for image_storage in query]
+ )
+ if is_referenced_checksums:
+ logger.warning(
+ "GC attempted to remove CAS checksums %s, which are still IS referenced",
+ is_referenced_checksums,
+ )
- # Check the ApprBlob table as well.
- query = ApprBlob.select(ApprBlob.digest).where(ApprBlob.digest << list(content_checksums))
- appr_blob_referenced_checksums = set([blob.digest for blob in query])
- if appr_blob_referenced_checksums:
- logger.warning('GC attempted to remove CAS checksums %s, which are ApprBlob referenced',
- appr_blob_referenced_checksums)
+ # Check the ApprBlob table as well.
+ query = ApprBlob.select(ApprBlob.digest).where(
+ ApprBlob.digest << list(content_checksums)
+ )
+ appr_blob_referenced_checksums = set([blob.digest for blob in query])
+ if appr_blob_referenced_checksums:
+ logger.warning(
+ "GC attempted to remove CAS checksums %s, which are ApprBlob referenced",
+ appr_blob_referenced_checksums,
+ )
- unreferenced_checksums = (content_checksums - appr_blob_referenced_checksums -
- is_referenced_checksums)
+ unreferenced_checksums = (
+ content_checksums
+ - appr_blob_referenced_checksums
+ - is_referenced_checksums
+ )
- # Return all placements for all image storages found not at a CAS path or with a content
- # checksum that is referenced.
- return {(get_image_location_for_id(placement.location_id).name,
- get_layer_path(placement.storage))
- for placement in placements_list
- if not placement.storage.cas_path or
- placement.storage.content_checksum in unreferenced_checksums}
+ # Return all placements for all image storages found not at a CAS path or with a content
+ # checksum that is referenced.
+ return {
+ (
+ get_image_location_for_id(placement.location_id).name,
+ get_layer_path(placement.storage),
+ )
+ for placement in placements_list
+ if not placement.storage.cas_path
+ or placement.storage.content_checksum in unreferenced_checksums
+ }
- # Note: Both of these deletes must occur in the same transaction (unfortunately) because a
- # storage without any placement is invalid, and a placement cannot exist without a storage.
- # TODO: We might want to allow for null storages on placements, which would allow us to
- # delete the storages, then delete the placements in a non-transaction.
- logger.debug('Garbage collecting storages from candidates: %s', storage_id_whitelist)
- with db_transaction():
- orphaned_storage_ids = _orphaned_storage_query(storage_id_whitelist)
- if len(orphaned_storage_ids) == 0:
- # Nothing to GC.
- return []
+ # Note: Both of these deletes must occur in the same transaction (unfortunately) because a
+ # storage without any placement is invalid, and a placement cannot exist without a storage.
+ # TODO: We might want to allow for null storages on placements, which would allow us to
+ # delete the storages, then delete the placements in a non-transaction.
+ logger.debug(
+ "Garbage collecting storages from candidates: %s", storage_id_whitelist
+ )
+ with db_transaction():
+ orphaned_storage_ids = _orphaned_storage_query(storage_id_whitelist)
+ if len(orphaned_storage_ids) == 0:
+ # Nothing to GC.
+ return []
- placements_to_remove = list(ImageStoragePlacement
- .select(ImageStoragePlacement, ImageStorage)
- .join(ImageStorage)
- .where(ImageStorage.id << orphaned_storage_ids))
+ placements_to_remove = list(
+ ImageStoragePlacement.select(ImageStoragePlacement, ImageStorage)
+ .join(ImageStorage)
+ .where(ImageStorage.id << orphaned_storage_ids)
+ )
- # Remove the placements for orphaned storages
- if len(placements_to_remove) > 0:
- placement_ids_to_remove = [placement.id for placement in placements_to_remove]
- placements_removed = (ImageStoragePlacement
- .delete()
- .where(ImageStoragePlacement.id << placement_ids_to_remove)
- .execute())
- logger.debug('Removed %s image storage placements', placements_removed)
+ # Remove the placements for orphaned storages
+ if len(placements_to_remove) > 0:
+ placement_ids_to_remove = [
+ placement.id for placement in placements_to_remove
+ ]
+ placements_removed = (
+ ImageStoragePlacement.delete()
+ .where(ImageStoragePlacement.id << placement_ids_to_remove)
+ .execute()
+ )
+ logger.debug("Removed %s image storage placements", placements_removed)
- # Remove all orphaned storages
- torrents_removed = (TorrentInfo
- .delete()
- .where(TorrentInfo.storage << orphaned_storage_ids)
- .execute())
- logger.debug('Removed %s torrent info records', torrents_removed)
+ # Remove all orphaned storages
+ torrents_removed = (
+ TorrentInfo.delete()
+ .where(TorrentInfo.storage << orphaned_storage_ids)
+ .execute()
+ )
+ logger.debug("Removed %s torrent info records", torrents_removed)
- signatures_removed = (ImageStorageSignature
- .delete()
- .where(ImageStorageSignature.storage << orphaned_storage_ids)
- .execute())
- logger.debug('Removed %s image storage signatures', signatures_removed)
+ signatures_removed = (
+ ImageStorageSignature.delete()
+ .where(ImageStorageSignature.storage << orphaned_storage_ids)
+ .execute()
+ )
+ logger.debug("Removed %s image storage signatures", signatures_removed)
- storages_removed = (ImageStorage
- .delete()
- .where(ImageStorage.id << orphaned_storage_ids)
- .execute())
- logger.debug('Removed %s image storage records', storages_removed)
+ storages_removed = (
+ ImageStorage.delete()
+ .where(ImageStorage.id << orphaned_storage_ids)
+ .execute()
+ )
+ logger.debug("Removed %s image storage records", storages_removed)
- # Determine the paths to remove. We cannot simply remove all paths matching storages, as CAS
- # can share the same path. We further filter these paths by checking for any storages still in
- # the database with the same content checksum.
- paths_to_remove = placements_to_filtered_paths_set(placements_to_remove)
+ # Determine the paths to remove. We cannot simply remove all paths matching storages, as CAS
+ # can share the same path. We further filter these paths by checking for any storages still in
+ # the database with the same content checksum.
+ paths_to_remove = placements_to_filtered_paths_set(placements_to_remove)
- # We are going to make the conscious decision to not delete image storage blobs inside
- # transactions.
- # This may end up producing garbage in s3, trading off for higher availability in the database.
- for location_name, image_path in paths_to_remove:
- logger.debug('Removing %s from %s', image_path, location_name)
- config.store.remove({location_name}, image_path)
+ # We are going to make the conscious decision to not delete image storage blobs inside
+ # transactions.
+ # This may end up producing garbage in s3, trading off for higher availability in the database.
+ for location_name, image_path in paths_to_remove:
+ logger.debug("Removing %s from %s", image_path, location_name)
+ config.store.remove({location_name}, image_path)
- return orphaned_storage_ids
+ return orphaned_storage_ids
def create_v1_storage(location_name):
- storage = ImageStorage.create(cas_path=False, uploading=True)
- location = get_image_location_for_name(location_name)
- ImageStoragePlacement.create(location=location.id, storage=storage)
- storage.locations = {location_name}
- return storage
+ storage = ImageStorage.create(cas_path=False, uploading=True)
+ location = get_image_location_for_name(location_name)
+ ImageStoragePlacement.create(location=location.id, storage=storage)
+ storage.locations = {location_name}
+ return storage
def find_or_create_storage_signature(storage, signature_kind_name):
- found = lookup_storage_signature(storage, signature_kind_name)
- if found is None:
- kind = ImageStorageSignatureKind.get(name=signature_kind_name)
- found = ImageStorageSignature.create(storage=storage, kind=kind)
+ found = lookup_storage_signature(storage, signature_kind_name)
+ if found is None:
+ kind = ImageStorageSignatureKind.get(name=signature_kind_name)
+ found = ImageStorageSignature.create(storage=storage, kind=kind)
- return found
+ return found
def lookup_storage_signature(storage, signature_kind_name):
- kind = ImageStorageSignatureKind.get(name=signature_kind_name)
- try:
- return (ImageStorageSignature
- .select()
- .where(ImageStorageSignature.storage == storage, ImageStorageSignature.kind == kind)
- .get())
- except ImageStorageSignature.DoesNotExist:
- return None
+ kind = ImageStorageSignatureKind.get(name=signature_kind_name)
+ try:
+ return (
+ ImageStorageSignature.select()
+ .where(
+ ImageStorageSignature.storage == storage,
+ ImageStorageSignature.kind == kind,
+ )
+ .get()
+ )
+ except ImageStorageSignature.DoesNotExist:
+ return None
def _get_storage(query_modifier):
- query = (ImageStoragePlacement
- .select(ImageStoragePlacement, ImageStorage)
- .switch(ImageStoragePlacement)
- .join(ImageStorage))
+ query = (
+ ImageStoragePlacement.select(ImageStoragePlacement, ImageStorage)
+ .switch(ImageStoragePlacement)
+ .join(ImageStorage)
+ )
- placements = list(query_modifier(query))
+ placements = list(query_modifier(query))
- if not placements:
- raise InvalidImageException()
+ if not placements:
+ raise InvalidImageException()
- found = placements[0].storage
- found.locations = {get_image_location_for_id(placement.location_id).name
- for placement in placements}
- return found
+ found = placements[0].storage
+ found.locations = {
+ get_image_location_for_id(placement.location_id).name
+ for placement in placements
+ }
+ return found
def get_storage_by_uuid(storage_uuid):
- def filter_to_uuid(query):
- return query.where(ImageStorage.uuid == storage_uuid)
+ def filter_to_uuid(query):
+ return query.where(ImageStorage.uuid == storage_uuid)
- try:
- return _get_storage(filter_to_uuid)
- except InvalidImageException:
- raise InvalidImageException('No storage found with uuid: %s', storage_uuid)
+ try:
+ return _get_storage(filter_to_uuid)
+ except InvalidImageException:
+ raise InvalidImageException("No storage found with uuid: %s", storage_uuid)
def get_layer_path(storage_record):
- """ Returns the path in the storage engine to the layer data referenced by the storage row. """
- assert storage_record.cas_path is not None
- return get_layer_path_for_storage(storage_record.uuid, storage_record.cas_path,
- storage_record.content_checksum)
+ """ Returns the path in the storage engine to the layer data referenced by the storage row. """
+ assert storage_record.cas_path is not None
+ return get_layer_path_for_storage(
+ storage_record.uuid, storage_record.cas_path, storage_record.content_checksum
+ )
def get_layer_path_for_storage(storage_uuid, cas_path, content_checksum):
- """ Returns the path in the storage engine to the layer data referenced by the storage
+ """ Returns the path in the storage engine to the layer data referenced by the storage
information. """
- store = config.store
- if not cas_path:
- logger.debug('Serving layer from legacy v1 path for storage %s', storage_uuid)
- return store.v1_image_layer_path(storage_uuid)
+ store = config.store
+ if not cas_path:
+ logger.debug("Serving layer from legacy v1 path for storage %s", storage_uuid)
+ return store.v1_image_layer_path(storage_uuid)
- return store.blob_path(content_checksum)
+ return store.blob_path(content_checksum)
def lookup_repo_storages_by_content_checksum(repo, checksums, by_manifest=False):
- """ Looks up repository storages (without placements) matching the given repository
+ """ Looks up repository storages (without placements) matching the given repository
and checksum. """
- if not checksums:
- return []
+ if not checksums:
+ return []
- # There may be many duplicates of the checksums, so for performance reasons we are going
- # to use a union to select just one storage with each checksum
- queries = []
+ # There may be many duplicates of the checksums, so for performance reasons we are going
+ # to use a union to select just one storage with each checksum
+ queries = []
- for counter, checksum in enumerate(set(checksums)):
- query_alias = 'q{0}'.format(counter)
+ for counter, checksum in enumerate(set(checksums)):
+ query_alias = "q{0}".format(counter)
- # TODO: Remove once we have a new-style model for tracking temp uploaded blobs and
- # all legacy tables have been removed.
- if by_manifest:
- candidate_subq = (ImageStorage
- .select(ImageStorage.id, ImageStorage.content_checksum,
- ImageStorage.image_size, ImageStorage.uuid, ImageStorage.cas_path,
- ImageStorage.uncompressed_size, ImageStorage.uploading)
- .join(ManifestBlob)
- .where(ManifestBlob.repository == repo,
- ImageStorage.content_checksum == checksum)
- .limit(1)
- .alias(query_alias))
- else:
- candidate_subq = (ImageStorage
- .select(ImageStorage.id, ImageStorage.content_checksum,
- ImageStorage.image_size, ImageStorage.uuid, ImageStorage.cas_path,
- ImageStorage.uncompressed_size, ImageStorage.uploading)
- .join(Image)
- .where(Image.repository == repo, ImageStorage.content_checksum == checksum)
- .limit(1)
- .alias(query_alias))
+ # TODO: Remove once we have a new-style model for tracking temp uploaded blobs and
+ # all legacy tables have been removed.
+ if by_manifest:
+ candidate_subq = (
+ ImageStorage.select(
+ ImageStorage.id,
+ ImageStorage.content_checksum,
+ ImageStorage.image_size,
+ ImageStorage.uuid,
+ ImageStorage.cas_path,
+ ImageStorage.uncompressed_size,
+ ImageStorage.uploading,
+ )
+ .join(ManifestBlob)
+ .where(
+ ManifestBlob.repository == repo,
+ ImageStorage.content_checksum == checksum,
+ )
+ .limit(1)
+ .alias(query_alias)
+ )
+ else:
+ candidate_subq = (
+ ImageStorage.select(
+ ImageStorage.id,
+ ImageStorage.content_checksum,
+ ImageStorage.image_size,
+ ImageStorage.uuid,
+ ImageStorage.cas_path,
+ ImageStorage.uncompressed_size,
+ ImageStorage.uploading,
+ )
+ .join(Image)
+ .where(
+ Image.repository == repo, ImageStorage.content_checksum == checksum
+ )
+ .limit(1)
+ .alias(query_alias)
+ )
- queries.append(ImageStorage
- .select(SQL('*'))
- .from_(candidate_subq))
+ queries.append(ImageStorage.select(SQL("*")).from_(candidate_subq))
- return _basequery.reduce_as_tree(queries)
+ return _basequery.reduce_as_tree(queries)
-def set_image_storage_metadata(docker_image_id, namespace_name, repository_name, image_size,
- uncompressed_size):
- """ Sets metadata that is specific to the binary storage of the data, irrespective of how it
+def set_image_storage_metadata(
+ docker_image_id, namespace_name, repository_name, image_size, uncompressed_size
+):
+ """ Sets metadata that is specific to the binary storage of the data, irrespective of how it
is used in the layer tree.
"""
- if image_size is None:
- raise DataModelException('Empty image size field')
+ if image_size is None:
+ raise DataModelException("Empty image size field")
- try:
- image = (Image
- .select(Image, ImageStorage)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(Image)
- .join(ImageStorage)
- .where(Repository.name == repository_name, Namespace.username == namespace_name,
- Image.docker_image_id == docker_image_id)
- .get())
- except ImageStorage.DoesNotExist:
- raise InvalidImageException('No image with specified id and repository')
+ try:
+ image = (
+ Image.select(Image, ImageStorage)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(Image)
+ .join(ImageStorage)
+ .where(
+ Repository.name == repository_name,
+ Namespace.username == namespace_name,
+ Image.docker_image_id == docker_image_id,
+ )
+ .get()
+ )
+ except ImageStorage.DoesNotExist:
+ raise InvalidImageException("No image with specified id and repository")
- # We MUST do this here, it can't be done in the corresponding image call because the storage
- # has not yet been pushed
- image.aggregate_size = _basequery.calculate_image_aggregate_size(image.ancestors, image_size,
- image.parent)
- image.save()
+ # We MUST do this here, it can't be done in the corresponding image call because the storage
+ # has not yet been pushed
+ image.aggregate_size = _basequery.calculate_image_aggregate_size(
+ image.ancestors, image_size, image.parent
+ )
+ image.save()
- image.storage.image_size = image_size
- image.storage.uncompressed_size = uncompressed_size
- image.storage.save()
- return image.storage
+ image.storage.image_size = image_size
+ image.storage.uncompressed_size = uncompressed_size
+ image.storage.save()
+ return image.storage
def get_storage_locations(uuid):
- query = (ImageStoragePlacement
- .select()
- .join(ImageStorage)
- .where(ImageStorage.uuid == uuid))
+ query = (
+ ImageStoragePlacement.select()
+ .join(ImageStorage)
+ .where(ImageStorage.uuid == uuid)
+ )
- return [get_image_location_for_id(placement.location_id).name for placement in query]
+ return [
+ get_image_location_for_id(placement.location_id).name for placement in query
+ ]
def save_torrent_info(storage_object, piece_length, pieces):
- try:
- return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
- except TorrentInfo.DoesNotExist:
try:
- return TorrentInfo.create(storage=storage_object, piece_length=piece_length, pieces=pieces)
- except IntegrityError:
- # TorrentInfo already exists for this storage.
- return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
+ return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
+ except TorrentInfo.DoesNotExist:
+ try:
+ return TorrentInfo.create(
+ storage=storage_object, piece_length=piece_length, pieces=pieces
+ )
+ except IntegrityError:
+ # TorrentInfo already exists for this storage.
+ return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
def get_torrent_info(blob):
- try:
- return (TorrentInfo
- .select()
- .where(TorrentInfo.storage == blob)
- .get())
- except TorrentInfo.DoesNotExist:
- raise TorrentInfoDoesNotExist
+ try:
+ return TorrentInfo.select().where(TorrentInfo.storage == blob).get()
+ except TorrentInfo.DoesNotExist:
+ raise TorrentInfoDoesNotExist
diff --git a/data/model/tag.py b/data/model/tag.py
index 437a9765b..a9e043dd7 100644
--- a/data/model/tag.py
+++ b/data/model/tag.py
@@ -5,14 +5,39 @@ from datetime import datetime
from uuid import uuid4
from peewee import IntegrityError, JOIN, fn
-from data.model import (image, storage, db_transaction, DataModelException, _basequery,
- InvalidManifestException, TagAlreadyCreatedException, StaleTagException,
- config)
-from data.database import (RepositoryTag, Repository, Image, ImageStorage, Namespace, TagManifest,
- RepositoryNotification, Label, TagManifestLabel, get_epoch_timestamp,
- db_for_update, Manifest, ManifestLabel, ManifestBlob,
- ManifestLegacyImage, TagManifestToManifest,
- TagManifestLabelMap, TagToRepositoryTag, Tag, get_epoch_timestamp_ms)
+from data.model import (
+ image,
+ storage,
+ db_transaction,
+ DataModelException,
+ _basequery,
+ InvalidManifestException,
+ TagAlreadyCreatedException,
+ StaleTagException,
+ config,
+)
+from data.database import (
+ RepositoryTag,
+ Repository,
+ Image,
+ ImageStorage,
+ Namespace,
+ TagManifest,
+ RepositoryNotification,
+ Label,
+ TagManifestLabel,
+ get_epoch_timestamp,
+ db_for_update,
+ Manifest,
+ ManifestLabel,
+ ManifestBlob,
+ ManifestLegacyImage,
+ TagManifestToManifest,
+ TagManifestLabelMap,
+ TagToRepositoryTag,
+ Tag,
+ get_epoch_timestamp_ms,
+)
from util.timedeltastring import convert_to_timedelta
@@ -20,797 +45,971 @@ logger = logging.getLogger(__name__)
def get_max_id_for_sec_scan():
- """ Gets the maximum id for security scanning """
- return RepositoryTag.select(fn.Max(RepositoryTag.id)).scalar()
+ """ Gets the maximum id for security scanning """
+ return RepositoryTag.select(fn.Max(RepositoryTag.id)).scalar()
def get_min_id_for_sec_scan(version):
- """ Gets the minimum id for a security scanning """
- return _tag_alive(RepositoryTag
- .select(fn.Min(RepositoryTag.id))
- .join(Image)
- .where(Image.security_indexed_engine < version)).scalar()
+ """ Gets the minimum id for a security scanning """
+ return _tag_alive(
+ RepositoryTag.select(fn.Min(RepositoryTag.id))
+ .join(Image)
+ .where(Image.security_indexed_engine < version)
+ ).scalar()
def get_tag_pk_field():
- """ Returns the primary key for Image DB model """
- return RepositoryTag.id
+ """ Returns the primary key for Image DB model """
+ return RepositoryTag.id
def get_tags_images_eligible_for_scan(clair_version):
- Parent = Image.alias()
- ParentImageStorage = ImageStorage.alias()
+ Parent = Image.alias()
+ ParentImageStorage = ImageStorage.alias()
- return _tag_alive(RepositoryTag
- .select(Image, ImageStorage, Parent, ParentImageStorage, RepositoryTag)
- .join(Image, on=(RepositoryTag.image == Image.id))
- .join(ImageStorage, on=(Image.storage == ImageStorage.id))
- .switch(Image)
- .join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
- .join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage))
- .where(RepositoryTag.hidden == False)
- .where(Image.security_indexed_engine < clair_version))
+ return _tag_alive(
+ RepositoryTag.select(
+ Image, ImageStorage, Parent, ParentImageStorage, RepositoryTag
+ )
+ .join(Image, on=(RepositoryTag.image == Image.id))
+ .join(ImageStorage, on=(Image.storage == ImageStorage.id))
+ .switch(Image)
+ .join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
+ .join(
+ ParentImageStorage,
+ JOIN.LEFT_OUTER,
+ on=(ParentImageStorage.id == Parent.storage),
+ )
+ .where(RepositoryTag.hidden == False)
+ .where(Image.security_indexed_engine < clair_version)
+ )
def _tag_alive(query, now_ts=None):
- if now_ts is None:
- now_ts = get_epoch_timestamp()
- return query.where((RepositoryTag.lifetime_end_ts >> None) |
- (RepositoryTag.lifetime_end_ts > now_ts))
+ if now_ts is None:
+ now_ts = get_epoch_timestamp()
+ return query.where(
+ (RepositoryTag.lifetime_end_ts >> None)
+ | (RepositoryTag.lifetime_end_ts > now_ts)
+ )
def filter_has_repository_event(query, event):
- """ Filters the query by ensuring the repositories returned have the given event. """
- return (query
- .join(Repository)
- .join(RepositoryNotification)
- .where(RepositoryNotification.event == event))
+ """ Filters the query by ensuring the repositories returned have the given event. """
+ return (
+ query.join(Repository)
+ .join(RepositoryNotification)
+ .where(RepositoryNotification.event == event)
+ )
def filter_tags_have_repository_event(query, event):
- """ Filters the query by ensuring the repository tags live in a repository that has the given
+ """ Filters the query by ensuring the repository tags live in a repository that has the given
event. Also returns the image storage for the tag's image and orders the results by
lifetime_start_ts.
"""
- query = filter_has_repository_event(query, event)
- query = query.switch(RepositoryTag).join(Image).join(ImageStorage)
- query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
- return query
+ query = filter_has_repository_event(query, event)
+ query = query.switch(RepositoryTag).join(Image).join(ImageStorage)
+ query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
+ return query
_MAX_SUB_QUERIES = 100
_MAX_IMAGE_LOOKUP_COUNT = 500
-def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=None,
- selections=None):
- """ Returns all tags that contain the images with the given docker_image_id and storage_uuid,
+
+def get_matching_tags_for_images(
+ image_pairs, filter_images=None, filter_tags=None, selections=None
+):
+ """ Returns all tags that contain the images with the given docker_image_id and storage_uuid,
as specified as an iterable of pairs. """
- if not image_pairs:
- return []
+ if not image_pairs:
+ return []
- image_pairs_set = set(image_pairs)
+ image_pairs_set = set(image_pairs)
- # Find all possible matching image+storages.
- images = []
+ # Find all possible matching image+storages.
+ images = []
- while image_pairs:
- image_pairs_slice = image_pairs[:_MAX_IMAGE_LOOKUP_COUNT]
+ while image_pairs:
+ image_pairs_slice = image_pairs[:_MAX_IMAGE_LOOKUP_COUNT]
- ids = [pair[0] for pair in image_pairs_slice]
- uuids = [pair[1] for pair in image_pairs_slice]
+ ids = [pair[0] for pair in image_pairs_slice]
+ uuids = [pair[1] for pair in image_pairs_slice]
- images_query = (Image
- .select(Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid)
- .join(ImageStorage)
- .where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)
- .switch(Image))
+ images_query = (
+ Image.select(
+ Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid
+ )
+ .join(ImageStorage)
+ .where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)
+ .switch(Image)
+ )
- if filter_images is not None:
- images_query = filter_images(images_query)
+ if filter_images is not None:
+ images_query = filter_images(images_query)
- images.extend(list(images_query))
- image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:]
+ images.extend(list(images_query))
+ image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:]
- # Filter down to those images actually in the pairs set and build the set of queries to run.
- individual_image_queries = []
+ # Filter down to those images actually in the pairs set and build the set of queries to run.
+ individual_image_queries = []
- for img in images:
- # Make sure the image found is in the set of those requested, and that we haven't already
- # processed it. We need this check because the query above checks for images with matching
- # IDs OR storage UUIDs, rather than the expected ID+UUID pair. We do this for efficiency
- # reasons, and it is highly unlikely we'll find an image with a mismatch, but we need this
- # check to be absolutely sure.
- pair = (img.docker_image_id, img.storage.uuid)
- if pair not in image_pairs_set:
- continue
+ for img in images:
+ # Make sure the image found is in the set of those requested, and that we haven't already
+ # processed it. We need this check because the query above checks for images with matching
+ # IDs OR storage UUIDs, rather than the expected ID+UUID pair. We do this for efficiency
+ # reasons, and it is highly unlikely we'll find an image with a mismatch, but we need this
+ # check to be absolutely sure.
+ pair = (img.docker_image_id, img.storage.uuid)
+ if pair not in image_pairs_set:
+ continue
- # Remove the pair so we don't try it again.
- image_pairs_set.remove(pair)
+ # Remove the pair so we don't try it again.
+ image_pairs_set.remove(pair)
- ancestors_str = '%s%s/%%' % (img.ancestors, img.id)
- query = (Image
- .select(Image.id)
- .where((Image.id == img.id) | (Image.ancestors ** ancestors_str)))
+ ancestors_str = "%s%s/%%" % (img.ancestors, img.id)
+ query = Image.select(Image.id).where(
+ (Image.id == img.id) | (Image.ancestors ** ancestors_str)
+ )
- individual_image_queries.append(query)
+ individual_image_queries.append(query)
- if not individual_image_queries:
- return []
+ if not individual_image_queries:
+ return []
- # Shard based on the max subquery count. This is used to prevent going over the DB's max query
- # size, as well as to prevent the DB from locking up on a massive query.
- sharded_queries = []
- while individual_image_queries:
- shard = individual_image_queries[:_MAX_SUB_QUERIES]
- sharded_queries.append(_basequery.reduce_as_tree(shard))
- individual_image_queries = individual_image_queries[_MAX_SUB_QUERIES:]
+ # Shard based on the max subquery count. This is used to prevent going over the DB's max query
+ # size, as well as to prevent the DB from locking up on a massive query.
+ sharded_queries = []
+ while individual_image_queries:
+ shard = individual_image_queries[:_MAX_SUB_QUERIES]
+ sharded_queries.append(_basequery.reduce_as_tree(shard))
+ individual_image_queries = individual_image_queries[_MAX_SUB_QUERIES:]
- # Collect IDs of the tags found for each query.
- tags = {}
- for query in sharded_queries:
- ImageAlias = Image.alias()
- tag_query = (_tag_alive(RepositoryTag
- .select(*(selections or []))
- .distinct()
- .join(ImageAlias)
- .where(RepositoryTag.hidden == False)
- .where(ImageAlias.id << query)
- .switch(RepositoryTag)))
+ # Collect IDs of the tags found for each query.
+ tags = {}
+ for query in sharded_queries:
+ ImageAlias = Image.alias()
+ tag_query = _tag_alive(
+ RepositoryTag.select(*(selections or []))
+ .distinct()
+ .join(ImageAlias)
+ .where(RepositoryTag.hidden == False)
+ .where(ImageAlias.id << query)
+ .switch(RepositoryTag)
+ )
- if filter_tags is not None:
- tag_query = filter_tags(tag_query)
+ if filter_tags is not None:
+ tag_query = filter_tags(tag_query)
- for tag in tag_query:
- tags[tag.id] = tag
+ for tag in tag_query:
+ tags[tag.id] = tag
- return tags.values()
+ return tags.values()
def get_matching_tags(docker_image_id, storage_uuid, *args):
- """ Returns a query pointing to all tags that contain the image with the
+ """ Returns a query pointing to all tags that contain the image with the
given docker_image_id and storage_uuid. """
- image_row = image.get_image_with_storage(docker_image_id, storage_uuid)
- if image_row is None:
- return RepositoryTag.select().where(RepositoryTag.id < 0) # Empty query.
+ image_row = image.get_image_with_storage(docker_image_id, storage_uuid)
+ if image_row is None:
+ return RepositoryTag.select().where(RepositoryTag.id < 0) # Empty query.
- ancestors_str = '%s%s/%%' % (image_row.ancestors, image_row.id)
- return _tag_alive(RepositoryTag
- .select(*args)
- .distinct()
- .join(Image)
- .join(ImageStorage)
- .where(RepositoryTag.hidden == False)
- .where((Image.id == image_row.id) |
- (Image.ancestors ** ancestors_str)))
+ ancestors_str = "%s%s/%%" % (image_row.ancestors, image_row.id)
+ return _tag_alive(
+ RepositoryTag.select(*args)
+ .distinct()
+ .join(Image)
+ .join(ImageStorage)
+ .where(RepositoryTag.hidden == False)
+ .where((Image.id == image_row.id) | (Image.ancestors ** ancestors_str))
+ )
def get_tags_for_image(image_id, *args):
- return _tag_alive(RepositoryTag
- .select(*args)
- .distinct()
- .where(RepositoryTag.image == image_id,
- RepositoryTag.hidden == False))
+ return _tag_alive(
+ RepositoryTag.select(*args)
+ .distinct()
+ .where(RepositoryTag.image == image_id, RepositoryTag.hidden == False)
+ )
def get_tag_manifest_digests(tags):
- """ Returns a map from tag ID to its associated manifest digest, if any. """
- if not tags:
- return dict()
+ """ Returns a map from tag ID to its associated manifest digest, if any. """
+ if not tags:
+ return dict()
- manifests = (TagManifest
- .select(TagManifest.tag, TagManifest.digest)
- .where(TagManifest.tag << [t.id for t in tags]))
+ manifests = TagManifest.select(TagManifest.tag, TagManifest.digest).where(
+ TagManifest.tag << [t.id for t in tags]
+ )
- return {manifest.tag_id: manifest.digest for manifest in manifests}
+ return {manifest.tag_id: manifest.digest for manifest in manifests}
def list_active_repo_tags(repo, start_id=None, limit=None, include_images=True):
- """ Returns all of the active, non-hidden tags in a repository, joined to they images
+ """ Returns all of the active, non-hidden tags in a repository, joined to they images
and (if present), their manifest.
"""
- if include_images:
- query = _tag_alive(RepositoryTag
- .select(RepositoryTag, Image, ImageStorage, TagManifest.digest)
- .join(Image)
- .join(ImageStorage)
- .where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
- .switch(RepositoryTag)
- .join(TagManifest, JOIN.LEFT_OUTER)
- .order_by(RepositoryTag.id))
- else:
- query = _tag_alive(RepositoryTag
- .select(RepositoryTag)
- .where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
- .order_by(RepositoryTag.id))
+ if include_images:
+ query = _tag_alive(
+ RepositoryTag.select(RepositoryTag, Image, ImageStorage, TagManifest.digest)
+ .join(Image)
+ .join(ImageStorage)
+ .where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
+ .switch(RepositoryTag)
+ .join(TagManifest, JOIN.LEFT_OUTER)
+ .order_by(RepositoryTag.id)
+ )
+ else:
+ query = _tag_alive(
+ RepositoryTag.select(RepositoryTag)
+ .where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
+ .order_by(RepositoryTag.id)
+ )
- if start_id is not None:
- query = query.where(RepositoryTag.id >= start_id)
+ if start_id is not None:
+ query = query.where(RepositoryTag.id >= start_id)
- if limit is not None:
- query = query.limit(limit)
+ if limit is not None:
+ query = query.limit(limit)
- return query
+ return query
-def list_repository_tags(namespace_name, repository_name, include_hidden=False,
- include_storage=False):
- to_select = (RepositoryTag, Image)
- if include_storage:
- to_select = (RepositoryTag, Image, ImageStorage)
+def list_repository_tags(
+ namespace_name, repository_name, include_hidden=False, include_storage=False
+):
+ to_select = (RepositoryTag, Image)
+ if include_storage:
+ to_select = (RepositoryTag, Image, ImageStorage)
- query = _tag_alive(RepositoryTag
- .select(*to_select)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(RepositoryTag)
- .join(Image)
- .where(Repository.name == repository_name,
- Namespace.username == namespace_name))
+ query = _tag_alive(
+ RepositoryTag.select(*to_select)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(RepositoryTag)
+ .join(Image)
+ .where(Repository.name == repository_name, Namespace.username == namespace_name)
+ )
- if not include_hidden:
- query = query.where(RepositoryTag.hidden == False)
+ if not include_hidden:
+ query = query.where(RepositoryTag.hidden == False)
- if include_storage:
- query = query.switch(Image).join(ImageStorage)
+ if include_storage:
+ query = query.switch(Image).join(ImageStorage)
- return query
+ return query
-def create_or_update_tag(namespace_name, repository_name, tag_name, tag_docker_image_id,
- reversion=False, now_ms=None):
- try:
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- except Repository.DoesNotExist:
- raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
-
- return create_or_update_tag_for_repo(repo.id, tag_name, tag_docker_image_id, reversion=reversion,
- now_ms=now_ms)
-
-def create_or_update_tag_for_repo(repository_id, tag_name, tag_docker_image_id, reversion=False,
- oci_manifest=None, now_ms=None):
- now_ms = now_ms or get_epoch_timestamp_ms()
- now_ts = int(now_ms / 1000)
-
- with db_transaction():
+def create_or_update_tag(
+ namespace_name,
+ repository_name,
+ tag_name,
+ tag_docker_image_id,
+ reversion=False,
+ now_ms=None,
+):
try:
- tag = db_for_update(_tag_alive(RepositoryTag
- .select()
- .where(RepositoryTag.repository == repository_id,
- RepositoryTag.name == tag_name), now_ts)).get()
- tag.lifetime_end_ts = now_ts
- tag.save()
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ except Repository.DoesNotExist:
+ raise DataModelException(
+ "Invalid repository %s/%s" % (namespace_name, repository_name)
+ )
- # Check for an OCI tag.
- try:
- oci_tag = db_for_update(Tag
- .select()
- .join(TagToRepositoryTag)
- .where(TagToRepositoryTag.repository_tag == tag)).get()
- oci_tag.lifetime_end_ms = now_ms
- oci_tag.save()
- except Tag.DoesNotExist:
- pass
- except RepositoryTag.DoesNotExist:
- pass
- except IntegrityError:
- msg = 'Tag with name %s was stale when we tried to update it; Please retry the push'
- raise StaleTagException(msg % tag_name)
+ return create_or_update_tag_for_repo(
+ repo.id, tag_name, tag_docker_image_id, reversion=reversion, now_ms=now_ms
+ )
- try:
- image_obj = Image.get(Image.docker_image_id == tag_docker_image_id,
- Image.repository == repository_id)
- except Image.DoesNotExist:
- raise DataModelException('Invalid image with id: %s' % tag_docker_image_id)
- try:
- created = RepositoryTag.create(repository=repository_id, image=image_obj, name=tag_name,
- lifetime_start_ts=now_ts, reversion=reversion)
- if oci_manifest:
- # Create the OCI tag as well.
- oci_tag = Tag.create(repository=repository_id, manifest=oci_manifest, name=tag_name,
- lifetime_start_ms=now_ms, reversion=reversion,
- tag_kind=Tag.tag_kind.get_id('tag'))
- TagToRepositoryTag.create(tag=oci_tag, repository_tag=created, repository=repository_id)
+def create_or_update_tag_for_repo(
+ repository_id,
+ tag_name,
+ tag_docker_image_id,
+ reversion=False,
+ oci_manifest=None,
+ now_ms=None,
+):
+ now_ms = now_ms or get_epoch_timestamp_ms()
+ now_ts = int(now_ms / 1000)
- return created
- except IntegrityError:
- msg = 'Tag with name %s and lifetime start %s already exists'
- raise TagAlreadyCreatedException(msg % (tag_name, now_ts))
+ with db_transaction():
+ try:
+ tag = db_for_update(
+ _tag_alive(
+ RepositoryTag.select().where(
+ RepositoryTag.repository == repository_id,
+ RepositoryTag.name == tag_name,
+ ),
+ now_ts,
+ )
+ ).get()
+ tag.lifetime_end_ts = now_ts
+ tag.save()
+
+ # Check for an OCI tag.
+ try:
+ oci_tag = db_for_update(
+ Tag.select()
+ .join(TagToRepositoryTag)
+ .where(TagToRepositoryTag.repository_tag == tag)
+ ).get()
+ oci_tag.lifetime_end_ms = now_ms
+ oci_tag.save()
+ except Tag.DoesNotExist:
+ pass
+ except RepositoryTag.DoesNotExist:
+ pass
+ except IntegrityError:
+ msg = "Tag with name %s was stale when we tried to update it; Please retry the push"
+ raise StaleTagException(msg % tag_name)
+
+ try:
+ image_obj = Image.get(
+ Image.docker_image_id == tag_docker_image_id,
+ Image.repository == repository_id,
+ )
+ except Image.DoesNotExist:
+ raise DataModelException("Invalid image with id: %s" % tag_docker_image_id)
+
+ try:
+ created = RepositoryTag.create(
+ repository=repository_id,
+ image=image_obj,
+ name=tag_name,
+ lifetime_start_ts=now_ts,
+ reversion=reversion,
+ )
+ if oci_manifest:
+ # Create the OCI tag as well.
+ oci_tag = Tag.create(
+ repository=repository_id,
+ manifest=oci_manifest,
+ name=tag_name,
+ lifetime_start_ms=now_ms,
+ reversion=reversion,
+ tag_kind=Tag.tag_kind.get_id("tag"),
+ )
+ TagToRepositoryTag.create(
+ tag=oci_tag, repository_tag=created, repository=repository_id
+ )
+
+ return created
+ except IntegrityError:
+ msg = "Tag with name %s and lifetime start %s already exists"
+ raise TagAlreadyCreatedException(msg % (tag_name, now_ts))
def create_temporary_hidden_tag(repo, image_obj, expiration_s):
- """ Create a tag with a defined timeline, that will not appear in the UI or CLI. Returns the name
+ """ Create a tag with a defined timeline, that will not appear in the UI or CLI. Returns the name
of the temporary tag. """
- now_ts = get_epoch_timestamp()
- expire_ts = now_ts + expiration_s
- tag_name = str(uuid4())
- RepositoryTag.create(repository=repo, image=image_obj, name=tag_name, lifetime_start_ts=now_ts,
- lifetime_end_ts=expire_ts, hidden=True)
- return tag_name
+ now_ts = get_epoch_timestamp()
+ expire_ts = now_ts + expiration_s
+ tag_name = str(uuid4())
+ RepositoryTag.create(
+ repository=repo,
+ image=image_obj,
+ name=tag_name,
+ lifetime_start_ts=now_ts,
+ lifetime_end_ts=expire_ts,
+ hidden=True,
+ )
+ return tag_name
def lookup_unrecoverable_tags(repo):
- """ Returns the tags in a repository that are expired and past their time machine recovery
+ """ Returns the tags in a repository that are expired and past their time machine recovery
period. """
- expired_clause = get_epoch_timestamp() - Namespace.removed_tag_expiration_s
- return (RepositoryTag
- .select()
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(RepositoryTag.repository == repo)
- .where(~(RepositoryTag.lifetime_end_ts >> None),
- RepositoryTag.lifetime_end_ts <= expired_clause))
+ expired_clause = get_epoch_timestamp() - Namespace.removed_tag_expiration_s
+ return (
+ RepositoryTag.select()
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(RepositoryTag.repository == repo)
+ .where(
+ ~(RepositoryTag.lifetime_end_ts >> None),
+ RepositoryTag.lifetime_end_ts <= expired_clause,
+ )
+ )
def delete_tag(namespace_name, repository_name, tag_name, now_ms=None):
- now_ms = now_ms or get_epoch_timestamp_ms()
- now_ts = int(now_ms / 1000)
+ now_ms = now_ms or get_epoch_timestamp_ms()
+ now_ts = int(now_ms / 1000)
- with db_transaction():
- try:
- query = _tag_alive(RepositoryTag
- .select(RepositoryTag, Repository)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Repository.name == repository_name,
- Namespace.username == namespace_name,
- RepositoryTag.name == tag_name), now_ts)
- found = db_for_update(query).get()
- except RepositoryTag.DoesNotExist:
- msg = ('Invalid repository tag \'%s\' on repository \'%s/%s\'' %
- (tag_name, namespace_name, repository_name))
- raise DataModelException(msg)
+ with db_transaction():
+ try:
+ query = _tag_alive(
+ RepositoryTag.select(RepositoryTag, Repository)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ Repository.name == repository_name,
+ Namespace.username == namespace_name,
+ RepositoryTag.name == tag_name,
+ ),
+ now_ts,
+ )
+ found = db_for_update(query).get()
+ except RepositoryTag.DoesNotExist:
+ msg = "Invalid repository tag '%s' on repository '%s/%s'" % (
+ tag_name,
+ namespace_name,
+ repository_name,
+ )
+ raise DataModelException(msg)
- found.lifetime_end_ts = now_ts
- found.save()
+ found.lifetime_end_ts = now_ts
+ found.save()
- try:
- oci_tag_query = TagToRepositoryTag.select().where(TagToRepositoryTag.repository_tag == found)
- oci_tag = db_for_update(oci_tag_query).get().tag
- oci_tag.lifetime_end_ms = now_ms
- oci_tag.save()
- except TagToRepositoryTag.DoesNotExist:
- pass
+ try:
+ oci_tag_query = TagToRepositoryTag.select().where(
+ TagToRepositoryTag.repository_tag == found
+ )
+ oci_tag = db_for_update(oci_tag_query).get().tag
+ oci_tag.lifetime_end_ms = now_ms
+ oci_tag.save()
+ except TagToRepositoryTag.DoesNotExist:
+ pass
- return found
+ return found
def _get_repo_tag_image(tag_name, include_storage, modifier):
- query = Image.select().join(RepositoryTag)
+ query = Image.select().join(RepositoryTag)
- if include_storage:
- query = (Image
- .select(Image, ImageStorage)
- .join(ImageStorage)
- .switch(Image)
- .join(RepositoryTag))
+ if include_storage:
+ query = (
+ Image.select(Image, ImageStorage)
+ .join(ImageStorage)
+ .switch(Image)
+ .join(RepositoryTag)
+ )
- images = _tag_alive(modifier(query.where(RepositoryTag.name == tag_name)))
- if not images:
- raise DataModelException('Unable to find image for tag.')
- else:
- return images[0]
+ images = _tag_alive(modifier(query.where(RepositoryTag.name == tag_name)))
+ if not images:
+ raise DataModelException("Unable to find image for tag.")
+ else:
+ return images[0]
def get_repo_tag_image(repo, tag_name, include_storage=False):
- def modifier(query):
- return query.where(RepositoryTag.repository == repo)
+ def modifier(query):
+ return query.where(RepositoryTag.repository == repo)
- return _get_repo_tag_image(tag_name, include_storage, modifier)
+ return _get_repo_tag_image(tag_name, include_storage, modifier)
def get_tag_image(namespace_name, repository_name, tag_name, include_storage=False):
- def modifier(query):
- return (query
- .switch(RepositoryTag)
+ def modifier(query):
+ return (
+ query.switch(RepositoryTag)
.join(Repository)
.join(Namespace)
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ .where(
+ Namespace.username == namespace_name, Repository.name == repository_name
+ )
+ )
- return _get_repo_tag_image(tag_name, include_storage, modifier)
+ return _get_repo_tag_image(tag_name, include_storage, modifier)
-def list_repository_tag_history(repo_obj, page=1, size=100, specific_tag=None, active_tags_only=False, since_time=None):
- # Only available on OCI model
- if since_time is not None:
- raise NotImplementedError
+def list_repository_tag_history(
+ repo_obj,
+ page=1,
+ size=100,
+ specific_tag=None,
+ active_tags_only=False,
+ since_time=None,
+):
+ # Only available on OCI model
+ if since_time is not None:
+ raise NotImplementedError
- query = (RepositoryTag
- .select(RepositoryTag, Image, ImageStorage)
- .join(Image)
- .join(ImageStorage)
- .switch(RepositoryTag)
- .where(RepositoryTag.repository == repo_obj)
- .where(RepositoryTag.hidden == False)
- .order_by(RepositoryTag.lifetime_start_ts.desc(), RepositoryTag.name)
- .limit(size + 1)
- .offset(size * (page - 1)))
+ query = (
+ RepositoryTag.select(RepositoryTag, Image, ImageStorage)
+ .join(Image)
+ .join(ImageStorage)
+ .switch(RepositoryTag)
+ .where(RepositoryTag.repository == repo_obj)
+ .where(RepositoryTag.hidden == False)
+ .order_by(RepositoryTag.lifetime_start_ts.desc(), RepositoryTag.name)
+ .limit(size + 1)
+ .offset(size * (page - 1))
+ )
- if active_tags_only:
- query = _tag_alive(query)
+ if active_tags_only:
+ query = _tag_alive(query)
- if specific_tag:
- query = query.where(RepositoryTag.name == specific_tag)
+ if specific_tag:
+ query = query.where(RepositoryTag.name == specific_tag)
- tags = list(query)
- if not tags:
- return [], {}, False
+ tags = list(query)
+ if not tags:
+ return [], {}, False
- manifest_map = get_tag_manifest_digests(tags)
- return tags[0:size], manifest_map, len(tags) > size
+ manifest_map = get_tag_manifest_digests(tags)
+ return tags[0:size], manifest_map, len(tags) > size
def restore_tag_to_manifest(repo_obj, tag_name, manifest_digest):
- """ Restores a tag to a specific manifest digest. """
- with db_transaction():
- # Verify that the manifest digest already existed under this repository under the
- # tag.
- try:
- tag_manifest = (TagManifest
- .select(TagManifest, RepositoryTag, Image)
- .join(RepositoryTag)
- .join(Image)
- .where(RepositoryTag.repository == repo_obj)
- .where(RepositoryTag.name == tag_name)
- .where(TagManifest.digest == manifest_digest)
- .get())
- except TagManifest.DoesNotExist:
- raise DataModelException('Cannot restore to unknown or invalid digest')
+ """ Restores a tag to a specific manifest digest. """
+ with db_transaction():
+ # Verify that the manifest digest already existed under this repository under the
+ # tag.
+ try:
+ tag_manifest = (
+ TagManifest.select(TagManifest, RepositoryTag, Image)
+ .join(RepositoryTag)
+ .join(Image)
+ .where(RepositoryTag.repository == repo_obj)
+ .where(RepositoryTag.name == tag_name)
+ .where(TagManifest.digest == manifest_digest)
+ .get()
+ )
+ except TagManifest.DoesNotExist:
+ raise DataModelException("Cannot restore to unknown or invalid digest")
- # Lookup the existing image, if any.
- try:
- existing_image = get_repo_tag_image(repo_obj, tag_name)
- except DataModelException:
- existing_image = None
+ # Lookup the existing image, if any.
+ try:
+ existing_image = get_repo_tag_image(repo_obj, tag_name)
+ except DataModelException:
+ existing_image = None
- docker_image_id = tag_manifest.tag.image.docker_image_id
- oci_manifest = None
- try:
- oci_manifest = Manifest.get(repository=repo_obj, digest=manifest_digest)
- except Manifest.DoesNotExist:
- pass
+ docker_image_id = tag_manifest.tag.image.docker_image_id
+ oci_manifest = None
+ try:
+ oci_manifest = Manifest.get(repository=repo_obj, digest=manifest_digest)
+ except Manifest.DoesNotExist:
+ pass
- # Change the tag and tag manifest to point to the updated image.
- updated_tag = create_or_update_tag_for_repo(repo_obj, tag_name, docker_image_id,
- reversion=True, oci_manifest=oci_manifest)
- tag_manifest.tag = updated_tag
- tag_manifest.save()
- return existing_image
+ # Change the tag and tag manifest to point to the updated image.
+ updated_tag = create_or_update_tag_for_repo(
+ repo_obj,
+ tag_name,
+ docker_image_id,
+ reversion=True,
+ oci_manifest=oci_manifest,
+ )
+ tag_manifest.tag = updated_tag
+ tag_manifest.save()
+ return existing_image
def restore_tag_to_image(repo_obj, tag_name, docker_image_id):
- """ Restores a tag to a specific image ID. """
- with db_transaction():
- # Verify that the image ID already existed under this repository under the
- # tag.
- try:
- (RepositoryTag
- .select()
- .join(Image)
- .where(RepositoryTag.repository == repo_obj)
- .where(RepositoryTag.name == tag_name)
- .where(Image.docker_image_id == docker_image_id)
- .get())
- except RepositoryTag.DoesNotExist:
- raise DataModelException('Cannot restore to unknown or invalid image')
+ """ Restores a tag to a specific image ID. """
+ with db_transaction():
+ # Verify that the image ID already existed under this repository under the
+ # tag.
+ try:
+ (
+ RepositoryTag.select()
+ .join(Image)
+ .where(RepositoryTag.repository == repo_obj)
+ .where(RepositoryTag.name == tag_name)
+ .where(Image.docker_image_id == docker_image_id)
+ .get()
+ )
+ except RepositoryTag.DoesNotExist:
+ raise DataModelException("Cannot restore to unknown or invalid image")
- # Lookup the existing image, if any.
- try:
- existing_image = get_repo_tag_image(repo_obj, tag_name)
- except DataModelException:
- existing_image = None
+ # Lookup the existing image, if any.
+ try:
+ existing_image = get_repo_tag_image(repo_obj, tag_name)
+ except DataModelException:
+ existing_image = None
- create_or_update_tag_for_repo(repo_obj, tag_name, docker_image_id, reversion=True)
- return existing_image
+ create_or_update_tag_for_repo(
+ repo_obj, tag_name, docker_image_id, reversion=True
+ )
+ return existing_image
-def store_tag_manifest_for_testing(namespace_name, repository_name, tag_name, manifest,
- leaf_layer_id, storage_id_map):
- """ Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
+def store_tag_manifest_for_testing(
+ namespace_name, repository_name, tag_name, manifest, leaf_layer_id, storage_id_map
+):
+ """ Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
object, as well as a boolean indicating whether the TagManifest was created.
"""
- try:
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- except Repository.DoesNotExist:
- raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
+ try:
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ except Repository.DoesNotExist:
+ raise DataModelException(
+ "Invalid repository %s/%s" % (namespace_name, repository_name)
+ )
- return store_tag_manifest_for_repo(repo.id, tag_name, manifest, leaf_layer_id, storage_id_map)
+ return store_tag_manifest_for_repo(
+ repo.id, tag_name, manifest, leaf_layer_id, storage_id_map
+ )
-def store_tag_manifest_for_repo(repository_id, tag_name, manifest, leaf_layer_id, storage_id_map,
- reversion=False):
- """ Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
+def store_tag_manifest_for_repo(
+ repository_id, tag_name, manifest, leaf_layer_id, storage_id_map, reversion=False
+):
+ """ Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
object, as well as a boolean indicating whether the TagManifest was created.
"""
- # Create the new-style OCI manifest and its blobs.
- oci_manifest = _populate_manifest_and_blobs(repository_id, manifest, storage_id_map,
- leaf_layer_id=leaf_layer_id)
+ # Create the new-style OCI manifest and its blobs.
+ oci_manifest = _populate_manifest_and_blobs(
+ repository_id, manifest, storage_id_map, leaf_layer_id=leaf_layer_id
+ )
- # Create the tag for the tag manifest.
- tag = create_or_update_tag_for_repo(repository_id, tag_name, leaf_layer_id,
- reversion=reversion, oci_manifest=oci_manifest)
+ # Create the tag for the tag manifest.
+ tag = create_or_update_tag_for_repo(
+ repository_id,
+ tag_name,
+ leaf_layer_id,
+ reversion=reversion,
+ oci_manifest=oci_manifest,
+ )
- # Add a tag manifest pointing to that tag.
- try:
- manifest = TagManifest.get(digest=manifest.digest)
- manifest.tag = tag
- manifest.save()
- return manifest, False
- except TagManifest.DoesNotExist:
- created = _associate_manifest(tag, oci_manifest)
- return created, True
+ # Add a tag manifest pointing to that tag.
+ try:
+ manifest = TagManifest.get(digest=manifest.digest)
+ manifest.tag = tag
+ manifest.save()
+ return manifest, False
+ except TagManifest.DoesNotExist:
+ created = _associate_manifest(tag, oci_manifest)
+ return created, True
def get_active_tag(namespace, repo_name, tag_name):
- return _tag_alive(RepositoryTag
- .select()
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(RepositoryTag.name == tag_name, Repository.name == repo_name,
- Namespace.username == namespace)).get()
+ return _tag_alive(
+ RepositoryTag.select()
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ RepositoryTag.name == tag_name,
+ Repository.name == repo_name,
+ Namespace.username == namespace,
+ )
+ ).get()
+
def get_active_tag_for_repo(repo, tag_name):
- try:
- return _tag_alive(RepositoryTag
- .select(RepositoryTag, Image, ImageStorage)
- .join(Image)
- .join(ImageStorage)
- .where(RepositoryTag.name == tag_name,
- RepositoryTag.repository == repo,
- RepositoryTag.hidden == False)).get()
- except RepositoryTag.DoesNotExist:
- return None
+ try:
+ return _tag_alive(
+ RepositoryTag.select(RepositoryTag, Image, ImageStorage)
+ .join(Image)
+ .join(ImageStorage)
+ .where(
+ RepositoryTag.name == tag_name,
+ RepositoryTag.repository == repo,
+ RepositoryTag.hidden == False,
+ )
+ ).get()
+ except RepositoryTag.DoesNotExist:
+ return None
+
def get_expired_tag_in_repo(repo, tag_name):
- return (RepositoryTag
- .select()
- .where(RepositoryTag.name == tag_name, RepositoryTag.repository == repo)
- .where(~(RepositoryTag.lifetime_end_ts >> None))
- .where(RepositoryTag.lifetime_end_ts <= get_epoch_timestamp())
- .get())
+ return (
+ RepositoryTag.select()
+ .where(RepositoryTag.name == tag_name, RepositoryTag.repository == repo)
+ .where(~(RepositoryTag.lifetime_end_ts >> None))
+ .where(RepositoryTag.lifetime_end_ts <= get_epoch_timestamp())
+ .get()
+ )
def get_possibly_expired_tag(namespace, repo_name, tag_name):
- return (RepositoryTag
- .select()
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(RepositoryTag.name == tag_name, Repository.name == repo_name,
- Namespace.username == namespace)).get()
+ return (
+ RepositoryTag.select()
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ RepositoryTag.name == tag_name,
+ Repository.name == repo_name,
+ Namespace.username == namespace,
+ )
+ ).get()
+
def associate_generated_tag_manifest_with_tag(tag, manifest, storage_id_map):
- oci_manifest = _populate_manifest_and_blobs(tag.repository, manifest, storage_id_map)
+ oci_manifest = _populate_manifest_and_blobs(
+ tag.repository, manifest, storage_id_map
+ )
- with db_transaction():
- try:
- (Tag
- .select()
- .join(TagToRepositoryTag)
- .where(TagToRepositoryTag.repository_tag == tag)).get()
- except Tag.DoesNotExist:
- oci_tag = Tag.create(repository=tag.repository, manifest=oci_manifest, name=tag.name,
- reversion=tag.reversion,
- lifetime_start_ms=tag.lifetime_start_ts * 1000,
- lifetime_end_ms=(tag.lifetime_end_ts * 1000
- if tag.lifetime_end_ts else None),
- tag_kind=Tag.tag_kind.get_id('tag'))
- TagToRepositoryTag.create(tag=oci_tag, repository_tag=tag, repository=tag.repository)
+ with db_transaction():
+ try:
+ (
+ Tag.select()
+ .join(TagToRepositoryTag)
+ .where(TagToRepositoryTag.repository_tag == tag)
+ ).get()
+ except Tag.DoesNotExist:
+ oci_tag = Tag.create(
+ repository=tag.repository,
+ manifest=oci_manifest,
+ name=tag.name,
+ reversion=tag.reversion,
+ lifetime_start_ms=tag.lifetime_start_ts * 1000,
+ lifetime_end_ms=(
+ tag.lifetime_end_ts * 1000 if tag.lifetime_end_ts else None
+ ),
+ tag_kind=Tag.tag_kind.get_id("tag"),
+ )
+ TagToRepositoryTag.create(
+ tag=oci_tag, repository_tag=tag, repository=tag.repository
+ )
- return _associate_manifest(tag, oci_manifest)
+ return _associate_manifest(tag, oci_manifest)
def _associate_manifest(tag, oci_manifest):
- with db_transaction():
- tag_manifest = TagManifest.create(tag=tag, digest=oci_manifest.digest,
- json_data=oci_manifest.manifest_bytes)
- TagManifestToManifest.create(tag_manifest=tag_manifest, manifest=oci_manifest)
- return tag_manifest
+ with db_transaction():
+ tag_manifest = TagManifest.create(
+ tag=tag, digest=oci_manifest.digest, json_data=oci_manifest.manifest_bytes
+ )
+ TagManifestToManifest.create(tag_manifest=tag_manifest, manifest=oci_manifest)
+ return tag_manifest
-def _populate_manifest_and_blobs(repository, manifest, storage_id_map, leaf_layer_id=None):
- leaf_layer_id = leaf_layer_id or manifest.leaf_layer_v1_image_id
- try:
- legacy_image = Image.get(Image.docker_image_id == leaf_layer_id,
- Image.repository == repository)
- except Image.DoesNotExist:
- raise DataModelException('Invalid image with id: %s' % leaf_layer_id)
+def _populate_manifest_and_blobs(
+ repository, manifest, storage_id_map, leaf_layer_id=None
+):
+ leaf_layer_id = leaf_layer_id or manifest.leaf_layer_v1_image_id
+ try:
+ legacy_image = Image.get(
+ Image.docker_image_id == leaf_layer_id, Image.repository == repository
+ )
+ except Image.DoesNotExist:
+ raise DataModelException("Invalid image with id: %s" % leaf_layer_id)
- storage_ids = set()
- for blob_digest in manifest.local_blob_digests:
- image_storage_id = storage_id_map.get(blob_digest)
- if image_storage_id is None:
- logger.error('Missing blob for manifest `%s` in: %s', blob_digest, storage_id_map)
- raise DataModelException('Missing blob for manifest `%s`' % blob_digest)
+ storage_ids = set()
+ for blob_digest in manifest.local_blob_digests:
+ image_storage_id = storage_id_map.get(blob_digest)
+ if image_storage_id is None:
+ logger.error(
+ "Missing blob for manifest `%s` in: %s", blob_digest, storage_id_map
+ )
+ raise DataModelException("Missing blob for manifest `%s`" % blob_digest)
- if image_storage_id in storage_ids:
- continue
+ if image_storage_id in storage_ids:
+ continue
- storage_ids.add(image_storage_id)
+ storage_ids.add(image_storage_id)
- return populate_manifest(repository, manifest, legacy_image, storage_ids)
+ return populate_manifest(repository, manifest, legacy_image, storage_ids)
def populate_manifest(repository, manifest, legacy_image, storage_ids):
- """ Populates the rows for the manifest, including its blobs and legacy image. """
- media_type = Manifest.media_type.get_id(manifest.media_type)
+ """ Populates the rows for the manifest, including its blobs and legacy image. """
+ media_type = Manifest.media_type.get_id(manifest.media_type)
- # Check for an existing manifest. If present, return it.
- try:
- return Manifest.get(repository=repository, digest=manifest.digest)
- except Manifest.DoesNotExist:
- pass
-
- with db_transaction():
+ # Check for an existing manifest. If present, return it.
try:
- manifest_row = Manifest.create(digest=manifest.digest, repository=repository,
- manifest_bytes=manifest.bytes.as_encoded_str(),
- media_type=media_type)
- except IntegrityError as ie:
- logger.debug('Got integrity error when trying to write manifest: %s', ie)
- return Manifest.get(repository=repository, digest=manifest.digest)
+ return Manifest.get(repository=repository, digest=manifest.digest)
+ except Manifest.DoesNotExist:
+ pass
- ManifestLegacyImage.create(manifest=manifest_row, repository=repository, image=legacy_image)
+ with db_transaction():
+ try:
+ manifest_row = Manifest.create(
+ digest=manifest.digest,
+ repository=repository,
+ manifest_bytes=manifest.bytes.as_encoded_str(),
+ media_type=media_type,
+ )
+ except IntegrityError as ie:
+ logger.debug("Got integrity error when trying to write manifest: %s", ie)
+ return Manifest.get(repository=repository, digest=manifest.digest)
- blobs_to_insert = [dict(manifest=manifest_row, repository=repository,
- blob=storage_id) for storage_id in storage_ids]
- if blobs_to_insert:
- ManifestBlob.insert_many(blobs_to_insert).execute()
+ ManifestLegacyImage.create(
+ manifest=manifest_row, repository=repository, image=legacy_image
+ )
- return manifest_row
+ blobs_to_insert = [
+ dict(manifest=manifest_row, repository=repository, blob=storage_id)
+ for storage_id in storage_ids
+ ]
+ if blobs_to_insert:
+ ManifestBlob.insert_many(blobs_to_insert).execute()
+
+ return manifest_row
def get_tag_manifest(tag):
- try:
- return TagManifest.get(tag=tag)
- except TagManifest.DoesNotExist:
- return None
+ try:
+ return TagManifest.get(tag=tag)
+ except TagManifest.DoesNotExist:
+ return None
def load_tag_manifest(namespace, repo_name, tag_name):
- try:
- return (_load_repo_manifests(namespace, repo_name)
+ try:
+ return (
+ _load_repo_manifests(namespace, repo_name)
.where(RepositoryTag.name == tag_name)
- .get())
- except TagManifest.DoesNotExist:
- msg = 'Manifest not found for tag {0} in repo {1}/{2}'.format(tag_name, namespace, repo_name)
- raise InvalidManifestException(msg)
+ .get()
+ )
+ except TagManifest.DoesNotExist:
+ msg = "Manifest not found for tag {0} in repo {1}/{2}".format(
+ tag_name, namespace, repo_name
+ )
+ raise InvalidManifestException(msg)
def delete_manifest_by_digest(namespace, repo_name, digest):
- tag_manifests = list(_load_repo_manifests(namespace, repo_name)
- .where(TagManifest.digest == digest))
+ tag_manifests = list(
+ _load_repo_manifests(namespace, repo_name).where(TagManifest.digest == digest)
+ )
- now_ms = get_epoch_timestamp_ms()
- for tag_manifest in tag_manifests:
- try:
- tag = _tag_alive(RepositoryTag.select().where(RepositoryTag.id == tag_manifest.tag_id)).get()
- delete_tag(namespace, repo_name, tag_manifest.tag.name, now_ms)
- except RepositoryTag.DoesNotExist:
- pass
+ now_ms = get_epoch_timestamp_ms()
+ for tag_manifest in tag_manifests:
+ try:
+ tag = _tag_alive(
+ RepositoryTag.select().where(RepositoryTag.id == tag_manifest.tag_id)
+ ).get()
+ delete_tag(namespace, repo_name, tag_manifest.tag.name, now_ms)
+ except RepositoryTag.DoesNotExist:
+ pass
- return [tag_manifest.tag for tag_manifest in tag_manifests]
+ return [tag_manifest.tag for tag_manifest in tag_manifests]
def load_manifest_by_digest(namespace, repo_name, digest, allow_dead=False):
- try:
- return (_load_repo_manifests(namespace, repo_name, allow_dead=allow_dead)
+ try:
+ return (
+ _load_repo_manifests(namespace, repo_name, allow_dead=allow_dead)
.where(TagManifest.digest == digest)
- .get())
- except TagManifest.DoesNotExist:
- msg = 'Manifest not found with digest {0} in repo {1}/{2}'.format(digest, namespace, repo_name)
- raise InvalidManifestException(msg)
+ .get()
+ )
+ except TagManifest.DoesNotExist:
+ msg = "Manifest not found with digest {0} in repo {1}/{2}".format(
+ digest, namespace, repo_name
+ )
+ raise InvalidManifestException(msg)
def _load_repo_manifests(namespace, repo_name, allow_dead=False):
- query = (TagManifest
- .select(TagManifest, RepositoryTag)
- .join(RepositoryTag)
- .join(Image)
- .join(Repository)
- .join(Namespace, on=(Namespace.id == Repository.namespace_user))
- .where(Repository.name == repo_name, Namespace.username == namespace))
+ query = (
+ TagManifest.select(TagManifest, RepositoryTag)
+ .join(RepositoryTag)
+ .join(Image)
+ .join(Repository)
+ .join(Namespace, on=(Namespace.id == Repository.namespace_user))
+ .where(Repository.name == repo_name, Namespace.username == namespace)
+ )
- if not allow_dead:
- query = _tag_alive(query)
+ if not allow_dead:
+ query = _tag_alive(query)
- return query
+ return query
-def change_repository_tag_expiration(namespace_name, repo_name, tag_name, expiration_date):
- """ Changes the expiration of the tag with the given name to the given expiration datetime. If
+
+def change_repository_tag_expiration(
+ namespace_name, repo_name, tag_name, expiration_date
+):
+ """ Changes the expiration of the tag with the given name to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring.
"""
- try:
- tag = get_active_tag(namespace_name, repo_name, tag_name)
- return change_tag_expiration(tag, expiration_date)
- except RepositoryTag.DoesNotExist:
- return (None, False)
+ try:
+ tag = get_active_tag(namespace_name, repo_name, tag_name)
+ return change_tag_expiration(tag, expiration_date)
+ except RepositoryTag.DoesNotExist:
+ return (None, False)
def set_tag_expiration_for_manifest(tag_manifest, expiration_sec):
- """
+ """
Changes the expiration of the tag that points to the given manifest to be its lifetime start +
the expiration seconds.
"""
- expiration_time_ts = tag_manifest.tag.lifetime_start_ts + expiration_sec
- expiration_date = datetime.utcfromtimestamp(expiration_time_ts)
- return change_tag_expiration(tag_manifest.tag, expiration_date)
+ expiration_time_ts = tag_manifest.tag.lifetime_start_ts + expiration_sec
+ expiration_date = datetime.utcfromtimestamp(expiration_time_ts)
+ return change_tag_expiration(tag_manifest.tag, expiration_date)
def change_tag_expiration(tag, expiration_date):
- """ Changes the expiration of the given tag to the given expiration datetime. If
+ """ Changes the expiration of the given tag to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring.
"""
- end_ts = None
- min_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MINIMUM', '1h'))
- max_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MAXIMUM', '104w'))
+ end_ts = None
+ min_expire_sec = convert_to_timedelta(
+ config.app_config.get("LABELED_EXPIRATION_MINIMUM", "1h")
+ )
+ max_expire_sec = convert_to_timedelta(
+ config.app_config.get("LABELED_EXPIRATION_MAXIMUM", "104w")
+ )
- if expiration_date is not None:
- offset = timegm(expiration_date.utctimetuple()) - tag.lifetime_start_ts
- offset = min(max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds())
- end_ts = tag.lifetime_start_ts + offset
+ if expiration_date is not None:
+ offset = timegm(expiration_date.utctimetuple()) - tag.lifetime_start_ts
+ offset = min(
+ max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds()
+ )
+ end_ts = tag.lifetime_start_ts + offset
- if end_ts == tag.lifetime_end_ts:
- return (None, True)
+ if end_ts == tag.lifetime_end_ts:
+ return (None, True)
- return set_tag_end_ts(tag, end_ts)
+ return set_tag_end_ts(tag, end_ts)
def set_tag_end_ts(tag, end_ts):
- """ Sets the end timestamp for a tag. Should only be called by change_tag_expiration
+ """ Sets the end timestamp for a tag. Should only be called by change_tag_expiration
or tests.
"""
- end_ms = end_ts * 1000 if end_ts is not None else None
+ end_ms = end_ts * 1000 if end_ts is not None else None
- with db_transaction():
- # Note: We check not just the ID of the tag but also its lifetime_end_ts, to ensure that it has
- # not changed while we were updating it expiration.
- result = (RepositoryTag
- .update(lifetime_end_ts=end_ts)
- .where(RepositoryTag.id == tag.id,
- RepositoryTag.lifetime_end_ts == tag.lifetime_end_ts)
- .execute())
+ with db_transaction():
+ # Note: We check not just the ID of the tag but also its lifetime_end_ts, to ensure that it has
+ # not changed while we were updating it expiration.
+ result = (
+ RepositoryTag.update(lifetime_end_ts=end_ts)
+ .where(
+ RepositoryTag.id == tag.id,
+ RepositoryTag.lifetime_end_ts == tag.lifetime_end_ts,
+ )
+ .execute()
+ )
- # Check for a mapping to an OCI tag.
- try:
- oci_tag = (Tag
- .select()
- .join(TagToRepositoryTag)
- .where(TagToRepositoryTag.repository_tag == tag)
- .get())
+ # Check for a mapping to an OCI tag.
+ try:
+ oci_tag = (
+ Tag.select()
+ .join(TagToRepositoryTag)
+ .where(TagToRepositoryTag.repository_tag == tag)
+ .get()
+ )
- (Tag
- .update(lifetime_end_ms=end_ms)
- .where(Tag.id == oci_tag.id,
- Tag.lifetime_end_ms == oci_tag.lifetime_end_ms)
- .execute())
- except Tag.DoesNotExist:
- pass
+ (
+ Tag.update(lifetime_end_ms=end_ms)
+ .where(
+ Tag.id == oci_tag.id, Tag.lifetime_end_ms == oci_tag.lifetime_end_ms
+ )
+ .execute()
+ )
+ except Tag.DoesNotExist:
+ pass
- return (tag.lifetime_end_ts, result > 0)
+ return (tag.lifetime_end_ts, result > 0)
def find_matching_tag(repo_id, tag_names):
- """ Finds the most recently pushed alive tag in the repository with one of the given names,
+ """ Finds the most recently pushed alive tag in the repository with one of the given names,
if any.
"""
- try:
- return (_tag_alive(RepositoryTag
- .select()
- .where(RepositoryTag.repository == repo_id,
- RepositoryTag.name << list(tag_names))
- .order_by(RepositoryTag.lifetime_start_ts.desc()))
- .get())
- except RepositoryTag.DoesNotExist:
- return None
+ try:
+ return _tag_alive(
+ RepositoryTag.select()
+ .where(
+ RepositoryTag.repository == repo_id,
+ RepositoryTag.name << list(tag_names),
+ )
+ .order_by(RepositoryTag.lifetime_start_ts.desc())
+ ).get()
+ except RepositoryTag.DoesNotExist:
+ return None
def get_most_recent_tag(repo_id):
- """ Returns the most recently pushed alive tag in the repository, or None if none. """
- try:
- return (_tag_alive(RepositoryTag
- .select()
- .where(RepositoryTag.repository == repo_id, RepositoryTag.hidden == False)
- .order_by(RepositoryTag.lifetime_start_ts.desc()))
- .get())
- except RepositoryTag.DoesNotExist:
- return None
+ """ Returns the most recently pushed alive tag in the repository, or None if none. """
+ try:
+ return _tag_alive(
+ RepositoryTag.select()
+ .where(RepositoryTag.repository == repo_id, RepositoryTag.hidden == False)
+ .order_by(RepositoryTag.lifetime_start_ts.desc())
+ ).get()
+ except RepositoryTag.DoesNotExist:
+ return None
diff --git a/data/model/team.py b/data/model/team.py
index 4988d74ac..3dca6fd15 100644
--- a/data/model/team.py
+++ b/data/model/team.py
@@ -5,10 +5,26 @@ import uuid
from datetime import datetime
from peewee import fn
-from data.database import (Team, TeamMember, TeamRole, User, TeamMemberInvite, RepositoryPermission,
- TeamSync, LoginService, FederatedLogin, db_random_func, db_transaction)
-from data.model import (DataModelException, InvalidTeamException, UserAlreadyInTeam,
- InvalidTeamMemberException, _basequery)
+from data.database import (
+ Team,
+ TeamMember,
+ TeamRole,
+ User,
+ TeamMemberInvite,
+ RepositoryPermission,
+ TeamSync,
+ LoginService,
+ FederatedLogin,
+ db_random_func,
+ db_transaction,
+)
+from data.model import (
+ DataModelException,
+ InvalidTeamException,
+ UserAlreadyInTeam,
+ InvalidTeamMemberException,
+ _basequery,
+)
from data.text import prefix_search
from util.validation import validate_username
from util.morecollections import AttrDict
@@ -17,503 +33,567 @@ from util.morecollections import AttrDict
MIN_TEAMNAME_LENGTH = 2
MAX_TEAMNAME_LENGTH = 255
-VALID_TEAMNAME_REGEX = r'^([a-z0-9]+(?:[._-][a-z0-9]+)*)$'
+VALID_TEAMNAME_REGEX = r"^([a-z0-9]+(?:[._-][a-z0-9]+)*)$"
def validate_team_name(teamname):
- if not re.match(VALID_TEAMNAME_REGEX, teamname):
- return (False, 'Namespace must match expression ' + VALID_TEAMNAME_REGEX)
+ if not re.match(VALID_TEAMNAME_REGEX, teamname):
+ return (False, "Namespace must match expression " + VALID_TEAMNAME_REGEX)
- length_match = (len(teamname) >= MIN_TEAMNAME_LENGTH and len(teamname) <= MAX_TEAMNAME_LENGTH)
- if not length_match:
- return (False, 'Team must be between %s and %s characters in length' %
- (MIN_TEAMNAME_LENGTH, MAX_TEAMNAME_LENGTH))
+ length_match = (
+ len(teamname) >= MIN_TEAMNAME_LENGTH and len(teamname) <= MAX_TEAMNAME_LENGTH
+ )
+ if not length_match:
+ return (
+ False,
+ "Team must be between %s and %s characters in length"
+ % (MIN_TEAMNAME_LENGTH, MAX_TEAMNAME_LENGTH),
+ )
- return (True, '')
+ return (True, "")
-def create_team(name, org_obj, team_role_name, description=''):
- (teamname_valid, teamname_issue) = validate_team_name(name)
- if not teamname_valid:
- raise InvalidTeamException('Invalid team name %s: %s' % (name, teamname_issue))
+def create_team(name, org_obj, team_role_name, description=""):
+ (teamname_valid, teamname_issue) = validate_team_name(name)
+ if not teamname_valid:
+ raise InvalidTeamException("Invalid team name %s: %s" % (name, teamname_issue))
- if not org_obj.organization:
- raise InvalidTeamException('Specified organization %s was not an organization' %
- org_obj.username)
+ if not org_obj.organization:
+ raise InvalidTeamException(
+ "Specified organization %s was not an organization" % org_obj.username
+ )
- team_role = TeamRole.get(TeamRole.name == team_role_name)
- return Team.create(name=name, organization=org_obj, role=team_role,
- description=description)
+ team_role = TeamRole.get(TeamRole.name == team_role_name)
+ return Team.create(
+ name=name, organization=org_obj, role=team_role, description=description
+ )
def add_user_to_team(user_obj, team):
- try:
- return TeamMember.create(user=user_obj, team=team)
- except Exception:
- raise UserAlreadyInTeam('User %s is already a member of team %s' %
- (user_obj.username, team.name))
+ try:
+ return TeamMember.create(user=user_obj, team=team)
+ except Exception:
+ raise UserAlreadyInTeam(
+ "User %s is already a member of team %s" % (user_obj.username, team.name)
+ )
def remove_user_from_team(org_name, team_name, username, removed_by_username):
- Org = User.alias()
- joined = TeamMember.select().join(User).switch(TeamMember).join(Team)
- with_role = joined.join(TeamRole)
- with_org = with_role.switch(Team).join(Org,
- on=(Org.id == Team.organization))
- found = list(with_org.where(User.username == username,
- Org.username == org_name,
- Team.name == team_name))
+ Org = User.alias()
+ joined = TeamMember.select().join(User).switch(TeamMember).join(Team)
+ with_role = joined.join(TeamRole)
+ with_org = with_role.switch(Team).join(Org, on=(Org.id == Team.organization))
+ found = list(
+ with_org.where(
+ User.username == username, Org.username == org_name, Team.name == team_name
+ )
+ )
- if not found:
- raise DataModelException('User %s does not belong to team %s' %
- (username, team_name))
+ if not found:
+ raise DataModelException(
+ "User %s does not belong to team %s" % (username, team_name)
+ )
- if username == removed_by_username:
- admin_team_query = __get_user_admin_teams(org_name, username)
- admin_team_names = {team.name for team in admin_team_query}
- if team_name in admin_team_names and len(admin_team_names) <= 1:
- msg = 'User cannot remove themselves from their only admin team.'
- raise DataModelException(msg)
+ if username == removed_by_username:
+ admin_team_query = __get_user_admin_teams(org_name, username)
+ admin_team_names = {team.name for team in admin_team_query}
+ if team_name in admin_team_names and len(admin_team_names) <= 1:
+ msg = "User cannot remove themselves from their only admin team."
+ raise DataModelException(msg)
- user_in_team = found[0]
- user_in_team.delete_instance()
+ user_in_team = found[0]
+ user_in_team.delete_instance()
def set_team_org_permission(team, team_role_name, set_by_username):
- if team.role.name == 'admin' and team_role_name != 'admin':
- # We need to make sure we're not removing the users only admin role
- user_admin_teams = __get_user_admin_teams(team.organization.username, set_by_username)
- admin_team_set = {admin_team.name for admin_team in user_admin_teams}
- if team.name in admin_team_set and len(admin_team_set) <= 1:
- msg = (('Cannot remove admin from team \'%s\' because calling user ' +
- 'would no longer have admin on org \'%s\'') %
- (team.name, team.organization.username))
- raise DataModelException(msg)
+ if team.role.name == "admin" and team_role_name != "admin":
+ # We need to make sure we're not removing the users only admin role
+ user_admin_teams = __get_user_admin_teams(
+ team.organization.username, set_by_username
+ )
+ admin_team_set = {admin_team.name for admin_team in user_admin_teams}
+ if team.name in admin_team_set and len(admin_team_set) <= 1:
+ msg = (
+ "Cannot remove admin from team '%s' because calling user "
+ + "would no longer have admin on org '%s'"
+ ) % (team.name, team.organization.username)
+ raise DataModelException(msg)
- new_role = TeamRole.get(TeamRole.name == team_role_name)
- team.role = new_role
- team.save()
- return team
+ new_role = TeamRole.get(TeamRole.name == team_role_name)
+ team.role = new_role
+ team.save()
+ return team
def __get_user_admin_teams(org_name, username):
- Org = User.alias()
- user_teams = Team.select().join(TeamMember).join(User)
- with_org = user_teams.switch(Team).join(Org,
- on=(Org.id == Team.organization))
- with_role = with_org.switch(Team).join(TeamRole)
- admin_teams = with_role.where(User.username == username,
- Org.username == org_name,
- TeamRole.name == 'admin')
- return admin_teams
+ Org = User.alias()
+ user_teams = Team.select().join(TeamMember).join(User)
+ with_org = user_teams.switch(Team).join(Org, on=(Org.id == Team.organization))
+ with_role = with_org.switch(Team).join(TeamRole)
+ admin_teams = with_role.where(
+ User.username == username, Org.username == org_name, TeamRole.name == "admin"
+ )
+ return admin_teams
def remove_team(org_name, team_name, removed_by_username):
- joined = Team.select(Team, TeamRole).join(User).switch(Team).join(TeamRole)
+ joined = Team.select(Team, TeamRole).join(User).switch(Team).join(TeamRole)
- found = list(joined.where(User.organization == True,
- User.username == org_name,
- Team.name == team_name))
- if not found:
- raise InvalidTeamException('Team \'%s\' is not a team in org \'%s\'' %
- (team_name, org_name))
+ found = list(
+ joined.where(
+ User.organization == True, User.username == org_name, Team.name == team_name
+ )
+ )
+ if not found:
+ raise InvalidTeamException(
+ "Team '%s' is not a team in org '%s'" % (team_name, org_name)
+ )
- team = found[0]
- if team.role.name == 'admin':
- admin_teams = list(__get_user_admin_teams(org_name, removed_by_username))
- if len(admin_teams) <= 1:
- # The team we are trying to remove is the only admin team containing this user.
- msg = "Deleting team '%s' would remove admin ability for user '%s' in organization '%s'"
- raise DataModelException(msg % (team_name, removed_by_username, org_name))
+ team = found[0]
+ if team.role.name == "admin":
+ admin_teams = list(__get_user_admin_teams(org_name, removed_by_username))
+ if len(admin_teams) <= 1:
+ # The team we are trying to remove is the only admin team containing this user.
+ msg = "Deleting team '%s' would remove admin ability for user '%s' in organization '%s'"
+ raise DataModelException(msg % (team_name, removed_by_username, org_name))
- team.delete_instance(recursive=True, delete_nullable=True)
+ team.delete_instance(recursive=True, delete_nullable=True)
-def add_or_invite_to_team(inviter, team, user_obj=None, email=None, requires_invite=True):
- # If the user is a member of the organization, then we simply add the
- # user directly to the team. Otherwise, an invite is created for the user/email.
- # We return None if the user was directly added and the invite object if the user was invited.
- if user_obj and requires_invite:
- orgname = team.organization.username
+def add_or_invite_to_team(
+ inviter, team, user_obj=None, email=None, requires_invite=True
+):
+ # If the user is a member of the organization, then we simply add the
+ # user directly to the team. Otherwise, an invite is created for the user/email.
+ # We return None if the user was directly added and the invite object if the user was invited.
+ if user_obj and requires_invite:
+ orgname = team.organization.username
- # If the user is part of the organization (or a robot), then no invite is required.
- if user_obj.robot:
- requires_invite = False
- if not user_obj.username.startswith(orgname + '+'):
- raise InvalidTeamMemberException('Cannot add the specified robot to this team, ' +
- 'as it is not a member of the organization')
- else:
- query = (TeamMember
- .select()
- .where(TeamMember.user == user_obj)
- .join(Team)
- .join(User)
- .where(User.username == orgname, User.organization == True))
- requires_invite = not any(query)
+ # If the user is part of the organization (or a robot), then no invite is required.
+ if user_obj.robot:
+ requires_invite = False
+ if not user_obj.username.startswith(orgname + "+"):
+ raise InvalidTeamMemberException(
+ "Cannot add the specified robot to this team, "
+ + "as it is not a member of the organization"
+ )
+ else:
+ query = (
+ TeamMember.select()
+ .where(TeamMember.user == user_obj)
+ .join(Team)
+ .join(User)
+ .where(User.username == orgname, User.organization == True)
+ )
+ requires_invite = not any(query)
- # If we have a valid user and no invite is required, simply add the user to the team.
- if user_obj and not requires_invite:
- add_user_to_team(user_obj, team)
- return None
+ # If we have a valid user and no invite is required, simply add the user to the team.
+ if user_obj and not requires_invite:
+ add_user_to_team(user_obj, team)
+ return None
- email_address = email if not user_obj else None
- return TeamMemberInvite.create(user=user_obj, email=email_address, team=team, inviter=inviter)
+ email_address = email if not user_obj else None
+ return TeamMemberInvite.create(
+ user=user_obj, email=email_address, team=team, inviter=inviter
+ )
def get_matching_user_teams(team_prefix, user_obj, limit=10):
- team_prefix_search = prefix_search(Team.name, team_prefix)
- query = (Team
- .select(Team.id.distinct(), Team)
- .join(User)
- .switch(Team)
- .join(TeamMember)
- .where(TeamMember.user == user_obj, team_prefix_search)
- .limit(limit))
+ team_prefix_search = prefix_search(Team.name, team_prefix)
+ query = (
+ Team.select(Team.id.distinct(), Team)
+ .join(User)
+ .switch(Team)
+ .join(TeamMember)
+ .where(TeamMember.user == user_obj, team_prefix_search)
+ .limit(limit)
+ )
- return query
+ return query
def get_organization_team(orgname, teamname):
- joined = Team.select().join(User)
- query = joined.where(Team.name == teamname, User.organization == True,
- User.username == orgname).limit(1)
- result = list(query)
- if not result:
- raise InvalidTeamException('Team does not exist: %s/%s', orgname,
- teamname)
+ joined = Team.select().join(User)
+ query = joined.where(
+ Team.name == teamname, User.organization == True, User.username == orgname
+ ).limit(1)
+ result = list(query)
+ if not result:
+ raise InvalidTeamException("Team does not exist: %s/%s", orgname, teamname)
- return result[0]
+ return result[0]
def get_matching_admined_teams(team_prefix, user_obj, limit=10):
- team_prefix_search = prefix_search(Team.name, team_prefix)
- admined_orgs = (_basequery.get_user_organizations(user_obj.username)
- .switch(Team)
- .join(TeamRole)
- .where(TeamRole.name == 'admin'))
+ team_prefix_search = prefix_search(Team.name, team_prefix)
+ admined_orgs = (
+ _basequery.get_user_organizations(user_obj.username)
+ .switch(Team)
+ .join(TeamRole)
+ .where(TeamRole.name == "admin")
+ )
- query = (Team
- .select(Team.id.distinct(), Team)
- .join(User)
- .switch(Team)
- .join(TeamMember)
- .where(team_prefix_search, Team.organization << (admined_orgs))
- .limit(limit))
+ query = (
+ Team.select(Team.id.distinct(), Team)
+ .join(User)
+ .switch(Team)
+ .join(TeamMember)
+ .where(team_prefix_search, Team.organization << (admined_orgs))
+ .limit(limit)
+ )
- return query
+ return query
def get_matching_teams(team_prefix, organization):
- team_prefix_search = prefix_search(Team.name, team_prefix)
- query = Team.select().where(team_prefix_search, Team.organization == organization)
- return query.limit(10)
+ team_prefix_search = prefix_search(Team.name, team_prefix)
+ query = Team.select().where(team_prefix_search, Team.organization == organization)
+ return query.limit(10)
def get_teams_within_org(organization, has_external_auth=False):
- """ Returns a AttrDict of team info (id, name, description), its role under the org,
+ """ Returns a AttrDict of team info (id, name, description), its role under the org,
the number of repositories on which it has permission, and the number of members.
"""
- query = (Team.select()
- .where(Team.organization == organization)
- .join(TeamRole))
+ query = Team.select().where(Team.organization == organization).join(TeamRole)
- def _team_view(team):
- return {
- 'id': team.id,
- 'name': team.name,
- 'description': team.description,
- 'role_name': Team.role.get_name(team.role_id),
+ def _team_view(team):
+ return {
+ "id": team.id,
+ "name": team.name,
+ "description": team.description,
+ "role_name": Team.role.get_name(team.role_id),
+ "repo_count": 0,
+ "member_count": 0,
+ "is_synced": False,
+ }
- 'repo_count': 0,
- 'member_count': 0,
+ teams = {team.id: _team_view(team) for team in query}
+ if not teams:
+ # Just in case. Should ideally never happen.
+ return []
- 'is_synced': False,
- }
+ # Add repository permissions count.
+ permission_tuples = (
+ RepositoryPermission.select(
+ RepositoryPermission.team, fn.Count(RepositoryPermission.id)
+ )
+ .where(RepositoryPermission.team << teams.keys())
+ .group_by(RepositoryPermission.team)
+ .tuples()
+ )
- teams = {team.id: _team_view(team) for team in query}
- if not teams:
- # Just in case. Should ideally never happen.
- return []
+ for perm_tuple in permission_tuples:
+ teams[perm_tuple[0]]["repo_count"] = perm_tuple[1]
- # Add repository permissions count.
- permission_tuples = (RepositoryPermission.select(RepositoryPermission.team,
- fn.Count(RepositoryPermission.id))
- .where(RepositoryPermission.team << teams.keys())
- .group_by(RepositoryPermission.team)
- .tuples())
+ # Add the member count.
+ members_tuples = (
+ TeamMember.select(TeamMember.team, fn.Count(TeamMember.id))
+ .where(TeamMember.team << teams.keys())
+ .group_by(TeamMember.team)
+ .tuples()
+ )
- for perm_tuple in permission_tuples:
- teams[perm_tuple[0]]['repo_count'] = perm_tuple[1]
+ for member_tuple in members_tuples:
+ teams[member_tuple[0]]["member_count"] = member_tuple[1]
- # Add the member count.
- members_tuples = (TeamMember.select(TeamMember.team,
- fn.Count(TeamMember.id))
- .where(TeamMember.team << teams.keys())
- .group_by(TeamMember.team)
- .tuples())
+ # Add syncing information.
+ if has_external_auth:
+ sync_query = TeamSync.select(TeamSync.team).where(TeamSync.team << teams.keys())
+ for team_sync in sync_query:
+ teams[team_sync.team_id]["is_synced"] = True
- for member_tuple in members_tuples:
- teams[member_tuple[0]]['member_count'] = member_tuple[1]
-
- # Add syncing information.
- if has_external_auth:
- sync_query = TeamSync.select(TeamSync.team).where(TeamSync.team << teams.keys())
- for team_sync in sync_query:
- teams[team_sync.team_id]['is_synced'] = True
-
- return [AttrDict(team_info) for team_info in teams.values()]
+ return [AttrDict(team_info) for team_info in teams.values()]
def get_user_teams_within_org(username, organization):
- joined = Team.select().join(TeamMember).join(User)
- return joined.where(Team.organization == organization,
- User.username == username)
+ joined = Team.select().join(TeamMember).join(User)
+ return joined.where(Team.organization == organization, User.username == username)
def list_organization_members_by_teams(organization):
- query = (TeamMember
- .select(Team, User)
- .join(Team)
- .switch(TeamMember)
- .join(User)
- .where(Team.organization == organization))
- return query
+ query = (
+ TeamMember.select(Team, User)
+ .join(Team)
+ .switch(TeamMember)
+ .join(User)
+ .where(Team.organization == organization)
+ )
+ return query
def get_organization_team_member_invites(teamid):
- joined = TeamMemberInvite.select().join(Team).join(User)
- query = joined.where(Team.id == teamid)
- return query
+ joined = TeamMemberInvite.select().join(Team).join(User)
+ query = joined.where(Team.id == teamid)
+ return query
def delete_team_email_invite(team, email):
- try:
- found = TeamMemberInvite.get(TeamMemberInvite.email == email, TeamMemberInvite.team == team)
- except TeamMemberInvite.DoesNotExist:
- return False
+ try:
+ found = TeamMemberInvite.get(
+ TeamMemberInvite.email == email, TeamMemberInvite.team == team
+ )
+ except TeamMemberInvite.DoesNotExist:
+ return False
- found.delete_instance()
- return True
+ found.delete_instance()
+ return True
def delete_team_user_invite(team, user_obj):
- try:
- found = TeamMemberInvite.get(TeamMemberInvite.user == user_obj, TeamMemberInvite.team == team)
- except TeamMemberInvite.DoesNotExist:
- return False
+ try:
+ found = TeamMemberInvite.get(
+ TeamMemberInvite.user == user_obj, TeamMemberInvite.team == team
+ )
+ except TeamMemberInvite.DoesNotExist:
+ return False
- found.delete_instance()
- return True
+ found.delete_instance()
+ return True
def lookup_team_invites_by_email(email):
- return TeamMemberInvite.select().where(TeamMemberInvite.email == email)
+ return TeamMemberInvite.select().where(TeamMemberInvite.email == email)
def lookup_team_invites(user_obj):
- return TeamMemberInvite.select().where(TeamMemberInvite.user == user_obj)
+ return TeamMemberInvite.select().where(TeamMemberInvite.user == user_obj)
def lookup_team_invite(code, user_obj=None):
- # Lookup the invite code.
- try:
- found = TeamMemberInvite.get(TeamMemberInvite.invite_token == code)
- except TeamMemberInvite.DoesNotExist:
- raise DataModelException('Invalid confirmation code.')
+ # Lookup the invite code.
+ try:
+ found = TeamMemberInvite.get(TeamMemberInvite.invite_token == code)
+ except TeamMemberInvite.DoesNotExist:
+ raise DataModelException("Invalid confirmation code.")
- if user_obj and found.user != user_obj:
- raise DataModelException('Invalid confirmation code.')
+ if user_obj and found.user != user_obj:
+ raise DataModelException("Invalid confirmation code.")
- return found
+ return found
def delete_team_invite(code, user_obj=None):
- found = lookup_team_invite(code, user_obj)
+ found = lookup_team_invite(code, user_obj)
- team = found.team
- inviter = found.inviter
+ team = found.team
+ inviter = found.inviter
- found.delete_instance()
+ found.delete_instance()
- return (team, inviter)
+ return (team, inviter)
def find_matching_team_invite(code, user_obj):
- """ Finds a team invite with the given code that applies to the given user and returns it or
+ """ Finds a team invite with the given code that applies to the given user and returns it or
raises a DataModelException if not found. """
- found = lookup_team_invite(code)
+ found = lookup_team_invite(code)
- # If the invite is for a specific user, we have to confirm that here.
- if found.user is not None and found.user != user_obj:
- message = """This invite is intended for user "%s".
- Please login to that account and try again.""" % found.user.username
- raise DataModelException(message)
+ # If the invite is for a specific user, we have to confirm that here.
+ if found.user is not None and found.user != user_obj:
+ message = (
+ """This invite is intended for user "%s".
+ Please login to that account and try again."""
+ % found.user.username
+ )
+ raise DataModelException(message)
- return found
+ return found
def find_organization_invites(organization, user_obj):
- """ Finds all organization team invites for the given user under the given organization. """
- invite_check = (TeamMemberInvite.user == user_obj)
- if user_obj.verified:
- invite_check = invite_check | (TeamMemberInvite.email == user_obj.email)
+ """ Finds all organization team invites for the given user under the given organization. """
+ invite_check = TeamMemberInvite.user == user_obj
+ if user_obj.verified:
+ invite_check = invite_check | (TeamMemberInvite.email == user_obj.email)
- query = (TeamMemberInvite
- .select()
- .join(Team)
- .where(invite_check, Team.organization == organization))
- return query
+ query = (
+ TeamMemberInvite.select()
+ .join(Team)
+ .where(invite_check, Team.organization == organization)
+ )
+ return query
def confirm_team_invite(code, user_obj):
- """ Confirms the given team invite code for the given user by adding the user to the team
+ """ Confirms the given team invite code for the given user by adding the user to the team
and deleting the code. Raises a DataModelException if the code was not found or does
not apply to the given user. If the user is invited to two or more teams under the
same organization, they are automatically confirmed for all of them. """
- found = find_matching_team_invite(code, user_obj)
+ found = find_matching_team_invite(code, user_obj)
- # Find all matching invitations for the user under the organization.
- code_found = False
- for invite in find_organization_invites(found.team.organization, user_obj):
- # Add the user to the team.
- try:
- code_found = True
- add_user_to_team(user_obj, invite.team)
- except UserAlreadyInTeam:
- # Ignore.
- pass
+ # Find all matching invitations for the user under the organization.
+ code_found = False
+ for invite in find_organization_invites(found.team.organization, user_obj):
+ # Add the user to the team.
+ try:
+ code_found = True
+ add_user_to_team(user_obj, invite.team)
+ except UserAlreadyInTeam:
+ # Ignore.
+ pass
- # Delete the invite and return the team.
- invite.delete_instance()
+ # Delete the invite and return the team.
+ invite.delete_instance()
- if not code_found:
- if found.user:
- message = """This invite is intended for user "%s".
- Please login to that account and try again.""" % found.user.username
- raise DataModelException(message)
- else:
- message = """This invite is intended for email "%s".
- Please login to that account and try again.""" % found.email
- raise DataModelException(message)
+ if not code_found:
+ if found.user:
+ message = (
+ """This invite is intended for user "%s".
+ Please login to that account and try again."""
+ % found.user.username
+ )
+ raise DataModelException(message)
+ else:
+ message = (
+ """This invite is intended for email "%s".
+ Please login to that account and try again."""
+ % found.email
+ )
+ raise DataModelException(message)
- team = found.team
- inviter = found.inviter
- return (team, inviter)
+ team = found.team
+ inviter = found.inviter
+ return (team, inviter)
def get_federated_team_member_mapping(team, login_service_name):
- """ Returns a dict of all federated IDs for all team members in the team whose users are
+ """ Returns a dict of all federated IDs for all team members in the team whose users are
bound to the login service within the given name. The dictionary is from federated service
identifier (username) to their Quay User table ID.
"""
- login_service = LoginService.get(name=login_service_name)
+ login_service = LoginService.get(name=login_service_name)
- query = (FederatedLogin
- .select(FederatedLogin.service_ident, User.id)
- .join(User)
- .join(TeamMember)
- .join(Team)
- .where(Team.id == team, User.robot == False, FederatedLogin.service == login_service))
- return dict(query.tuples())
+ query = (
+ FederatedLogin.select(FederatedLogin.service_ident, User.id)
+ .join(User)
+ .join(TeamMember)
+ .join(Team)
+ .where(
+ Team.id == team,
+ User.robot == False,
+ FederatedLogin.service == login_service,
+ )
+ )
+ return dict(query.tuples())
def list_team_users(team):
- """ Returns an iterator of all the *users* found in a team. Does not include robots. """
- return (User
- .select()
- .join(TeamMember)
- .join(Team)
- .where(Team.id == team, User.robot == False))
+ """ Returns an iterator of all the *users* found in a team. Does not include robots. """
+ return (
+ User.select()
+ .join(TeamMember)
+ .join(Team)
+ .where(Team.id == team, User.robot == False)
+ )
def list_team_robots(team):
- """ Returns an iterator of all the *robots* found in a team. Does not include users. """
- return (User
- .select()
- .join(TeamMember)
- .join(Team)
- .where(Team.id == team, User.robot == True))
+ """ Returns an iterator of all the *robots* found in a team. Does not include users. """
+ return (
+ User.select()
+ .join(TeamMember)
+ .join(Team)
+ .where(Team.id == team, User.robot == True)
+ )
def set_team_syncing(team, login_service_name, config):
- """ Sets the given team to sync to the given service using the given config. """
- login_service = LoginService.get(name=login_service_name)
- return TeamSync.create(team=team, transaction_id='', service=login_service,
- config=json.dumps(config))
+ """ Sets the given team to sync to the given service using the given config. """
+ login_service = LoginService.get(name=login_service_name)
+ return TeamSync.create(
+ team=team, transaction_id="", service=login_service, config=json.dumps(config)
+ )
def remove_team_syncing(orgname, teamname):
- """ Removes syncing on the team matching the given organization name and team name. """
- existing = get_team_sync_information(orgname, teamname)
- if existing:
- existing.delete_instance()
+ """ Removes syncing on the team matching the given organization name and team name. """
+ existing = get_team_sync_information(orgname, teamname)
+ if existing:
+ existing.delete_instance()
def get_stale_team(stale_timespan):
- """ Returns a team that is setup to sync to an external group, and who has not been synced in
+ """ Returns a team that is setup to sync to an external group, and who has not been synced in
now - stale_timespan. Returns None if none found.
"""
- stale_at = datetime.now() - stale_timespan
+ stale_at = datetime.now() - stale_timespan
- try:
- candidates = (TeamSync
- .select(TeamSync.id)
- .where((TeamSync.last_updated <= stale_at) | (TeamSync.last_updated >> None))
- .limit(500)
- .alias('candidates'))
+ try:
+ candidates = (
+ TeamSync.select(TeamSync.id)
+ .where(
+ (TeamSync.last_updated <= stale_at) | (TeamSync.last_updated >> None)
+ )
+ .limit(500)
+ .alias("candidates")
+ )
- found = (TeamSync
- .select(candidates.c.id)
- .from_(candidates)
- .order_by(db_random_func())
- .get())
+ found = (
+ TeamSync.select(candidates.c.id)
+ .from_(candidates)
+ .order_by(db_random_func())
+ .get()
+ )
- if found is None:
- return
+ if found is None:
+ return
- return TeamSync.select(TeamSync, Team).join(Team).where(TeamSync.id == found.id).get()
- except TeamSync.DoesNotExist:
- return None
+ return (
+ TeamSync.select(TeamSync, Team)
+ .join(Team)
+ .where(TeamSync.id == found.id)
+ .get()
+ )
+ except TeamSync.DoesNotExist:
+ return None
def get_team_sync_information(orgname, teamname):
- """ Returns the team syncing information for the team with the given name under the organization
+ """ Returns the team syncing information for the team with the given name under the organization
with the given name or None if none.
"""
- query = (TeamSync
- .select(TeamSync, LoginService)
- .join(Team)
- .join(User)
- .switch(TeamSync)
- .join(LoginService)
- .where(Team.name == teamname, User.organization == True, User.username == orgname))
+ query = (
+ TeamSync.select(TeamSync, LoginService)
+ .join(Team)
+ .join(User)
+ .switch(TeamSync)
+ .join(LoginService)
+ .where(
+ Team.name == teamname, User.organization == True, User.username == orgname
+ )
+ )
- try:
- return query.get()
- except TeamSync.DoesNotExist:
- return None
+ try:
+ return query.get()
+ except TeamSync.DoesNotExist:
+ return None
def update_sync_status(team_sync_info):
- """ Attempts to update the transaction ID and last updated time on a TeamSync object. If the
+ """ Attempts to update the transaction ID and last updated time on a TeamSync object. If the
transaction ID on the entry in the DB does not match that found on the object, this method
returns False, which indicates another caller updated it first.
"""
- new_transaction_id = str(uuid.uuid4())
- query = (TeamSync
- .update(transaction_id=new_transaction_id, last_updated=datetime.now())
- .where(TeamSync.id == team_sync_info.id,
- TeamSync.transaction_id == team_sync_info.transaction_id))
- return query.execute() == 1
+ new_transaction_id = str(uuid.uuid4())
+ query = TeamSync.update(
+ transaction_id=new_transaction_id, last_updated=datetime.now()
+ ).where(
+ TeamSync.id == team_sync_info.id,
+ TeamSync.transaction_id == team_sync_info.transaction_id,
+ )
+ return query.execute() == 1
def delete_members_not_present(team, member_id_set):
- """ Deletes all members of the given team that are not found in the member ID set. """
- with db_transaction():
- user_ids = set([u.id for u in list_team_users(team)])
- to_delete = list(user_ids - member_id_set)
- if to_delete:
- query = TeamMember.delete().where(TeamMember.team == team, TeamMember.user << to_delete)
- return query.execute()
+ """ Deletes all members of the given team that are not found in the member ID set. """
+ with db_transaction():
+ user_ids = set([u.id for u in list_team_users(team)])
+ to_delete = list(user_ids - member_id_set)
+ if to_delete:
+ query = TeamMember.delete().where(
+ TeamMember.team == team, TeamMember.user << to_delete
+ )
+ return query.execute()
- return 0
+ return 0
diff --git a/data/model/test/test_appspecifictoken.py b/data/model/test/test_appspecifictoken.py
index 96a7491f5..abfea13b7 100644
--- a/data/model/test/test_appspecifictoken.py
+++ b/data/model/test/test_appspecifictoken.py
@@ -12,115 +12,112 @@ from util.timedeltastring import convert_to_timedelta
from test.fixtures import *
-@pytest.mark.parametrize('expiration', [
- (None),
- ('-1m'),
- ('-1d'),
- ('-1w'),
- ('10m'),
- ('10d'),
- ('10w'),
-])
+
+@pytest.mark.parametrize(
+ "expiration", [(None), ("-1m"), ("-1d"), ("-1w"), ("10m"), ("10d"), ("10w")]
+)
def test_gc(expiration, initialized_db):
- user = model.user.get_user('devtable')
+ user = model.user.get_user("devtable")
- expiration_date = None
- is_expired = False
- if expiration:
- if expiration[0] == '-':
- is_expired = True
- expiration_date = datetime.now() - convert_to_timedelta(expiration[1:])
- else:
- expiration_date = datetime.now() + convert_to_timedelta(expiration)
+ expiration_date = None
+ is_expired = False
+ if expiration:
+ if expiration[0] == "-":
+ is_expired = True
+ expiration_date = datetime.now() - convert_to_timedelta(expiration[1:])
+ else:
+ expiration_date = datetime.now() + convert_to_timedelta(expiration)
- # Create a token.
- token = create_token(user, 'Some token', expiration=expiration_date)
+ # Create a token.
+ token = create_token(user, "Some token", expiration=expiration_date)
- # GC tokens.
- gc_expired_tokens(timedelta(seconds=0))
+ # GC tokens.
+ gc_expired_tokens(timedelta(seconds=0))
- # Ensure the token was GCed if expired and not if it wasn't.
- assert (access_valid_token(get_full_token_string(token)) is None) == is_expired
+ # Ensure the token was GCed if expired and not if it wasn't.
+ assert (access_valid_token(get_full_token_string(token)) is None) == is_expired
def test_access_token(initialized_db):
- user = model.user.get_user('devtable')
+ user = model.user.get_user("devtable")
- # Create a token.
- token = create_token(user, 'Some token')
- assert token.last_accessed is None
+ # Create a token.
+ token = create_token(user, "Some token")
+ assert token.last_accessed is None
- # Lookup the token.
- token = access_valid_token(get_full_token_string(token))
- assert token.last_accessed is not None
+ # Lookup the token.
+ token = access_valid_token(get_full_token_string(token))
+ assert token.last_accessed is not None
- # Revoke the token.
- revoke_token(token)
+ # Revoke the token.
+ revoke_token(token)
- # Ensure it cannot be accessed
- assert access_valid_token(get_full_token_string(token)) is None
+ # Ensure it cannot be accessed
+ assert access_valid_token(get_full_token_string(token)) is None
def test_expiring_soon(initialized_db):
- user = model.user.get_user('devtable')
+ user = model.user.get_user("devtable")
- # Create some tokens.
- create_token(user, 'Some token')
- exp_token = create_token(user, 'Some expiring token', datetime.now() + convert_to_timedelta('1d'))
- create_token(user, 'Some other token', expiration=datetime.now() + convert_to_timedelta('2d'))
+ # Create some tokens.
+ create_token(user, "Some token")
+ exp_token = create_token(
+ user, "Some expiring token", datetime.now() + convert_to_timedelta("1d")
+ )
+ create_token(
+ user, "Some other token", expiration=datetime.now() + convert_to_timedelta("2d")
+ )
- # Get the token expiring soon.
- expiring_soon = get_expiring_tokens(user, convert_to_timedelta('25h'))
- assert expiring_soon
- assert len(expiring_soon) == 1
- assert expiring_soon[0].id == exp_token.id
+ # Get the token expiring soon.
+ expiring_soon = get_expiring_tokens(user, convert_to_timedelta("25h"))
+ assert expiring_soon
+ assert len(expiring_soon) == 1
+ assert expiring_soon[0].id == exp_token.id
- expiring_soon = get_expiring_tokens(user, convert_to_timedelta('49h'))
- assert expiring_soon
- assert len(expiring_soon) == 2
+ expiring_soon = get_expiring_tokens(user, convert_to_timedelta("49h"))
+ assert expiring_soon
+ assert len(expiring_soon) == 2
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def app_config():
- with patch.dict(_config.app_config, {}, clear=True):
- yield _config.app_config
-
-@pytest.mark.parametrize('expiration', [
- (None),
- ('10m'),
- ('10d'),
- ('10w'),
-])
-@pytest.mark.parametrize('default_expiration', [
- (None),
- ('10m'),
- ('10d'),
- ('10w'),
-])
-def test_create_access_token(expiration, default_expiration, initialized_db, app_config):
- user = model.user.get_user('devtable')
- expiration_date = datetime.now() + convert_to_timedelta(expiration) if expiration else None
- with patch.dict(_config.app_config, {}, clear=True):
- app_config['APP_SPECIFIC_TOKEN_EXPIRATION'] = default_expiration
- if expiration:
- exp_token = create_token(user, 'Some token', expiration=expiration_date)
- assert exp_token.expiration == expiration_date
- else:
- exp_token = create_token(user, 'Some token')
- assert (exp_token.expiration is None) == (default_expiration is None)
+ with patch.dict(_config.app_config, {}, clear=True):
+ yield _config.app_config
-@pytest.mark.parametrize('invalid_token', [
- '',
- 'foo',
- 'a' * 40,
- 'b' * 40,
- '%s%s' % ('b' * 40, 'a' * 40),
- '%s%s' % ('a' * 39, 'b' * 40),
- '%s%s' % ('a' * 40, 'b' * 39),
- '%s%s' % ('a' * 40, 'b' * 41),
-])
+@pytest.mark.parametrize("expiration", [(None), ("10m"), ("10d"), ("10w")])
+@pytest.mark.parametrize("default_expiration", [(None), ("10m"), ("10d"), ("10w")])
+def test_create_access_token(
+ expiration, default_expiration, initialized_db, app_config
+):
+ user = model.user.get_user("devtable")
+ expiration_date = (
+ datetime.now() + convert_to_timedelta(expiration) if expiration else None
+ )
+ with patch.dict(_config.app_config, {}, clear=True):
+ app_config["APP_SPECIFIC_TOKEN_EXPIRATION"] = default_expiration
+ if expiration:
+ exp_token = create_token(user, "Some token", expiration=expiration_date)
+ assert exp_token.expiration == expiration_date
+ else:
+ exp_token = create_token(user, "Some token")
+ assert (exp_token.expiration is None) == (default_expiration is None)
+
+
+@pytest.mark.parametrize(
+ "invalid_token",
+ [
+ "",
+ "foo",
+ "a" * 40,
+ "b" * 40,
+ "%s%s" % ("b" * 40, "a" * 40),
+ "%s%s" % ("a" * 39, "b" * 40),
+ "%s%s" % ("a" * 40, "b" * 39),
+ "%s%s" % ("a" * 40, "b" * 41),
+ ],
+)
def test_invalid_access_token(invalid_token, initialized_db):
- user = model.user.get_user('devtable')
- token = access_valid_token(invalid_token)
- assert token is None
+ user = model.user.get_user("devtable")
+ token = access_valid_token(invalid_token)
+ assert token is None
diff --git a/data/model/test/test_basequery.py b/data/model/test/test_basequery.py
index 84e248327..9a3d9567a 100644
--- a/data/model/test/test_basequery.py
+++ b/data/model/test/test_basequery.py
@@ -11,97 +11,101 @@ from util.names import parse_robot_username
from test.fixtures import *
-def _is_team_member(team, user):
- return user.id in [member.user_id for member in
- TeamMember.select().where(TeamMember.team == team)]
-def _get_visible_repositories_for_user(user, repo_kind='image', include_public=False,
- namespace=None):
- """ Returns all repositories directly visible to the given user, by either repo permission,
+def _is_team_member(team, user):
+ return user.id in [
+ member.user_id for member in TeamMember.select().where(TeamMember.team == team)
+ ]
+
+
+def _get_visible_repositories_for_user(
+ user, repo_kind="image", include_public=False, namespace=None
+):
+ """ Returns all repositories directly visible to the given user, by either repo permission,
or the user being the admin of a namespace.
"""
- for repo in Repository.select():
- if repo_kind is not None and repo.kind.name != repo_kind:
- continue
+ for repo in Repository.select():
+ if repo_kind is not None and repo.kind.name != repo_kind:
+ continue
- if namespace is not None and repo.namespace_user.username != namespace:
- continue
+ if namespace is not None and repo.namespace_user.username != namespace:
+ continue
- if include_public and repo.visibility.name == 'public':
- yield repo
- continue
+ if include_public and repo.visibility.name == "public":
+ yield repo
+ continue
- # Direct repo permission.
- try:
- RepositoryPermission.get(repository=repo, user=user).get()
- yield repo
- continue
- except RepositoryPermission.DoesNotExist:
- pass
+ # Direct repo permission.
+ try:
+ RepositoryPermission.get(repository=repo, user=user).get()
+ yield repo
+ continue
+ except RepositoryPermission.DoesNotExist:
+ pass
- # Team permission.
- found_in_team = False
- for perm in RepositoryPermission.select().where(RepositoryPermission.repository == repo):
- if perm.team and _is_team_member(perm.team, user):
- found_in_team = True
- break
+ # Team permission.
+ found_in_team = False
+ for perm in RepositoryPermission.select().where(
+ RepositoryPermission.repository == repo
+ ):
+ if perm.team and _is_team_member(perm.team, user):
+ found_in_team = True
+ break
- if found_in_team:
- yield repo
- continue
+ if found_in_team:
+ yield repo
+ continue
- # Org namespace admin permission.
- if user in get_admin_users(repo.namespace_user):
- yield repo
- continue
+ # Org namespace admin permission.
+ if user in get_admin_users(repo.namespace_user):
+ yield repo
+ continue
-@pytest.mark.parametrize('username', [
- 'devtable',
- 'devtable+dtrobot',
- 'public',
- 'reader',
-])
-@pytest.mark.parametrize('include_public', [
- True,
- False
-])
-@pytest.mark.parametrize('filter_to_namespace', [
- True,
- False
-])
-@pytest.mark.parametrize('repo_kind', [
- None,
- 'image',
- 'application',
-])
-def test_filter_repositories(username, include_public, filter_to_namespace, repo_kind,
- initialized_db):
- namespace = username if filter_to_namespace else None
- if '+' in username and filter_to_namespace:
- namespace, _ = parse_robot_username(username)
+@pytest.mark.parametrize(
+ "username", ["devtable", "devtable+dtrobot", "public", "reader"]
+)
+@pytest.mark.parametrize("include_public", [True, False])
+@pytest.mark.parametrize("filter_to_namespace", [True, False])
+@pytest.mark.parametrize("repo_kind", [None, "image", "application"])
+def test_filter_repositories(
+ username, include_public, filter_to_namespace, repo_kind, initialized_db
+):
+ namespace = username if filter_to_namespace else None
+ if "+" in username and filter_to_namespace:
+ namespace, _ = parse_robot_username(username)
- user = get_namespace_user(username)
- query = (Repository
- .select()
- .distinct()
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .switch(Repository)
- .join(RepositoryPermission, JOIN.LEFT_OUTER))
+ user = get_namespace_user(username)
+ query = (
+ Repository.select()
+ .distinct()
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .switch(Repository)
+ .join(RepositoryPermission, JOIN.LEFT_OUTER)
+ )
- # Prime the cache.
- Repository.kind.get_id('image')
+ # Prime the cache.
+ Repository.kind.get_id("image")
- with assert_query_count(1):
- found = list(filter_to_repos_for_user(query, user.id,
- namespace=namespace,
- include_public=include_public,
- repo_kind=repo_kind))
+ with assert_query_count(1):
+ found = list(
+ filter_to_repos_for_user(
+ query,
+ user.id,
+ namespace=namespace,
+ include_public=include_public,
+ repo_kind=repo_kind,
+ )
+ )
- expected = list(_get_visible_repositories_for_user(user,
- repo_kind=repo_kind,
- namespace=namespace,
- include_public=include_public))
+ expected = list(
+ _get_visible_repositories_for_user(
+ user,
+ repo_kind=repo_kind,
+ namespace=namespace,
+ include_public=include_public,
+ )
+ )
- assert len(found) == len(expected)
- assert {r.id for r in found} == {r.id for r in expected}
+ assert len(found) == len(expected)
+ assert {r.id for r in found} == {r.id for r in expected}
diff --git a/data/model/test/test_build.py b/data/model/test/test_build.py
index c43d6e683..7fe95a7f9 100644
--- a/data/model/test/test_build.py
+++ b/data/model/test/test_build.py
@@ -3,105 +3,137 @@ import pytest
from mock import patch
from data.database import BUILD_PHASE, RepositoryBuildTrigger, RepositoryBuild
-from data.model.build import (update_trigger_disable_status, create_repository_build,
- get_repository_build, update_phase_then_close)
+from data.model.build import (
+ update_trigger_disable_status,
+ create_repository_build,
+ get_repository_build,
+ update_phase_then_close,
+)
from test.fixtures import *
TEST_FAIL_THRESHOLD = 5
TEST_INTERNAL_ERROR_THRESHOLD = 2
-@pytest.mark.parametrize('starting_failure_count, starting_error_count, status, expected_reason', [
- (0, 0, BUILD_PHASE.COMPLETE, None),
- (10, 10, BUILD_PHASE.COMPLETE, None),
- (TEST_FAIL_THRESHOLD - 1, TEST_INTERNAL_ERROR_THRESHOLD - 1, BUILD_PHASE.COMPLETE, None),
- (TEST_FAIL_THRESHOLD - 1, 0, BUILD_PHASE.ERROR, 'successive_build_failures'),
- (0, TEST_INTERNAL_ERROR_THRESHOLD - 1, BUILD_PHASE.INTERNAL_ERROR,
- 'successive_build_internal_errors'),
-])
-def test_update_trigger_disable_status(starting_failure_count, starting_error_count, status,
- expected_reason, initialized_db):
- test_config = {
- 'SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD': TEST_FAIL_THRESHOLD,
- 'SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD': TEST_INTERNAL_ERROR_THRESHOLD,
- }
+@pytest.mark.parametrize(
+ "starting_failure_count, starting_error_count, status, expected_reason",
+ [
+ (0, 0, BUILD_PHASE.COMPLETE, None),
+ (10, 10, BUILD_PHASE.COMPLETE, None),
+ (
+ TEST_FAIL_THRESHOLD - 1,
+ TEST_INTERNAL_ERROR_THRESHOLD - 1,
+ BUILD_PHASE.COMPLETE,
+ None,
+ ),
+ (TEST_FAIL_THRESHOLD - 1, 0, BUILD_PHASE.ERROR, "successive_build_failures"),
+ (
+ 0,
+ TEST_INTERNAL_ERROR_THRESHOLD - 1,
+ BUILD_PHASE.INTERNAL_ERROR,
+ "successive_build_internal_errors",
+ ),
+ ],
+)
+def test_update_trigger_disable_status(
+ starting_failure_count,
+ starting_error_count,
+ status,
+ expected_reason,
+ initialized_db,
+):
+ test_config = {
+ "SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD": TEST_FAIL_THRESHOLD,
+ "SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD": TEST_INTERNAL_ERROR_THRESHOLD,
+ }
- trigger = model.build.list_build_triggers('devtable', 'building')[0]
- trigger.successive_failure_count = starting_failure_count
- trigger.successive_internal_error_count = starting_error_count
- trigger.enabled = True
- trigger.save()
+ trigger = model.build.list_build_triggers("devtable", "building")[0]
+ trigger.successive_failure_count = starting_failure_count
+ trigger.successive_internal_error_count = starting_error_count
+ trigger.enabled = True
+ trigger.save()
- with patch('data.model.config.app_config', test_config):
- update_trigger_disable_status(trigger, status)
- updated_trigger = RepositoryBuildTrigger.get(uuid=trigger.uuid)
+ with patch("data.model.config.app_config", test_config):
+ update_trigger_disable_status(trigger, status)
+ updated_trigger = RepositoryBuildTrigger.get(uuid=trigger.uuid)
- assert updated_trigger.enabled == (expected_reason is None)
+ assert updated_trigger.enabled == (expected_reason is None)
- if expected_reason is not None:
- assert updated_trigger.disabled_reason.name == expected_reason
- else:
- assert updated_trigger.disabled_reason is None
- assert updated_trigger.successive_failure_count == 0
- assert updated_trigger.successive_internal_error_count == 0
+ if expected_reason is not None:
+ assert updated_trigger.disabled_reason.name == expected_reason
+ else:
+ assert updated_trigger.disabled_reason is None
+ assert updated_trigger.successive_failure_count == 0
+ assert updated_trigger.successive_internal_error_count == 0
def test_archivable_build_logs(initialized_db):
- # Make sure there are no archivable logs.
- result = model.build.get_archivable_build()
- assert result is None
+ # Make sure there are no archivable logs.
+ result = model.build.get_archivable_build()
+ assert result is None
- # Add a build that cannot (yet) be archived.
- repo = model.repository.get_repository('devtable', 'simple')
- token = model.token.create_access_token(repo, 'write')
- created = RepositoryBuild.create(repository=repo, access_token=token,
- phase=model.build.BUILD_PHASE.WAITING,
- logs_archived=False, job_config='{}',
- display_name='')
+ # Add a build that cannot (yet) be archived.
+ repo = model.repository.get_repository("devtable", "simple")
+ token = model.token.create_access_token(repo, "write")
+ created = RepositoryBuild.create(
+ repository=repo,
+ access_token=token,
+ phase=model.build.BUILD_PHASE.WAITING,
+ logs_archived=False,
+ job_config="{}",
+ display_name="",
+ )
- # Make sure there are no archivable logs.
- result = model.build.get_archivable_build()
- assert result is None
+ # Make sure there are no archivable logs.
+ result = model.build.get_archivable_build()
+ assert result is None
- # Change the build to being complete.
- created.phase = model.build.BUILD_PHASE.COMPLETE
- created.save()
+ # Change the build to being complete.
+ created.phase = model.build.BUILD_PHASE.COMPLETE
+ created.save()
- # Make sure we now find an archivable build.
- result = model.build.get_archivable_build()
- assert result.id == created.id
+ # Make sure we now find an archivable build.
+ result = model.build.get_archivable_build()
+ assert result.id == created.id
def test_update_build_phase(initialized_db):
- build = create_build(model.repository.get_repository("devtable", "building"))
+ build = create_build(model.repository.get_repository("devtable", "building"))
- repo_build = get_repository_build(build.uuid)
+ repo_build = get_repository_build(build.uuid)
- assert repo_build.phase == BUILD_PHASE.WAITING
- assert update_phase_then_close(build.uuid, BUILD_PHASE.COMPLETE)
+ assert repo_build.phase == BUILD_PHASE.WAITING
+ assert update_phase_then_close(build.uuid, BUILD_PHASE.COMPLETE)
- repo_build = get_repository_build(build.uuid)
- assert repo_build.phase == BUILD_PHASE.COMPLETE
+ repo_build = get_repository_build(build.uuid)
+ assert repo_build.phase == BUILD_PHASE.COMPLETE
- repo_build.delete_instance()
- assert not update_phase_then_close(repo_build.uuid, BUILD_PHASE.PULLING)
+ repo_build.delete_instance()
+ assert not update_phase_then_close(repo_build.uuid, BUILD_PHASE.PULLING)
def create_build(repository):
- new_token = model.token.create_access_token(repository, 'write', 'build-worker')
- repo = 'ci.devtable.com:5000/%s/%s' % (repository.namespace_user.username, repository.name)
- job_config = {
- 'repository': repo,
- 'docker_tags': ['latest'],
- 'build_subdir': '',
- 'trigger_metadata': {
- 'commit': '3482adc5822c498e8f7db2e361e8d57b3d77ddd9',
- 'ref': 'refs/heads/master',
- 'default_branch': 'master'
+ new_token = model.token.create_access_token(repository, "write", "build-worker")
+ repo = "ci.devtable.com:5000/%s/%s" % (
+ repository.namespace_user.username,
+ repository.name,
+ )
+ job_config = {
+ "repository": repo,
+ "docker_tags": ["latest"],
+ "build_subdir": "",
+ "trigger_metadata": {
+ "commit": "3482adc5822c498e8f7db2e361e8d57b3d77ddd9",
+ "ref": "refs/heads/master",
+ "default_branch": "master",
+ },
}
- }
- build = create_repository_build(repository, new_token, job_config,
- '68daeebd-a5b9-457f-80a0-4363b882f8ea',
- "build_name")
- build.save()
- return build
+ build = create_repository_build(
+ repository,
+ new_token,
+ job_config,
+ "68daeebd-a5b9-457f-80a0-4363b882f8ea",
+ "build_name",
+ )
+ build.save()
+ return build
diff --git a/data/model/test/test_gc.py b/data/model/test/test_gc.py
index 79d13779e..656b75bbb 100644
--- a/data/model/test/test_gc.py
+++ b/data/model/test/test_gc.py
@@ -13,9 +13,19 @@ from playhouse.test_utils import assert_query_count
from freezegun import freeze_time
from data import model, database
-from data.database import (Image, ImageStorage, DerivedStorageForImage, Label, TagManifestLabel,
- ApprBlob, Manifest, TagManifestToManifest, ManifestBlob, Tag,
- TagToRepositoryTag)
+from data.database import (
+ Image,
+ ImageStorage,
+ DerivedStorageForImage,
+ Label,
+ TagManifestLabel,
+ ApprBlob,
+ Manifest,
+ TagManifestToManifest,
+ ManifestBlob,
+ Tag,
+ TagToRepositoryTag,
+)
from data.model.oci.test.test_oci_manifest import create_manifest_for_testing
from image.docker.schema1 import DockerSchema1ManifestBuilder
from image.docker.schema2.manifest import DockerSchema2ManifestBuilder
@@ -25,701 +35,787 @@ from util.bytes import Bytes
from test.fixtures import *
-ADMIN_ACCESS_USER = 'devtable'
-PUBLIC_USER = 'public'
+ADMIN_ACCESS_USER = "devtable"
+PUBLIC_USER = "public"
+
+REPO = "somerepo"
-REPO = 'somerepo'
def _set_tag_expiration_policy(namespace, expiration_s):
- namespace_user = model.user.get_user(namespace)
- model.user.change_user_tag_expiration(namespace_user, expiration_s)
+ namespace_user = model.user.get_user(namespace)
+ model.user.change_user_tag_expiration(namespace_user, expiration_s)
@pytest.fixture()
def default_tag_policy(initialized_db):
- _set_tag_expiration_policy(ADMIN_ACCESS_USER, 0)
- _set_tag_expiration_policy(PUBLIC_USER, 0)
+ _set_tag_expiration_policy(ADMIN_ACCESS_USER, 0)
+ _set_tag_expiration_policy(PUBLIC_USER, 0)
def create_image(docker_image_id, repository_obj, username):
- preferred = storage.preferred_locations[0]
- image = model.image.find_create_or_link_image(docker_image_id, repository_obj, username, {},
- preferred)
- image.storage.uploading = False
- image.storage.save()
+ preferred = storage.preferred_locations[0]
+ image = model.image.find_create_or_link_image(
+ docker_image_id, repository_obj, username, {}, preferred
+ )
+ image.storage.uploading = False
+ image.storage.save()
- # Create derived images as well.
- model.image.find_or_create_derived_storage(image, 'squash', preferred)
- model.image.find_or_create_derived_storage(image, 'aci', preferred)
-
- # Add some torrent info.
- try:
- database.TorrentInfo.get(storage=image.storage)
- except database.TorrentInfo.DoesNotExist:
- model.storage.save_torrent_info(image.storage, 1, 'helloworld')
-
- # Add some additional placements to the image.
- for location_name in ['local_eu']:
- location = database.ImageStorageLocation.get(name=location_name)
+ # Create derived images as well.
+ model.image.find_or_create_derived_storage(image, "squash", preferred)
+ model.image.find_or_create_derived_storage(image, "aci", preferred)
+ # Add some torrent info.
try:
- database.ImageStoragePlacement.get(location=location, storage=image.storage)
- except:
- continue
+ database.TorrentInfo.get(storage=image.storage)
+ except database.TorrentInfo.DoesNotExist:
+ model.storage.save_torrent_info(image.storage, 1, "helloworld")
- database.ImageStoragePlacement.create(location=location, storage=image.storage)
+ # Add some additional placements to the image.
+ for location_name in ["local_eu"]:
+ location = database.ImageStorageLocation.get(name=location_name)
- return image.storage
+ try:
+ database.ImageStoragePlacement.get(location=location, storage=image.storage)
+ except:
+ continue
+
+ database.ImageStoragePlacement.create(location=location, storage=image.storage)
+
+ return image.storage
def store_tag_manifest(namespace, repo_name, tag_name, image_id):
- builder = DockerSchema1ManifestBuilder(namespace, repo_name, tag_name)
- storage_id_map = {}
- try:
- image_storage = ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).get()
- builder.add_layer(image_storage.content_checksum, '{"id": "foo"}')
- storage_id_map[image_storage.content_checksum] = image_storage.id
- except ImageStorage.DoesNotExist:
- pass
+ builder = DockerSchema1ManifestBuilder(namespace, repo_name, tag_name)
+ storage_id_map = {}
+ try:
+ image_storage = (
+ ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).get()
+ )
+ builder.add_layer(image_storage.content_checksum, '{"id": "foo"}')
+ storage_id_map[image_storage.content_checksum] = image_storage.id
+ except ImageStorage.DoesNotExist:
+ pass
- manifest = builder.build(docker_v2_signing_key)
- manifest_row, _ = model.tag.store_tag_manifest_for_testing(namespace, repo_name, tag_name,
- manifest, image_id, storage_id_map)
- return manifest_row
+ manifest = builder.build(docker_v2_signing_key)
+ manifest_row, _ = model.tag.store_tag_manifest_for_testing(
+ namespace, repo_name, tag_name, manifest, image_id, storage_id_map
+ )
+ return manifest_row
def create_repository(namespace=ADMIN_ACCESS_USER, name=REPO, **kwargs):
- user = model.user.get_user(namespace)
- repo = model.repository.create_repository(namespace, name, user)
+ user = model.user.get_user(namespace)
+ repo = model.repository.create_repository(namespace, name, user)
- # Populate the repository with the tags.
- image_map = {}
- for tag_name in kwargs:
- image_ids = kwargs[tag_name]
- parent = None
+ # Populate the repository with the tags.
+ image_map = {}
+ for tag_name in kwargs:
+ image_ids = kwargs[tag_name]
+ parent = None
- for image_id in image_ids:
- if not image_id in image_map:
- image_map[image_id] = create_image(image_id, repo, namespace)
+ for image_id in image_ids:
+ if not image_id in image_map:
+ image_map[image_id] = create_image(image_id, repo, namespace)
- v1_metadata = {
- 'id': image_id,
- }
- if parent is not None:
- v1_metadata['parent'] = parent.docker_image_id
+ v1_metadata = {"id": image_id}
+ if parent is not None:
+ v1_metadata["parent"] = parent.docker_image_id
- # Set the ancestors for the image.
- parent = model.image.set_image_metadata(image_id, namespace, name, '', '', '', v1_metadata,
- parent=parent)
+ # Set the ancestors for the image.
+ parent = model.image.set_image_metadata(
+ image_id, namespace, name, "", "", "", v1_metadata, parent=parent
+ )
- # Set the tag for the image.
- tag_manifest = store_tag_manifest(namespace, name, tag_name, image_ids[-1])
+ # Set the tag for the image.
+ tag_manifest = store_tag_manifest(namespace, name, tag_name, image_ids[-1])
- # Add some labels to the tag.
- model.label.create_manifest_label(tag_manifest, 'foo', 'bar', 'manifest')
- model.label.create_manifest_label(tag_manifest, 'meh', 'grah', 'manifest')
+ # Add some labels to the tag.
+ model.label.create_manifest_label(tag_manifest, "foo", "bar", "manifest")
+ model.label.create_manifest_label(tag_manifest, "meh", "grah", "manifest")
- return repo
+ return repo
def gc_now(repository):
- assert model.gc.garbage_collect_repo(repository)
+ assert model.gc.garbage_collect_repo(repository)
def delete_tag(repository, tag, perform_gc=True, expect_gc=True):
- model.tag.delete_tag(repository.namespace_user.username, repository.name, tag)
- if perform_gc:
- assert model.gc.garbage_collect_repo(repository) == expect_gc
+ model.tag.delete_tag(repository.namespace_user.username, repository.name, tag)
+ if perform_gc:
+ assert model.gc.garbage_collect_repo(repository) == expect_gc
def move_tag(repository, tag, docker_image_id, expect_gc=True):
- model.tag.create_or_update_tag(repository.namespace_user.username, repository.name, tag,
- docker_image_id)
- assert model.gc.garbage_collect_repo(repository) == expect_gc
+ model.tag.create_or_update_tag(
+ repository.namespace_user.username, repository.name, tag, docker_image_id
+ )
+ assert model.gc.garbage_collect_repo(repository) == expect_gc
def assert_not_deleted(repository, *args):
- for docker_image_id in args:
- assert model.image.get_image_by_id(repository.namespace_user.username, repository.name,
- docker_image_id)
+ for docker_image_id in args:
+ assert model.image.get_image_by_id(
+ repository.namespace_user.username, repository.name, docker_image_id
+ )
def assert_deleted(repository, *args):
- for docker_image_id in args:
- try:
- # Verify the image is missing when accessed by the repository.
- model.image.get_image_by_id(repository.namespace_user.username, repository.name,
- docker_image_id)
- except model.DataModelException:
- return
+ for docker_image_id in args:
+ try:
+ # Verify the image is missing when accessed by the repository.
+ model.image.get_image_by_id(
+ repository.namespace_user.username, repository.name, docker_image_id
+ )
+ except model.DataModelException:
+ return
- assert False, 'Expected image %s to be deleted' % docker_image_id
+ assert False, "Expected image %s to be deleted" % docker_image_id
def _get_dangling_storage_count():
- storage_ids = set([current.id for current in ImageStorage.select()])
- referenced_by_image = set([image.storage_id for image in Image.select()])
- referenced_by_manifest = set([blob.blob_id for blob in ManifestBlob.select()])
- referenced_by_derived = set([derived.derivative_id
- for derived in DerivedStorageForImage.select()])
- return len(storage_ids - referenced_by_image - referenced_by_derived - referenced_by_manifest)
+ storage_ids = set([current.id for current in ImageStorage.select()])
+ referenced_by_image = set([image.storage_id for image in Image.select()])
+ referenced_by_manifest = set([blob.blob_id for blob in ManifestBlob.select()])
+ referenced_by_derived = set(
+ [derived.derivative_id for derived in DerivedStorageForImage.select()]
+ )
+ return len(
+ storage_ids
+ - referenced_by_image
+ - referenced_by_derived
+ - referenced_by_manifest
+ )
def _get_dangling_label_count():
- return len(_get_dangling_labels())
+ return len(_get_dangling_labels())
def _get_dangling_labels():
- label_ids = set([current.id for current in Label.select()])
- referenced_by_manifest = set([mlabel.label_id for mlabel in TagManifestLabel.select()])
- return label_ids - referenced_by_manifest
+ label_ids = set([current.id for current in Label.select()])
+ referenced_by_manifest = set(
+ [mlabel.label_id for mlabel in TagManifestLabel.select()]
+ )
+ return label_ids - referenced_by_manifest
def _get_dangling_manifest_count():
- manifest_ids = set([current.id for current in Manifest.select()])
- referenced_by_tag_manifest = set([tmt.manifest_id for tmt in TagManifestToManifest.select()])
- return len(manifest_ids - referenced_by_tag_manifest)
-
+ manifest_ids = set([current.id for current in Manifest.select()])
+ referenced_by_tag_manifest = set(
+ [tmt.manifest_id for tmt in TagManifestToManifest.select()]
+ )
+ return len(manifest_ids - referenced_by_tag_manifest)
@contextmanager
def assert_gc_integrity(expect_storage_removed=True, check_oci_tags=True):
- """ Specialized assertion for ensuring that GC cleans up all dangling storages
+ """ Specialized assertion for ensuring that GC cleans up all dangling storages
and labels, invokes the callback for images removed and doesn't invoke the
callback for images *not* removed.
"""
- # Add a callback for when images are removed.
- removed_image_storages = []
- model.config.register_image_cleanup_callback(removed_image_storages.extend)
+ # Add a callback for when images are removed.
+ removed_image_storages = []
+ model.config.register_image_cleanup_callback(removed_image_storages.extend)
- # Store the number of dangling storages and labels.
- existing_storage_count = _get_dangling_storage_count()
- existing_label_count = _get_dangling_label_count()
- existing_manifest_count = _get_dangling_manifest_count()
- yield
+ # Store the number of dangling storages and labels.
+ existing_storage_count = _get_dangling_storage_count()
+ existing_label_count = _get_dangling_label_count()
+ existing_manifest_count = _get_dangling_manifest_count()
+ yield
- # Ensure the number of dangling storages, manifests and labels has not changed.
- updated_storage_count = _get_dangling_storage_count()
- assert updated_storage_count == existing_storage_count
+ # Ensure the number of dangling storages, manifests and labels has not changed.
+ updated_storage_count = _get_dangling_storage_count()
+ assert updated_storage_count == existing_storage_count
- updated_label_count = _get_dangling_label_count()
- assert updated_label_count == existing_label_count, _get_dangling_labels()
+ updated_label_count = _get_dangling_label_count()
+ assert updated_label_count == existing_label_count, _get_dangling_labels()
- updated_manifest_count = _get_dangling_manifest_count()
- assert updated_manifest_count == existing_manifest_count
+ updated_manifest_count = _get_dangling_manifest_count()
+ assert updated_manifest_count == existing_manifest_count
- # Ensure that for each call to the image+storage cleanup callback, the image and its
- # storage is not found *anywhere* in the database.
- for removed_image_and_storage in removed_image_storages:
- with pytest.raises(Image.DoesNotExist):
- Image.get(id=removed_image_and_storage.id)
+ # Ensure that for each call to the image+storage cleanup callback, the image and its
+ # storage is not found *anywhere* in the database.
+ for removed_image_and_storage in removed_image_storages:
+ with pytest.raises(Image.DoesNotExist):
+ Image.get(id=removed_image_and_storage.id)
- # Ensure that image storages are only removed if not shared.
- shared = Image.select().where(Image.storage == removed_image_and_storage.storage_id).count()
- if shared == 0:
- shared = (ManifestBlob
- .select()
+ # Ensure that image storages are only removed if not shared.
+ shared = (
+ Image.select()
+ .where(Image.storage == removed_image_and_storage.storage_id)
+ .count()
+ )
+ if shared == 0:
+ shared = (
+ ManifestBlob.select()
.where(ManifestBlob.blob == removed_image_and_storage.storage_id)
- .count())
+ .count()
+ )
- if shared == 0:
- with pytest.raises(ImageStorage.DoesNotExist):
- ImageStorage.get(id=removed_image_and_storage.storage_id)
+ if shared == 0:
+ with pytest.raises(ImageStorage.DoesNotExist):
+ ImageStorage.get(id=removed_image_and_storage.storage_id)
- with pytest.raises(ImageStorage.DoesNotExist):
- ImageStorage.get(uuid=removed_image_and_storage.storage.uuid)
+ with pytest.raises(ImageStorage.DoesNotExist):
+ ImageStorage.get(uuid=removed_image_and_storage.storage.uuid)
- # Ensure all CAS storage is in the storage engine.
- preferred = storage.preferred_locations[0]
- for storage_row in ImageStorage.select():
- if storage_row.cas_path:
- storage.get_content({preferred}, storage.blob_path(storage_row.content_checksum))
+ # Ensure all CAS storage is in the storage engine.
+ preferred = storage.preferred_locations[0]
+ for storage_row in ImageStorage.select():
+ if storage_row.cas_path:
+ storage.get_content(
+ {preferred}, storage.blob_path(storage_row.content_checksum)
+ )
- for blob_row in ApprBlob.select():
- storage.get_content({preferred}, storage.blob_path(blob_row.digest))
+ for blob_row in ApprBlob.select():
+ storage.get_content({preferred}, storage.blob_path(blob_row.digest))
- # Ensure there are no danglings OCI tags.
- if check_oci_tags:
- oci_tags = {t.id for t in Tag.select()}
- referenced_oci_tags = {t.tag_id for t in TagToRepositoryTag.select()}
- assert not oci_tags - referenced_oci_tags
+ # Ensure there are no danglings OCI tags.
+ if check_oci_tags:
+ oci_tags = {t.id for t in Tag.select()}
+ referenced_oci_tags = {t.tag_id for t in TagToRepositoryTag.select()}
+ assert not oci_tags - referenced_oci_tags
- # Ensure all tags have valid manifests.
- for manifest in {t.manifest for t in Tag.select()}:
- # Ensure that the manifest's blobs all exist.
- found_blobs = {b.blob.content_checksum
- for b in ManifestBlob.select().where(ManifestBlob.manifest == manifest)}
+ # Ensure all tags have valid manifests.
+ for manifest in {t.manifest for t in Tag.select()}:
+ # Ensure that the manifest's blobs all exist.
+ found_blobs = {
+ b.blob.content_checksum
+ for b in ManifestBlob.select().where(ManifestBlob.manifest == manifest)
+ }
- parsed = parse_manifest_from_bytes(Bytes.for_string_or_unicode(manifest.manifest_bytes),
- manifest.media_type.name)
- assert set(parsed.local_blob_digests) == found_blobs
+ parsed = parse_manifest_from_bytes(
+ Bytes.for_string_or_unicode(manifest.manifest_bytes),
+ manifest.media_type.name,
+ )
+ assert set(parsed.local_blob_digests) == found_blobs
def test_has_garbage(default_tag_policy, initialized_db):
- """ Remove all existing repositories, then add one without garbage, check, then add one with
+ """ Remove all existing repositories, then add one without garbage, check, then add one with
garbage, and check again.
"""
- # Delete all existing repos.
- for repo in database.Repository.select().order_by(database.Repository.id):
- assert model.gc.purge_repository(repo.namespace_user.username, repo.name)
+ # Delete all existing repos.
+ for repo in database.Repository.select().order_by(database.Repository.id):
+ assert model.gc.purge_repository(repo.namespace_user.username, repo.name)
- # Change the time machine expiration on the namespace.
- (database.User
- .update(removed_tag_expiration_s=1000000000)
- .where(database.User.username == ADMIN_ACCESS_USER)
- .execute())
+ # Change the time machine expiration on the namespace.
+ (
+ database.User.update(removed_tag_expiration_s=1000000000)
+ .where(database.User.username == ADMIN_ACCESS_USER)
+ .execute()
+ )
- # Create a repository without any garbage.
- repository = create_repository(latest=['i1', 'i2', 'i3'])
+ # Create a repository without any garbage.
+ repository = create_repository(latest=["i1", "i2", "i3"])
- # Ensure that no repositories are returned by the has garbage check.
- assert model.repository.find_repository_with_garbage(1000000000) is None
+ # Ensure that no repositories are returned by the has garbage check.
+ assert model.repository.find_repository_with_garbage(1000000000) is None
- # Delete a tag.
- delete_tag(repository, 'latest', perform_gc=False)
+ # Delete a tag.
+ delete_tag(repository, "latest", perform_gc=False)
- # There should still not be any repositories with garbage, due to time machine.
- assert model.repository.find_repository_with_garbage(1000000000) is None
+ # There should still not be any repositories with garbage, due to time machine.
+ assert model.repository.find_repository_with_garbage(1000000000) is None
- # Change the time machine expiration on the namespace.
- (database.User
- .update(removed_tag_expiration_s=0)
- .where(database.User.username == ADMIN_ACCESS_USER)
- .execute())
+ # Change the time machine expiration on the namespace.
+ (
+ database.User.update(removed_tag_expiration_s=0)
+ .where(database.User.username == ADMIN_ACCESS_USER)
+ .execute()
+ )
- # Now we should find the repository for GC.
- repository = model.repository.find_repository_with_garbage(0)
- assert repository is not None
- assert repository.name == REPO
+ # Now we should find the repository for GC.
+ repository = model.repository.find_repository_with_garbage(0)
+ assert repository is not None
+ assert repository.name == REPO
- # GC the repository.
- assert model.gc.garbage_collect_repo(repository)
+ # GC the repository.
+ assert model.gc.garbage_collect_repo(repository)
- # There should now be no repositories with garbage.
- assert model.repository.find_repository_with_garbage(0) is None
+ # There should now be no repositories with garbage.
+ assert model.repository.find_repository_with_garbage(0) is None
def test_find_garbage_policy_functions(default_tag_policy, initialized_db):
- with assert_query_count(1):
- one_policy = model.repository.get_random_gc_policy()
- all_policies = model.repository._get_gc_expiration_policies()
- assert one_policy in all_policies
+ with assert_query_count(1):
+ one_policy = model.repository.get_random_gc_policy()
+ all_policies = model.repository._get_gc_expiration_policies()
+ assert one_policy in all_policies
def test_one_tag(default_tag_policy, initialized_db):
- """ Create a repository with a single tag, then remove that tag and verify that the repository
+ """ Create a repository with a single tag, then remove that tag and verify that the repository
is now empty. """
- with assert_gc_integrity():
- repository = create_repository(latest=['i1', 'i2', 'i3'])
- delete_tag(repository, 'latest')
- assert_deleted(repository, 'i1', 'i2', 'i3')
+ with assert_gc_integrity():
+ repository = create_repository(latest=["i1", "i2", "i3"])
+ delete_tag(repository, "latest")
+ assert_deleted(repository, "i1", "i2", "i3")
def test_two_tags_unshared_images(default_tag_policy, initialized_db):
- """ Repository has two tags with no shared images between them. """
- with assert_gc_integrity():
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['f1', 'f2'])
- delete_tag(repository, 'latest')
- assert_deleted(repository, 'i1', 'i2', 'i3')
- assert_not_deleted(repository, 'f1', 'f2')
+ """ Repository has two tags with no shared images between them. """
+ with assert_gc_integrity():
+ repository = create_repository(latest=["i1", "i2", "i3"], other=["f1", "f2"])
+ delete_tag(repository, "latest")
+ assert_deleted(repository, "i1", "i2", "i3")
+ assert_not_deleted(repository, "f1", "f2")
def test_two_tags_shared_images(default_tag_policy, initialized_db):
- """ Repository has two tags with shared images. Deleting the tag should only remove the
+ """ Repository has two tags with shared images. Deleting the tag should only remove the
unshared images.
"""
- with assert_gc_integrity():
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
- delete_tag(repository, 'latest')
- assert_deleted(repository, 'i2', 'i3')
- assert_not_deleted(repository, 'i1', 'f1')
+ with assert_gc_integrity():
+ repository = create_repository(latest=["i1", "i2", "i3"], other=["i1", "f1"])
+ delete_tag(repository, "latest")
+ assert_deleted(repository, "i2", "i3")
+ assert_not_deleted(repository, "i1", "f1")
def test_unrelated_repositories(default_tag_policy, initialized_db):
- """ Two repositories with different images. Removing the tag from one leaves the other's
+ """ Two repositories with different images. Removing the tag from one leaves the other's
images intact.
"""
- with assert_gc_integrity():
- repository1 = create_repository(latest=['i1', 'i2', 'i3'], name='repo1')
- repository2 = create_repository(latest=['j1', 'j2', 'j3'], name='repo2')
+ with assert_gc_integrity():
+ repository1 = create_repository(latest=["i1", "i2", "i3"], name="repo1")
+ repository2 = create_repository(latest=["j1", "j2", "j3"], name="repo2")
- delete_tag(repository1, 'latest')
+ delete_tag(repository1, "latest")
- assert_deleted(repository1, 'i1', 'i2', 'i3')
- assert_not_deleted(repository2, 'j1', 'j2', 'j3')
+ assert_deleted(repository1, "i1", "i2", "i3")
+ assert_not_deleted(repository2, "j1", "j2", "j3")
def test_related_repositories(default_tag_policy, initialized_db):
- """ Two repositories with shared images. Removing the tag from one leaves the other's
+ """ Two repositories with shared images. Removing the tag from one leaves the other's
images intact.
"""
- with assert_gc_integrity():
- repository1 = create_repository(latest=['i1', 'i2', 'i3'], name='repo1')
- repository2 = create_repository(latest=['i1', 'i2', 'j1'], name='repo2')
+ with assert_gc_integrity():
+ repository1 = create_repository(latest=["i1", "i2", "i3"], name="repo1")
+ repository2 = create_repository(latest=["i1", "i2", "j1"], name="repo2")
- delete_tag(repository1, 'latest')
+ delete_tag(repository1, "latest")
- assert_deleted(repository1, 'i3')
- assert_not_deleted(repository2, 'i1', 'i2', 'j1')
+ assert_deleted(repository1, "i3")
+ assert_not_deleted(repository2, "i1", "i2", "j1")
def test_inaccessible_repositories(default_tag_policy, initialized_db):
- """ Two repositories under different namespaces should result in the images being deleted
+ """ Two repositories under different namespaces should result in the images being deleted
but not completely removed from the database.
"""
- with assert_gc_integrity():
- repository1 = create_repository(namespace=ADMIN_ACCESS_USER, latest=['i1', 'i2', 'i3'])
- repository2 = create_repository(namespace=PUBLIC_USER, latest=['i1', 'i2', 'i3'])
+ with assert_gc_integrity():
+ repository1 = create_repository(
+ namespace=ADMIN_ACCESS_USER, latest=["i1", "i2", "i3"]
+ )
+ repository2 = create_repository(
+ namespace=PUBLIC_USER, latest=["i1", "i2", "i3"]
+ )
- delete_tag(repository1, 'latest')
- assert_deleted(repository1, 'i1', 'i2', 'i3')
- assert_not_deleted(repository2, 'i1', 'i2', 'i3')
+ delete_tag(repository1, "latest")
+ assert_deleted(repository1, "i1", "i2", "i3")
+ assert_not_deleted(repository2, "i1", "i2", "i3")
def test_many_multiple_shared_images(default_tag_policy, initialized_db):
- """ Repository has multiple tags with shared images. Delete all but one tag.
+ """ Repository has multiple tags with shared images. Delete all but one tag.
"""
- with assert_gc_integrity():
- repository = create_repository(latest=['i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j0'],
- master=['i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1'])
+ with assert_gc_integrity():
+ repository = create_repository(
+ latest=["i1", "i2", "i3", "i4", "i5", "i6", "i7", "i8", "j0"],
+ master=["i1", "i2", "i3", "i4", "i5", "i6", "i7", "i8", "j1"],
+ )
- # Delete tag latest. Should only delete j0, since it is not shared.
- delete_tag(repository, 'latest')
+ # Delete tag latest. Should only delete j0, since it is not shared.
+ delete_tag(repository, "latest")
- assert_deleted(repository, 'j0')
- assert_not_deleted(repository, 'i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1')
+ assert_deleted(repository, "j0")
+ assert_not_deleted(
+ repository, "i1", "i2", "i3", "i4", "i5", "i6", "i7", "i8", "j1"
+ )
- # Delete tag master. Should delete the rest of the images.
- delete_tag(repository, 'master')
+ # Delete tag master. Should delete the rest of the images.
+ delete_tag(repository, "master")
- assert_deleted(repository, 'i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1')
+ assert_deleted(repository, "i1", "i2", "i3", "i4", "i5", "i6", "i7", "i8", "j1")
def test_multiple_shared_images(default_tag_policy, initialized_db):
- """ Repository has multiple tags with shared images. Selectively deleting the tags, and
+ """ Repository has multiple tags with shared images. Selectively deleting the tags, and
verifying at each step.
"""
- with assert_gc_integrity():
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1', 'f2'],
- third=['t1', 't2', 't3'], fourth=['i1', 'f1'])
+ with assert_gc_integrity():
+ repository = create_repository(
+ latest=["i1", "i2", "i3"],
+ other=["i1", "f1", "f2"],
+ third=["t1", "t2", "t3"],
+ fourth=["i1", "f1"],
+ )
- # Current state:
- # latest -> i3->i2->i1
- # other -> f2->f1->i1
- # third -> t3->t2->t1
- # fourth -> f1->i1
+ # Current state:
+ # latest -> i3->i2->i1
+ # other -> f2->f1->i1
+ # third -> t3->t2->t1
+ # fourth -> f1->i1
- # Delete tag other. Should delete f2, since it is not shared.
- delete_tag(repository, 'other')
- assert_deleted(repository, 'f2')
- assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3', 'f1')
+ # Delete tag other. Should delete f2, since it is not shared.
+ delete_tag(repository, "other")
+ assert_deleted(repository, "f2")
+ assert_not_deleted(repository, "i1", "i2", "i3", "t1", "t2", "t3", "f1")
- # Current state:
- # latest -> i3->i2->i1
- # third -> t3->t2->t1
- # fourth -> f1->i1
+ # Current state:
+ # latest -> i3->i2->i1
+ # third -> t3->t2->t1
+ # fourth -> f1->i1
- # Move tag fourth to i3. This should remove f1 since it is no longer referenced.
- move_tag(repository, 'fourth', 'i3')
- assert_deleted(repository, 'f1')
- assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3')
+ # Move tag fourth to i3. This should remove f1 since it is no longer referenced.
+ move_tag(repository, "fourth", "i3")
+ assert_deleted(repository, "f1")
+ assert_not_deleted(repository, "i1", "i2", "i3", "t1", "t2", "t3")
- # Current state:
- # latest -> i3->i2->i1
- # third -> t3->t2->t1
- # fourth -> i3->i2->i1
+ # Current state:
+ # latest -> i3->i2->i1
+ # third -> t3->t2->t1
+ # fourth -> i3->i2->i1
- # Delete tag 'latest'. This should do nothing since fourth is on the same branch.
- delete_tag(repository, 'latest')
- assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3')
+ # Delete tag 'latest'. This should do nothing since fourth is on the same branch.
+ delete_tag(repository, "latest")
+ assert_not_deleted(repository, "i1", "i2", "i3", "t1", "t2", "t3")
- # Current state:
- # third -> t3->t2->t1
- # fourth -> i3->i2->i1
+ # Current state:
+ # third -> t3->t2->t1
+ # fourth -> i3->i2->i1
- # Delete tag 'third'. This should remove t1->t3.
- delete_tag(repository, 'third')
- assert_deleted(repository, 't1', 't2', 't3')
- assert_not_deleted(repository, 'i1', 'i2', 'i3')
+ # Delete tag 'third'. This should remove t1->t3.
+ delete_tag(repository, "third")
+ assert_deleted(repository, "t1", "t2", "t3")
+ assert_not_deleted(repository, "i1", "i2", "i3")
- # Current state:
- # fourth -> i3->i2->i1
+ # Current state:
+ # fourth -> i3->i2->i1
- # Add tag to i1.
- move_tag(repository, 'newtag', 'i1', expect_gc=False)
- assert_not_deleted(repository, 'i1', 'i2', 'i3')
+ # Add tag to i1.
+ move_tag(repository, "newtag", "i1", expect_gc=False)
+ assert_not_deleted(repository, "i1", "i2", "i3")
- # Current state:
- # fourth -> i3->i2->i1
- # newtag -> i1
+ # Current state:
+ # fourth -> i3->i2->i1
+ # newtag -> i1
- # Delete tag 'fourth'. This should remove i2 and i3.
- delete_tag(repository, 'fourth')
- assert_deleted(repository, 'i2', 'i3')
- assert_not_deleted(repository, 'i1')
+ # Delete tag 'fourth'. This should remove i2 and i3.
+ delete_tag(repository, "fourth")
+ assert_deleted(repository, "i2", "i3")
+ assert_not_deleted(repository, "i1")
- # Current state:
- # newtag -> i1
+ # Current state:
+ # newtag -> i1
- # Delete tag 'newtag'. This should remove the remaining image.
- delete_tag(repository, 'newtag')
- assert_deleted(repository, 'i1')
+ # Delete tag 'newtag'. This should remove the remaining image.
+ delete_tag(repository, "newtag")
+ assert_deleted(repository, "i1")
- # Current state:
- # (Empty)
+ # Current state:
+ # (Empty)
def test_empty_gc(default_tag_policy, initialized_db):
- with assert_gc_integrity(expect_storage_removed=False):
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1', 'f2'],
- third=['t1', 't2', 't3'], fourth=['i1', 'f1'])
+ with assert_gc_integrity(expect_storage_removed=False):
+ repository = create_repository(
+ latest=["i1", "i2", "i3"],
+ other=["i1", "f1", "f2"],
+ third=["t1", "t2", "t3"],
+ fourth=["i1", "f1"],
+ )
- assert not model.gc.garbage_collect_repo(repository)
- assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3', 'f1', 'f2')
+ assert not model.gc.garbage_collect_repo(repository)
+ assert_not_deleted(repository, "i1", "i2", "i3", "t1", "t2", "t3", "f1", "f2")
def test_time_machine_no_gc(default_tag_policy, initialized_db):
- """ Repository has two tags with shared images. Deleting the tag should not remove any images
+ """ Repository has two tags with shared images. Deleting the tag should not remove any images
"""
- with assert_gc_integrity(expect_storage_removed=False):
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
- _set_tag_expiration_policy(repository.namespace_user.username, 60*60*24)
+ with assert_gc_integrity(expect_storage_removed=False):
+ repository = create_repository(latest=["i1", "i2", "i3"], other=["i1", "f1"])
+ _set_tag_expiration_policy(repository.namespace_user.username, 60 * 60 * 24)
- delete_tag(repository, 'latest', expect_gc=False)
- assert_not_deleted(repository, 'i2', 'i3')
- assert_not_deleted(repository, 'i1', 'f1')
+ delete_tag(repository, "latest", expect_gc=False)
+ assert_not_deleted(repository, "i2", "i3")
+ assert_not_deleted(repository, "i1", "f1")
def test_time_machine_gc(default_tag_policy, initialized_db):
- """ Repository has two tags with shared images. Deleting the second tag should cause the images
+ """ Repository has two tags with shared images. Deleting the second tag should cause the images
for the first deleted tag to gc.
"""
- now = datetime.utcnow()
+ now = datetime.utcnow()
- with assert_gc_integrity():
- with freeze_time(now):
- repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
+ with assert_gc_integrity():
+ with freeze_time(now):
+ repository = create_repository(
+ latest=["i1", "i2", "i3"], other=["i1", "f1"]
+ )
- _set_tag_expiration_policy(repository.namespace_user.username, 1)
+ _set_tag_expiration_policy(repository.namespace_user.username, 1)
- delete_tag(repository, 'latest', expect_gc=False)
- assert_not_deleted(repository, 'i2', 'i3')
- assert_not_deleted(repository, 'i1', 'f1')
+ delete_tag(repository, "latest", expect_gc=False)
+ assert_not_deleted(repository, "i2", "i3")
+ assert_not_deleted(repository, "i1", "f1")
- with freeze_time(now + timedelta(seconds=2)):
- # This will cause the images associated with latest to gc
- delete_tag(repository, 'other')
- assert_deleted(repository, 'i2', 'i3')
- assert_not_deleted(repository, 'i1', 'f1')
+ with freeze_time(now + timedelta(seconds=2)):
+ # This will cause the images associated with latest to gc
+ delete_tag(repository, "other")
+ assert_deleted(repository, "i2", "i3")
+ assert_not_deleted(repository, "i1", "f1")
def test_images_shared_storage(default_tag_policy, initialized_db):
- """ Repository with two tags, both with the same shared storage. Deleting the first
+ """ Repository with two tags, both with the same shared storage. Deleting the first
tag should delete the first image, but *not* its storage.
"""
- with assert_gc_integrity(expect_storage_removed=False):
- repository = create_repository()
+ with assert_gc_integrity(expect_storage_removed=False):
+ repository = create_repository()
- # Add two tags, each with their own image, but with the same storage.
- image_storage = model.storage.create_v1_storage(storage.preferred_locations[0])
+ # Add two tags, each with their own image, but with the same storage.
+ image_storage = model.storage.create_v1_storage(storage.preferred_locations[0])
- first_image = Image.create(docker_image_id='i1',
- repository=repository, storage=image_storage,
- ancestors='/')
+ first_image = Image.create(
+ docker_image_id="i1",
+ repository=repository,
+ storage=image_storage,
+ ancestors="/",
+ )
- second_image = Image.create(docker_image_id='i2',
- repository=repository, storage=image_storage,
- ancestors='/')
+ second_image = Image.create(
+ docker_image_id="i2",
+ repository=repository,
+ storage=image_storage,
+ ancestors="/",
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'first', first_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "first",
+ first_image.docker_image_id,
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'second', second_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "second",
+ second_image.docker_image_id,
+ )
- # Delete the first tag.
- delete_tag(repository, 'first')
- assert_deleted(repository, 'i1')
- assert_not_deleted(repository, 'i2')
+ # Delete the first tag.
+ delete_tag(repository, "first")
+ assert_deleted(repository, "i1")
+ assert_not_deleted(repository, "i2")
def test_image_with_cas(default_tag_policy, initialized_db):
- """ A repository with a tag pointing to an image backed by CAS. Deleting and GCing the tag
+ """ A repository with a tag pointing to an image backed by CAS. Deleting and GCing the tag
should result in the storage and its CAS data being removed.
"""
- with assert_gc_integrity(expect_storage_removed=True):
- repository = create_repository()
+ with assert_gc_integrity(expect_storage_removed=True):
+ repository = create_repository()
- # Create an image storage record under CAS.
- content = 'hello world'
- digest = 'sha256:' + hashlib.sha256(content).hexdigest()
- preferred = storage.preferred_locations[0]
- storage.put_content({preferred}, storage.blob_path(digest), content)
+ # Create an image storage record under CAS.
+ content = "hello world"
+ digest = "sha256:" + hashlib.sha256(content).hexdigest()
+ preferred = storage.preferred_locations[0]
+ storage.put_content({preferred}, storage.blob_path(digest), content)
- image_storage = database.ImageStorage.create(content_checksum=digest, uploading=False)
- location = database.ImageStorageLocation.get(name=preferred)
- database.ImageStoragePlacement.create(location=location, storage=image_storage)
+ image_storage = database.ImageStorage.create(
+ content_checksum=digest, uploading=False
+ )
+ location = database.ImageStorageLocation.get(name=preferred)
+ database.ImageStoragePlacement.create(location=location, storage=image_storage)
- # Ensure the CAS path exists.
- assert storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path exists.
+ assert storage.exists({preferred}, storage.blob_path(digest))
- # Create the image and the tag.
- first_image = Image.create(docker_image_id='i1',
- repository=repository, storage=image_storage,
- ancestors='/')
+ # Create the image and the tag.
+ first_image = Image.create(
+ docker_image_id="i1",
+ repository=repository,
+ storage=image_storage,
+ ancestors="/",
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'first', first_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "first",
+ first_image.docker_image_id,
+ )
- assert_not_deleted(repository, 'i1')
+ assert_not_deleted(repository, "i1")
- # Delete the tag.
- delete_tag(repository, 'first')
- assert_deleted(repository, 'i1')
+ # Delete the tag.
+ delete_tag(repository, "first")
+ assert_deleted(repository, "i1")
- # Ensure the CAS path is gone.
- assert not storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path is gone.
+ assert not storage.exists({preferred}, storage.blob_path(digest))
def test_images_shared_cas(default_tag_policy, initialized_db):
- """ A repository, each two tags, pointing to the same image, which has image storage
+ """ A repository, each two tags, pointing to the same image, which has image storage
with the same *CAS path*, but *distinct records*. Deleting the first tag should delete the
first image, and its storage, but not the file in storage, as it shares its CAS path.
"""
- with assert_gc_integrity(expect_storage_removed=True):
- repository = create_repository()
+ with assert_gc_integrity(expect_storage_removed=True):
+ repository = create_repository()
- # Create two image storage records with the same content checksum.
- content = 'hello world'
- digest = 'sha256:' + hashlib.sha256(content).hexdigest()
- preferred = storage.preferred_locations[0]
- storage.put_content({preferred}, storage.blob_path(digest), content)
+ # Create two image storage records with the same content checksum.
+ content = "hello world"
+ digest = "sha256:" + hashlib.sha256(content).hexdigest()
+ preferred = storage.preferred_locations[0]
+ storage.put_content({preferred}, storage.blob_path(digest), content)
- is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
- is2 = database.ImageStorage.create(content_checksum=digest, uploading=False)
+ is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
+ is2 = database.ImageStorage.create(content_checksum=digest, uploading=False)
- location = database.ImageStorageLocation.get(name=preferred)
+ location = database.ImageStorageLocation.get(name=preferred)
- database.ImageStoragePlacement.create(location=location, storage=is1)
- database.ImageStoragePlacement.create(location=location, storage=is2)
+ database.ImageStoragePlacement.create(location=location, storage=is1)
+ database.ImageStoragePlacement.create(location=location, storage=is2)
- # Ensure the CAS path exists.
- assert storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path exists.
+ assert storage.exists({preferred}, storage.blob_path(digest))
- # Create two images in the repository, and two tags, each pointing to one of the storages.
- first_image = Image.create(docker_image_id='i1',
- repository=repository, storage=is1,
- ancestors='/')
+ # Create two images in the repository, and two tags, each pointing to one of the storages.
+ first_image = Image.create(
+ docker_image_id="i1", repository=repository, storage=is1, ancestors="/"
+ )
- second_image = Image.create(docker_image_id='i2',
- repository=repository, storage=is2,
- ancestors='/')
+ second_image = Image.create(
+ docker_image_id="i2", repository=repository, storage=is2, ancestors="/"
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'first', first_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "first",
+ first_image.docker_image_id,
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'second', second_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "second",
+ second_image.docker_image_id,
+ )
- assert_not_deleted(repository, 'i1', 'i2')
+ assert_not_deleted(repository, "i1", "i2")
- # Delete the first tag.
- delete_tag(repository, 'first')
- assert_deleted(repository, 'i1')
- assert_not_deleted(repository, 'i2')
+ # Delete the first tag.
+ delete_tag(repository, "first")
+ assert_deleted(repository, "i1")
+ assert_not_deleted(repository, "i2")
- # Ensure the CAS path still exists.
- assert storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path still exists.
+ assert storage.exists({preferred}, storage.blob_path(digest))
def test_images_shared_cas_with_new_blob_table(default_tag_policy, initialized_db):
- """ A repository with a tag and image that shares its CAS path with a record in the new Blob
+ """ A repository with a tag and image that shares its CAS path with a record in the new Blob
table. Deleting the first tag should delete the first image, and its storage, but not the
file in storage, as it shares its CAS path with the blob row.
"""
- with assert_gc_integrity(expect_storage_removed=True):
- repository = create_repository()
+ with assert_gc_integrity(expect_storage_removed=True):
+ repository = create_repository()
- # Create two image storage records with the same content checksum.
- content = 'hello world'
- digest = 'sha256:' + hashlib.sha256(content).hexdigest()
- preferred = storage.preferred_locations[0]
- storage.put_content({preferred}, storage.blob_path(digest), content)
+ # Create two image storage records with the same content checksum.
+ content = "hello world"
+ digest = "sha256:" + hashlib.sha256(content).hexdigest()
+ preferred = storage.preferred_locations[0]
+ storage.put_content({preferred}, storage.blob_path(digest), content)
- media_type = database.MediaType.get(name='text/plain')
+ media_type = database.MediaType.get(name="text/plain")
- is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
- database.ApprBlob.create(digest=digest, size=0, media_type=media_type)
+ is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
+ database.ApprBlob.create(digest=digest, size=0, media_type=media_type)
- location = database.ImageStorageLocation.get(name=preferred)
- database.ImageStoragePlacement.create(location=location, storage=is1)
+ location = database.ImageStorageLocation.get(name=preferred)
+ database.ImageStoragePlacement.create(location=location, storage=is1)
- # Ensure the CAS path exists.
- assert storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path exists.
+ assert storage.exists({preferred}, storage.blob_path(digest))
- # Create the image in the repository, and the tag.
- first_image = Image.create(docker_image_id='i1',
- repository=repository, storage=is1,
- ancestors='/')
+ # Create the image in the repository, and the tag.
+ first_image = Image.create(
+ docker_image_id="i1", repository=repository, storage=is1, ancestors="/"
+ )
- store_tag_manifest(repository.namespace_user.username, repository.name,
- 'first', first_image.docker_image_id)
+ store_tag_manifest(
+ repository.namespace_user.username,
+ repository.name,
+ "first",
+ first_image.docker_image_id,
+ )
- assert_not_deleted(repository, 'i1')
+ assert_not_deleted(repository, "i1")
- # Delete the tag.
- delete_tag(repository, 'first')
- assert_deleted(repository, 'i1')
+ # Delete the tag.
+ delete_tag(repository, "first")
+ assert_deleted(repository, "i1")
- # Ensure the CAS path still exists, as it is referenced by the Blob table
- assert storage.exists({preferred}, storage.blob_path(digest))
+ # Ensure the CAS path still exists, as it is referenced by the Blob table
+ assert storage.exists({preferred}, storage.blob_path(digest))
def test_purge_repo(app):
- """ Test that app registers delete_metadata function on repository deletions """
- with assert_gc_integrity():
- with patch('app.tuf_metadata_api') as mock_tuf:
- model.gc.purge_repository("ns", "repo")
- assert mock_tuf.delete_metadata.called_with("ns", "repo")
+ """ Test that app registers delete_metadata function on repository deletions """
+ with assert_gc_integrity():
+ with patch("app.tuf_metadata_api") as mock_tuf:
+ model.gc.purge_repository("ns", "repo")
+ assert mock_tuf.delete_metadata.called_with("ns", "repo")
def test_super_long_image_chain_gc(app, default_tag_policy):
- """ Test that a super long chain of images all gets properly GCed. """
- with assert_gc_integrity():
- images = ['i%s' % i for i in range(0, 100)]
- repository = create_repository(latest=images)
- delete_tag(repository, 'latest')
+ """ Test that a super long chain of images all gets properly GCed. """
+ with assert_gc_integrity():
+ images = ["i%s" % i for i in range(0, 100)]
+ repository = create_repository(latest=images)
+ delete_tag(repository, "latest")
- # Ensure the repository is now empty.
- assert_deleted(repository, *images)
+ # Ensure the repository is now empty.
+ assert_deleted(repository, *images)
def test_manifest_v2_shared_config_and_blobs(app, default_tag_policy):
- """ Test that GCing a tag that refers to a V2 manifest with the same config and some shared
+ """ Test that GCing a tag that refers to a V2 manifest with the same config and some shared
blobs as another manifest ensures that the config blob and shared blob are NOT GCed.
"""
- repo = model.repository.create_repository('devtable', 'newrepo', None)
- manifest1, built1 = create_manifest_for_testing(repo, differentiation_field='1',
- include_shared_blob=True)
- manifest2, built2 = create_manifest_for_testing(repo, differentiation_field='2',
- include_shared_blob=True)
+ repo = model.repository.create_repository("devtable", "newrepo", None)
+ manifest1, built1 = create_manifest_for_testing(
+ repo, differentiation_field="1", include_shared_blob=True
+ )
+ manifest2, built2 = create_manifest_for_testing(
+ repo, differentiation_field="2", include_shared_blob=True
+ )
- assert set(built1.local_blob_digests).intersection(built2.local_blob_digests)
- assert built1.config.digest == built2.config.digest
+ assert set(built1.local_blob_digests).intersection(built2.local_blob_digests)
+ assert built1.config.digest == built2.config.digest
- # Create tags pointing to the manifests.
- model.oci.tag.retarget_tag('tag1', manifest1)
- model.oci.tag.retarget_tag('tag2', manifest2)
+ # Create tags pointing to the manifests.
+ model.oci.tag.retarget_tag("tag1", manifest1)
+ model.oci.tag.retarget_tag("tag2", manifest2)
- with assert_gc_integrity(expect_storage_removed=True, check_oci_tags=False):
- # Delete tag2.
- model.oci.tag.delete_tag(repo, 'tag2')
- assert model.gc.garbage_collect_repo(repo)
+ with assert_gc_integrity(expect_storage_removed=True, check_oci_tags=False):
+ # Delete tag2.
+ model.oci.tag.delete_tag(repo, "tag2")
+ assert model.gc.garbage_collect_repo(repo)
- # Ensure the blobs for manifest1 still all exist.
- preferred = storage.preferred_locations[0]
- for blob_digest in built1.local_blob_digests:
- storage_row = ImageStorage.get(content_checksum=blob_digest)
+ # Ensure the blobs for manifest1 still all exist.
+ preferred = storage.preferred_locations[0]
+ for blob_digest in built1.local_blob_digests:
+ storage_row = ImageStorage.get(content_checksum=blob_digest)
- assert storage_row.cas_path
- storage.get_content({preferred}, storage.blob_path(storage_row.content_checksum))
+ assert storage_row.cas_path
+ storage.get_content(
+ {preferred}, storage.blob_path(storage_row.content_checksum)
+ )
diff --git a/data/model/test/test_image.py b/data/model/test/test_image.py
index 9442a23eb..e7c3f79e1 100644
--- a/data/model/test/test_image.py
+++ b/data/model/test/test_image.py
@@ -6,99 +6,110 @@ from playhouse.test_utils import assert_query_count
from test.fixtures import *
+
@pytest.fixture()
def images(initialized_db):
- images = image.get_repository_images('devtable', 'simple')
- assert len(images)
- return images
+ images = image.get_repository_images("devtable", "simple")
+ assert len(images)
+ return images
def test_get_image_with_storage(images, initialized_db):
- for current in images:
- storage_uuid = current.storage.uuid
+ for current in images:
+ storage_uuid = current.storage.uuid
- with assert_query_count(1):
- retrieved = image.get_image_with_storage(current.docker_image_id, storage_uuid)
- assert retrieved.id == current.id
- assert retrieved.storage.uuid == storage_uuid
+ with assert_query_count(1):
+ retrieved = image.get_image_with_storage(
+ current.docker_image_id, storage_uuid
+ )
+ assert retrieved.id == current.id
+ assert retrieved.storage.uuid == storage_uuid
def test_get_parent_images(images, initialized_db):
- for current in images:
- if not len(current.ancestor_id_list()):
- continue
+ for current in images:
+ if not len(current.ancestor_id_list()):
+ continue
- with assert_query_count(1):
- parent_images = list(image.get_parent_images('devtable', 'simple', current))
+ with assert_query_count(1):
+ parent_images = list(image.get_parent_images("devtable", "simple", current))
- assert len(parent_images) == len(current.ancestor_id_list())
- assert set(current.ancestor_id_list()) == {i.id for i in parent_images}
+ assert len(parent_images) == len(current.ancestor_id_list())
+ assert set(current.ancestor_id_list()) == {i.id for i in parent_images}
- for parent in parent_images:
- with assert_query_count(0):
- assert parent.storage.id
+ for parent in parent_images:
+ with assert_query_count(0):
+ assert parent.storage.id
def test_get_image(images, initialized_db):
- for current in images:
- repo = current.repository
+ for current in images:
+ repo = current.repository
- with assert_query_count(1):
- found = image.get_image(repo, current.docker_image_id)
+ with assert_query_count(1):
+ found = image.get_image(repo, current.docker_image_id)
- assert found.id == current.id
+ assert found.id == current.id
def test_placements(images, initialized_db):
- with assert_query_count(1):
- placements_map = image.get_placements_for_images(images)
+ with assert_query_count(1):
+ placements_map = image.get_placements_for_images(images)
- for current in images:
- assert current.storage.id in placements_map
+ for current in images:
+ assert current.storage.id in placements_map
- with assert_query_count(2):
- expected_image, expected_placements = image.get_image_and_placements('devtable', 'simple',
- current.docker_image_id)
+ with assert_query_count(2):
+ expected_image, expected_placements = image.get_image_and_placements(
+ "devtable", "simple", current.docker_image_id
+ )
- assert expected_image.id == current.id
- assert len(expected_placements) == len(placements_map.get(current.storage.id))
- assert ({p.id for p in expected_placements} ==
- {p.id for p in placements_map.get(current.storage.id)})
+ assert expected_image.id == current.id
+ assert len(expected_placements) == len(placements_map.get(current.storage.id))
+ assert {p.id for p in expected_placements} == {
+ p.id for p in placements_map.get(current.storage.id)
+ }
def test_get_repo_image(images, initialized_db):
- for current in images:
- with assert_query_count(1):
- found = image.get_repo_image('devtable', 'simple', current.docker_image_id)
+ for current in images:
+ with assert_query_count(1):
+ found = image.get_repo_image("devtable", "simple", current.docker_image_id)
- assert found.id == current.id
- with assert_query_count(1):
- assert found.storage.id
+ assert found.id == current.id
+ with assert_query_count(1):
+ assert found.storage.id
def test_get_repo_image_and_storage(images, initialized_db):
- for current in images:
- with assert_query_count(1):
- found = image.get_repo_image_and_storage('devtable', 'simple', current.docker_image_id)
+ for current in images:
+ with assert_query_count(1):
+ found = image.get_repo_image_and_storage(
+ "devtable", "simple", current.docker_image_id
+ )
- assert found.id == current.id
- with assert_query_count(0):
- assert found.storage.id
+ assert found.id == current.id
+ with assert_query_count(0):
+ assert found.storage.id
def test_get_repository_images_without_placements(images, initialized_db):
- ancestors_map = defaultdict(list)
- for img in images:
- current = img.parent
- while current is not None:
- ancestors_map[current.id].append(img.id)
- current = current.parent
+ ancestors_map = defaultdict(list)
+ for img in images:
+ current = img.parent
+ while current is not None:
+ ancestors_map[current.id].append(img.id)
+ current = current.parent
- for current in images:
- repo = current.repository
+ for current in images:
+ repo = current.repository
- with assert_query_count(1):
- found = list(image.get_repository_images_without_placements(repo, with_ancestor=current))
+ with assert_query_count(1):
+ found = list(
+ image.get_repository_images_without_placements(
+ repo, with_ancestor=current
+ )
+ )
- assert len(found) == len(ancestors_map[current.id]) + 1
- assert {i.id for i in found} == set(ancestors_map[current.id] + [current.id])
+ assert len(found) == len(ancestors_map[current.id]) + 1
+ assert {i.id for i in found} == set(ancestors_map[current.id] + [current.id])
diff --git a/data/model/test/test_image_sharing.py b/data/model/test/test_image_sharing.py
index 239500b10..2490fe7cc 100644
--- a/data/model/test/test_image_sharing.py
+++ b/data/model/test/test_image_sharing.py
@@ -6,210 +6,315 @@ from storage.distributedstorage import DistributedStorage
from storage.fakestorage import FakeStorage
from test.fixtures import *
-NO_ACCESS_USER = 'freshuser'
-READ_ACCESS_USER = 'reader'
-ADMIN_ACCESS_USER = 'devtable'
-PUBLIC_USER = 'public'
-RANDOM_USER = 'randomuser'
-OUTSIDE_ORG_USER = 'outsideorg'
+NO_ACCESS_USER = "freshuser"
+READ_ACCESS_USER = "reader"
+ADMIN_ACCESS_USER = "devtable"
+PUBLIC_USER = "public"
+RANDOM_USER = "randomuser"
+OUTSIDE_ORG_USER = "outsideorg"
-ADMIN_ROBOT_USER = 'devtable+dtrobot'
+ADMIN_ROBOT_USER = "devtable+dtrobot"
-ORGANIZATION = 'buynlarge'
+ORGANIZATION = "buynlarge"
-REPO = 'devtable/simple'
-PUBLIC_REPO = 'public/publicrepo'
-RANDOM_REPO = 'randomuser/randomrepo'
+REPO = "devtable/simple"
+PUBLIC_REPO = "public/publicrepo"
+RANDOM_REPO = "randomuser/randomrepo"
-OUTSIDE_ORG_REPO = 'outsideorg/coolrepo'
+OUTSIDE_ORG_REPO = "outsideorg/coolrepo"
-ORG_REPO = 'buynlarge/orgrepo'
-ANOTHER_ORG_REPO = 'buynlarge/anotherorgrepo'
+ORG_REPO = "buynlarge/orgrepo"
+ANOTHER_ORG_REPO = "buynlarge/anotherorgrepo"
# Note: The shared repo has devtable as admin, public as a writer and reader as a reader.
-SHARED_REPO = 'devtable/shared'
+SHARED_REPO = "devtable/shared"
@pytest.fixture()
def storage(app):
- return DistributedStorage({'local_us': FakeStorage(None)}, preferred_locations=['local_us'])
+ return DistributedStorage(
+ {"local_us": FakeStorage(None)}, preferred_locations=["local_us"]
+ )
-def createStorage(storage, docker_image_id, repository=REPO, username=ADMIN_ACCESS_USER):
- repository_obj = model.repository.get_repository(repository.split('/')[0],
- repository.split('/')[1])
- preferred = storage.preferred_locations[0]
- image = model.image.find_create_or_link_image(docker_image_id, repository_obj, username, {},
- preferred)
- image.storage.uploading = False
- image.storage.save()
- return image.storage
+def createStorage(
+ storage, docker_image_id, repository=REPO, username=ADMIN_ACCESS_USER
+):
+ repository_obj = model.repository.get_repository(
+ repository.split("/")[0], repository.split("/")[1]
+ )
+ preferred = storage.preferred_locations[0]
+ image = model.image.find_create_or_link_image(
+ docker_image_id, repository_obj, username, {}, preferred
+ )
+ image.storage.uploading = False
+ image.storage.save()
+ return image.storage
-def assertSameStorage(storage, docker_image_id, existing_storage, repository=REPO,
- username=ADMIN_ACCESS_USER):
- new_storage = createStorage(storage, docker_image_id, repository, username)
- assert existing_storage.id == new_storage.id
+def assertSameStorage(
+ storage,
+ docker_image_id,
+ existing_storage,
+ repository=REPO,
+ username=ADMIN_ACCESS_USER,
+):
+ new_storage = createStorage(storage, docker_image_id, repository, username)
+ assert existing_storage.id == new_storage.id
-def assertDifferentStorage(storage, docker_image_id, existing_storage, repository=REPO,
- username=ADMIN_ACCESS_USER):
- new_storage = createStorage(storage, docker_image_id, repository, username)
- assert existing_storage.id != new_storage.id
+def assertDifferentStorage(
+ storage,
+ docker_image_id,
+ existing_storage,
+ repository=REPO,
+ username=ADMIN_ACCESS_USER,
+):
+ new_storage = createStorage(storage, docker_image_id, repository, username)
+ assert existing_storage.id != new_storage.id
def test_same_user(storage, initialized_db):
- """ The same user creates two images, each which should be shared in the same repo. This is a
+ """ The same user creates two images, each which should be shared in the same repo. This is a
sanity check. """
- # Create a reference to a new docker ID => new image.
- first_storage_id = createStorage(storage, 'first-image')
+ # Create a reference to a new docker ID => new image.
+ first_storage_id = createStorage(storage, "first-image")
- # Create a reference to the same docker ID => same image.
- assertSameStorage(storage, 'first-image', first_storage_id)
+ # Create a reference to the same docker ID => same image.
+ assertSameStorage(storage, "first-image", first_storage_id)
- # Create a reference to another new docker ID => new image.
- second_storage_id = createStorage(storage, 'second-image')
+ # Create a reference to another new docker ID => new image.
+ second_storage_id = createStorage(storage, "second-image")
- # Create a reference to that same docker ID => same image.
- assertSameStorage(storage, 'second-image', second_storage_id)
+ # Create a reference to that same docker ID => same image.
+ assertSameStorage(storage, "second-image", second_storage_id)
- # Make sure the images are different.
- assert first_storage_id != second_storage_id
+ # Make sure the images are different.
+ assert first_storage_id != second_storage_id
def test_no_user_private_repo(storage, initialized_db):
- """ If no user is specified (token case usually), then no sharing can occur on a private repo. """
- # Create a reference to a new docker ID => new image.
- first_storage = createStorage(storage, 'the-image', username=None, repository=SHARED_REPO)
+ """ If no user is specified (token case usually), then no sharing can occur on a private repo. """
+ # Create a reference to a new docker ID => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=None, repository=SHARED_REPO
+ )
- # Create a areference to the same docker ID, but since no username => new image.
- assertDifferentStorage(storage, 'the-image', first_storage, username=None, repository=RANDOM_REPO)
+ # Create a areference to the same docker ID, but since no username => new image.
+ assertDifferentStorage(
+ storage, "the-image", first_storage, username=None, repository=RANDOM_REPO
+ )
def test_no_user_public_repo(storage, initialized_db):
- """ If no user is specified (token case usually), then no sharing can occur on a private repo except when the image is first public. """
- # Create a reference to a new docker ID => new image.
- first_storage = createStorage(storage, 'the-image', username=None, repository=PUBLIC_REPO)
+ """ If no user is specified (token case usually), then no sharing can occur on a private repo except when the image is first public. """
+ # Create a reference to a new docker ID => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=None, repository=PUBLIC_REPO
+ )
- # Create a areference to the same docker ID. Since no username, we'd expect different but the first image is public so => shaed image.
- assertSameStorage(storage, 'the-image', first_storage, username=None, repository=RANDOM_REPO)
+ # Create a areference to the same docker ID. Since no username, we'd expect different but the first image is public so => shaed image.
+ assertSameStorage(
+ storage, "the-image", first_storage, username=None, repository=RANDOM_REPO
+ )
def test_different_user_same_repo(storage, initialized_db):
- """ Two different users create the same image in the same repo. """
+ """ Two different users create the same image in the same repo. """
- # Create a reference to a new docker ID under the first user => new image.
- first_storage = createStorage(storage, 'the-image', username=PUBLIC_USER, repository=SHARED_REPO)
+ # Create a reference to a new docker ID under the first user => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=PUBLIC_USER, repository=SHARED_REPO
+ )
- # Create a reference to the *same* docker ID under the second user => same image.
- assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=SHARED_REPO)
+ # Create a reference to the *same* docker ID under the second user => same image.
+ assertSameStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=ADMIN_ACCESS_USER,
+ repository=SHARED_REPO,
+ )
def test_different_repo_no_shared_access(storage, initialized_db):
- """ Neither user has access to the other user's repository. """
+ """ Neither user has access to the other user's repository. """
- # Create a reference to a new docker ID under the first user => new image.
- first_storage_id = createStorage(storage, 'the-image', username=RANDOM_USER, repository=RANDOM_REPO)
+ # Create a reference to a new docker ID under the first user => new image.
+ first_storage_id = createStorage(
+ storage, "the-image", username=RANDOM_USER, repository=RANDOM_REPO
+ )
- # Create a reference to the *same* docker ID under the second user => new image.
- second_storage_id = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=REPO)
+ # Create a reference to the *same* docker ID under the second user => new image.
+ second_storage_id = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=REPO
+ )
- # Verify that the users do not share storage.
- assert first_storage_id != second_storage_id
+ # Verify that the users do not share storage.
+ assert first_storage_id != second_storage_id
def test_public_than_private(storage, initialized_db):
- """ An image is created publicly then used privately, so it should be shared. """
+ """ An image is created publicly then used privately, so it should be shared. """
- # Create a reference to a new docker ID under the first user => new image.
- first_storage = createStorage(storage, 'the-image', username=PUBLIC_USER, repository=PUBLIC_REPO)
+ # Create a reference to a new docker ID under the first user => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=PUBLIC_USER, repository=PUBLIC_REPO
+ )
- # Create a reference to the *same* docker ID under the second user => same image, since the first was public.
- assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=REPO)
+ # Create a reference to the *same* docker ID under the second user => same image, since the first was public.
+ assertSameStorage(
+ storage, "the-image", first_storage, username=ADMIN_ACCESS_USER, repository=REPO
+ )
def test_private_than_public(storage, initialized_db):
- """ An image is created privately then used publicly, so it should *not* be shared. """
+ """ An image is created privately then used publicly, so it should *not* be shared. """
- # Create a reference to a new docker ID under the first user => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=REPO)
+ # Create a reference to a new docker ID under the first user => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=REPO
+ )
- # Create a reference to the *same* docker ID under the second user => new image, since the first was private.
- assertDifferentStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
+ # Create a reference to the *same* docker ID under the second user => new image, since the first was private.
+ assertDifferentStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=PUBLIC_USER,
+ repository=PUBLIC_REPO,
+ )
def test_different_repo_with_access(storage, initialized_db):
- """ An image is created in one repo (SHARED_REPO) which the user (PUBLIC_USER) has access to. Later, the
+ """ An image is created in one repo (SHARED_REPO) which the user (PUBLIC_USER) has access to. Later, the
image is created in another repo (PUBLIC_REPO) that the user also has access to. The image should
be shared since the user has access.
"""
- # Create the image in the shared repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=SHARED_REPO)
+ # Create the image in the shared repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=SHARED_REPO
+ )
- # Create the image in the other user's repo, but since the user (PUBLIC) still has access to the shared
- # repository, they should reuse the storage.
- assertSameStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
+ # Create the image in the other user's repo, but since the user (PUBLIC) still has access to the shared
+ # repository, they should reuse the storage.
+ assertSameStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=PUBLIC_USER,
+ repository=PUBLIC_REPO,
+ )
def test_org_access(storage, initialized_db):
- """ An image is accessible by being a member of the organization. """
+ """ An image is accessible by being a member of the organization. """
- # Create the new image under the org's repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
+ # Create the new image under the org's repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=ORG_REPO
+ )
- # Create an image under the user's repo, but since the user has access to the organization => shared image.
- assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=REPO)
+ # Create an image under the user's repo, but since the user has access to the organization => shared image.
+ assertSameStorage(
+ storage, "the-image", first_storage, username=ADMIN_ACCESS_USER, repository=REPO
+ )
- # Ensure that the user's robot does not have access, since it is not on the permissions list for the repo.
- assertDifferentStorage(storage, 'the-image', first_storage, username=ADMIN_ROBOT_USER, repository=SHARED_REPO)
+ # Ensure that the user's robot does not have access, since it is not on the permissions list for the repo.
+ assertDifferentStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=ADMIN_ROBOT_USER,
+ repository=SHARED_REPO,
+ )
def test_org_access_different_user(storage, initialized_db):
- """ An image is accessible by being a member of the organization. """
+ """ An image is accessible by being a member of the organization. """
- # Create the new image under the org's repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
+ # Create the new image under the org's repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=ORG_REPO
+ )
- # Create an image under a user's repo, but since the user has access to the organization => shared image.
- assertSameStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
+ # Create an image under a user's repo, but since the user has access to the organization => shared image.
+ assertSameStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=PUBLIC_USER,
+ repository=PUBLIC_REPO,
+ )
- # Also verify for reader.
- assertSameStorage(storage, 'the-image', first_storage, username=READ_ACCESS_USER, repository=PUBLIC_REPO)
+ # Also verify for reader.
+ assertSameStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=READ_ACCESS_USER,
+ repository=PUBLIC_REPO,
+ )
def test_org_no_access(storage, initialized_db):
- """ An image is not accessible if not a member of the organization. """
+ """ An image is not accessible if not a member of the organization. """
- # Create the new image under the org's repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
+ # Create the new image under the org's repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=ORG_REPO
+ )
- # Create an image under a user's repo. Since the user is not a member of the organization => new image.
- assertDifferentStorage(storage, 'the-image', first_storage, username=RANDOM_USER, repository=RANDOM_REPO)
+ # Create an image under a user's repo. Since the user is not a member of the organization => new image.
+ assertDifferentStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=RANDOM_USER,
+ repository=RANDOM_REPO,
+ )
def test_org_not_team_member_with_access(storage, initialized_db):
- """ An image is accessible to a user specifically listed as having permission on the org repo. """
+ """ An image is accessible to a user specifically listed as having permission on the org repo. """
- # Create the new image under the org's repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
+ # Create the new image under the org's repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=ORG_REPO
+ )
- # Create an image under a user's repo. Since the user has read access on that repo, they can see the image => shared image.
- assertSameStorage(storage, 'the-image', first_storage, username=OUTSIDE_ORG_USER, repository=OUTSIDE_ORG_REPO)
+ # Create an image under a user's repo. Since the user has read access on that repo, they can see the image => shared image.
+ assertSameStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=OUTSIDE_ORG_USER,
+ repository=OUTSIDE_ORG_REPO,
+ )
def test_org_not_team_member_with_no_access(storage, initialized_db):
- """ A user that has access to one org repo but not another and is not a team member. """
+ """ A user that has access to one org repo but not another and is not a team member. """
- # Create the new image under the org's repo => new image.
- first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ANOTHER_ORG_REPO)
+ # Create the new image under the org's repo => new image.
+ first_storage = createStorage(
+ storage, "the-image", username=ADMIN_ACCESS_USER, repository=ANOTHER_ORG_REPO
+ )
+
+ # Create an image under a user's repo. The user doesn't have access to the repo (ANOTHER_ORG_REPO) so => new image.
+ assertDifferentStorage(
+ storage,
+ "the-image",
+ first_storage,
+ username=OUTSIDE_ORG_USER,
+ repository=OUTSIDE_ORG_REPO,
+ )
- # Create an image under a user's repo. The user doesn't have access to the repo (ANOTHER_ORG_REPO) so => new image.
- assertDifferentStorage(storage, 'the-image', first_storage, username=OUTSIDE_ORG_USER, repository=OUTSIDE_ORG_REPO)
def test_no_link_to_uploading(storage, initialized_db):
- still_uploading = createStorage(storage, 'an-image', repository=PUBLIC_REPO)
- still_uploading.uploading = True
- still_uploading.save()
+ still_uploading = createStorage(storage, "an-image", repository=PUBLIC_REPO)
+ still_uploading.uploading = True
+ still_uploading.save()
- assertDifferentStorage(storage, 'an-image', still_uploading)
+ assertDifferentStorage(storage, "an-image", still_uploading)
diff --git a/data/model/test/test_log.py b/data/model/test/test_log.py
index 7ced0bb91..21c84e655 100644
--- a/data/model/test/test_log.py
+++ b/data/model/test/test_log.py
@@ -8,73 +8,101 @@ from mock import patch, Mock, DEFAULT, sentinel
from peewee import PeeweeException
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def app_config():
- with patch.dict(_config.app_config, {}, clear=True):
- yield _config.app_config
+ with patch.dict(_config.app_config, {}, clear=True):
+ yield _config.app_config
+
@pytest.fixture()
def logentry_kind():
- kinds = {'pull_repo': 'pull_repo_kind', 'push_repo': 'push_repo_kind'}
- with patch('data.model.log.get_log_entry_kinds', return_value=kinds, spec=True):
- yield kinds
+ kinds = {"pull_repo": "pull_repo_kind", "push_repo": "push_repo_kind"}
+ with patch("data.model.log.get_log_entry_kinds", return_value=kinds, spec=True):
+ yield kinds
+
@pytest.fixture()
def logentry(logentry_kind):
- with patch('data.database.LogEntry3.create', spec=True):
- yield LogEntry3
+ with patch("data.database.LogEntry3.create", spec=True):
+ yield LogEntry3
+
@pytest.fixture()
def user():
- with patch.multiple('data.database.User', username=DEFAULT, get=DEFAULT, select=DEFAULT) as user:
- user['get'].return_value = Mock(id='mock_user_id')
- user['select'].return_value.tuples.return_value.get.return_value = ['default_user_id']
- yield User
+ with patch.multiple(
+ "data.database.User", username=DEFAULT, get=DEFAULT, select=DEFAULT
+ ) as user:
+ user["get"].return_value = Mock(id="mock_user_id")
+ user["select"].return_value.tuples.return_value.get.return_value = [
+ "default_user_id"
+ ]
+ yield User
-@pytest.mark.parametrize('action_kind', [('pull'), ('oops')])
+
+@pytest.mark.parametrize("action_kind", [("pull"), ("oops")])
def test_log_action_unknown_action(action_kind):
- ''' test unknown action types throw an exception when logged '''
- with pytest.raises(Exception):
- log_action(action_kind, None)
+ """ test unknown action types throw an exception when logged """
+ with pytest.raises(Exception):
+ log_action(action_kind, None)
-@pytest.mark.parametrize('user_or_org_name,account_id,account', [
- ('my_test_org', 'N/A', 'mock_user_id' ),
- (None, 'test_account_id', 'test_account_id'),
- (None, None, 'default_user_id')
-])
-@pytest.mark.parametrize('unlogged_pulls_ok,action_kind,db_exception,throws', [
- (False, 'pull_repo', None, False),
- (False, 'push_repo', None, False),
- (False, 'pull_repo', PeeweeException, True ),
- (False, 'push_repo', PeeweeException, True ),
+@pytest.mark.parametrize(
+ "user_or_org_name,account_id,account",
+ [
+ ("my_test_org", "N/A", "mock_user_id"),
+ (None, "test_account_id", "test_account_id"),
+ (None, None, "default_user_id"),
+ ],
+)
+@pytest.mark.parametrize(
+ "unlogged_pulls_ok,action_kind,db_exception,throws",
+ [
+ (False, "pull_repo", None, False),
+ (False, "push_repo", None, False),
+ (False, "pull_repo", PeeweeException, True),
+ (False, "push_repo", PeeweeException, True),
+ (True, "pull_repo", PeeweeException, False),
+ (True, "push_repo", PeeweeException, True),
+ (True, "pull_repo", Exception, True),
+ (True, "push_repo", Exception, True),
+ ],
+)
+def test_log_action(
+ user_or_org_name,
+ account_id,
+ account,
+ unlogged_pulls_ok,
+ action_kind,
+ db_exception,
+ throws,
+ app_config,
+ logentry,
+ user,
+):
+ log_args = {
+ "performer": Mock(id="TEST_PERFORMER_ID"),
+ "repository": Mock(id="TEST_REPO"),
+ "ip": "TEST_IP",
+ "metadata": {"test_key": "test_value"},
+ "timestamp": "TEST_TIMESTAMP",
+ }
+ app_config["SERVICE_LOG_ACCOUNT_ID"] = account_id
+ app_config["ALLOW_PULLS_WITHOUT_STRICT_LOGGING"] = unlogged_pulls_ok
- (True, 'pull_repo', PeeweeException, False),
- (True, 'push_repo', PeeweeException, True ),
- (True, 'pull_repo', Exception, True ),
- (True, 'push_repo', Exception, True )
-])
-def test_log_action(user_or_org_name, account_id, account, unlogged_pulls_ok, action_kind,
- db_exception, throws, app_config, logentry, user):
- log_args = {
- 'performer' : Mock(id='TEST_PERFORMER_ID'),
- 'repository' : Mock(id='TEST_REPO'),
- 'ip' : 'TEST_IP',
- 'metadata' : { 'test_key' : 'test_value' },
- 'timestamp' : 'TEST_TIMESTAMP'
- }
- app_config['SERVICE_LOG_ACCOUNT_ID'] = account_id
- app_config['ALLOW_PULLS_WITHOUT_STRICT_LOGGING'] = unlogged_pulls_ok
+ logentry.create.side_effect = db_exception
- logentry.create.side_effect = db_exception
+ if throws:
+ with pytest.raises(db_exception):
+ log_action(action_kind, user_or_org_name, **log_args)
+ else:
+ log_action(action_kind, user_or_org_name, **log_args)
- if throws:
- with pytest.raises(db_exception):
- log_action(action_kind, user_or_org_name, **log_args)
- else:
- log_action(action_kind, user_or_org_name, **log_args)
-
- logentry.create.assert_called_once_with(kind=action_kind+'_kind', account=account,
- performer='TEST_PERFORMER_ID', repository='TEST_REPO',
- ip='TEST_IP', metadata_json='{"test_key": "test_value"}',
- datetime='TEST_TIMESTAMP')
+ logentry.create.assert_called_once_with(
+ kind=action_kind + "_kind",
+ account=account,
+ performer="TEST_PERFORMER_ID",
+ repository="TEST_REPO",
+ ip="TEST_IP",
+ metadata_json='{"test_key": "test_value"}',
+ datetime="TEST_TIMESTAMP",
+ )
diff --git a/data/model/test/test_model_blob.py b/data/model/test/test_model_blob.py
index b6053b353..13ed37d51 100644
--- a/data/model/test/test_model_blob.py
+++ b/data/model/test/test_model_blob.py
@@ -3,49 +3,58 @@ from data import model, database
from test.fixtures import *
-ADMIN_ACCESS_USER = 'devtable'
-REPO = 'simple'
+ADMIN_ACCESS_USER = "devtable"
+REPO = "simple"
+
def test_store_blob(initialized_db):
- location = database.ImageStorageLocation.select().get()
+ location = database.ImageStorageLocation.select().get()
- # Create a new blob at a unique digest.
- digest = 'somecooldigest'
- blob_storage = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, digest,
- location, 1024, 0, 5000)
- assert blob_storage.content_checksum == digest
- assert blob_storage.image_size == 1024
- assert blob_storage.uncompressed_size == 5000
+ # Create a new blob at a unique digest.
+ digest = "somecooldigest"
+ blob_storage = model.blob.store_blob_record_and_temp_link(
+ ADMIN_ACCESS_USER, REPO, digest, location, 1024, 0, 5000
+ )
+ assert blob_storage.content_checksum == digest
+ assert blob_storage.image_size == 1024
+ assert blob_storage.uncompressed_size == 5000
- # Link to the same digest.
- blob_storage2 = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, digest,
- location, 2048, 0, 6000)
- assert blob_storage2.id == blob_storage.id
+ # Link to the same digest.
+ blob_storage2 = model.blob.store_blob_record_and_temp_link(
+ ADMIN_ACCESS_USER, REPO, digest, location, 2048, 0, 6000
+ )
+ assert blob_storage2.id == blob_storage.id
- # The sizes should be unchanged.
- assert blob_storage2.image_size == 1024
- assert blob_storage2.uncompressed_size == 5000
+ # The sizes should be unchanged.
+ assert blob_storage2.image_size == 1024
+ assert blob_storage2.uncompressed_size == 5000
- # Add a new digest, ensure it has a new record.
- otherdigest = 'anotherdigest'
- blob_storage3 = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, otherdigest,
- location, 1234, 0, 5678)
- assert blob_storage3.id != blob_storage.id
- assert blob_storage3.image_size == 1234
- assert blob_storage3.uncompressed_size == 5678
+ # Add a new digest, ensure it has a new record.
+ otherdigest = "anotherdigest"
+ blob_storage3 = model.blob.store_blob_record_and_temp_link(
+ ADMIN_ACCESS_USER, REPO, otherdigest, location, 1234, 0, 5678
+ )
+ assert blob_storage3.id != blob_storage.id
+ assert blob_storage3.image_size == 1234
+ assert blob_storage3.uncompressed_size == 5678
def test_get_or_create_shared_blob(initialized_db):
- shared = model.blob.get_or_create_shared_blob('sha256:abcdef', 'somecontent', storage)
- assert shared.content_checksum == 'sha256:abcdef'
+ shared = model.blob.get_or_create_shared_blob(
+ "sha256:abcdef", "somecontent", storage
+ )
+ assert shared.content_checksum == "sha256:abcdef"
- again = model.blob.get_or_create_shared_blob('sha256:abcdef', 'somecontent', storage)
- assert shared == again
+ again = model.blob.get_or_create_shared_blob(
+ "sha256:abcdef", "somecontent", storage
+ )
+ assert shared == again
def test_lookup_repo_storages_by_content_checksum(initialized_db):
- for image in database.Image.select():
- found = model.storage.lookup_repo_storages_by_content_checksum(image.repository,
- [image.storage.content_checksum])
- assert len(found) == 1
- assert found[0].content_checksum == image.storage.content_checksum
+ for image in database.Image.select():
+ found = model.storage.lookup_repo_storages_by_content_checksum(
+ image.repository, [image.storage.content_checksum]
+ )
+ assert len(found) == 1
+ assert found[0].content_checksum == image.storage.content_checksum
diff --git a/data/model/test/test_modelutil.py b/data/model/test/test_modelutil.py
index 5da72be4a..1f2bb1447 100644
--- a/data/model/test/test_modelutil.py
+++ b/data/model/test/test_modelutil.py
@@ -4,47 +4,38 @@ from data.database import Role
from data.model.modelutil import paginate
from test.fixtures import *
-@pytest.mark.parametrize('page_size', [
- 10,
- 20,
- 50,
- 100,
- 200,
- 500,
- 1000,
-])
-@pytest.mark.parametrize('descending', [
- False,
- True,
-])
+
+@pytest.mark.parametrize("page_size", [10, 20, 50, 100, 200, 500, 1000])
+@pytest.mark.parametrize("descending", [False, True])
def test_paginate(page_size, descending, initialized_db):
- # Add a bunch of rows into a test table (`Role`).
- for i in range(0, 522):
- Role.create(name='testrole%s' % i)
+ # Add a bunch of rows into a test table (`Role`).
+ for i in range(0, 522):
+ Role.create(name="testrole%s" % i)
- query = Role.select().where(Role.name ** 'testrole%')
- all_matching_roles = list(query)
- assert len(all_matching_roles) == 522
+ query = Role.select().where(Role.name ** "testrole%")
+ all_matching_roles = list(query)
+ assert len(all_matching_roles) == 522
- # Paginate a query to lookup roles.
- collected = []
- page_token = None
- while True:
- results, page_token = paginate(query, Role, limit=page_size, descending=descending,
- page_token=page_token)
- assert len(results) <= page_size
- collected.extend(results)
+ # Paginate a query to lookup roles.
+ collected = []
+ page_token = None
+ while True:
+ results, page_token = paginate(
+ query, Role, limit=page_size, descending=descending, page_token=page_token
+ )
+ assert len(results) <= page_size
+ collected.extend(results)
- if page_token is None:
- break
+ if page_token is None:
+ break
- assert len(results) == page_size
+ assert len(results) == page_size
- for index, result in enumerate(results[1:]):
- if descending:
- assert result.id < results[index].id
- else:
- assert result.id > results[index].id
+ for index, result in enumerate(results[1:]):
+ if descending:
+ assert result.id < results[index].id
+ else:
+ assert result.id > results[index].id
- assert len(collected) == len(all_matching_roles)
- assert {c.id for c in collected} == {a.id for a in all_matching_roles}
+ assert len(collected) == len(all_matching_roles)
+ assert {c.id for c in collected} == {a.id for a in all_matching_roles}
diff --git a/data/model/test/test_organization.py b/data/model/test/test_organization.py
index 153814765..3c0a11c75 100644
--- a/data/model/test/test_organization.py
+++ b/data/model/test/test_organization.py
@@ -5,18 +5,16 @@ from data.model.user import mark_namespace_for_deletion
from data.queue import WorkQueue
from test.fixtures import *
-@pytest.mark.parametrize('deleted', [
- (True),
- (False),
-])
+
+@pytest.mark.parametrize("deleted", [(True), (False)])
def test_get_organizations(deleted, initialized_db):
- # Delete an org.
- deleted_org = get_organization('sellnsmall')
- queue = WorkQueue('testgcnamespace', lambda db: db.transaction())
- mark_namespace_for_deletion(deleted_org, [], queue)
+ # Delete an org.
+ deleted_org = get_organization("sellnsmall")
+ queue = WorkQueue("testgcnamespace", lambda db: db.transaction())
+ mark_namespace_for_deletion(deleted_org, [], queue)
- orgs = get_organizations(deleted=deleted)
- assert orgs
+ orgs = get_organizations(deleted=deleted)
+ assert orgs
- deleted_found = [org for org in orgs if org.id == deleted_org.id]
- assert bool(deleted_found) == deleted
+ deleted_found = [org for org in orgs if org.id == deleted_org.id]
+ assert bool(deleted_found) == deleted
diff --git a/data/model/test/test_repo_mirroring.py b/data/model/test/test_repo_mirroring.py
index 6a3f808e3..a05696a3a 100644
--- a/data/model/test/test_repo_mirroring.py
+++ b/data/model/test/test_repo_mirroring.py
@@ -4,232 +4,265 @@ from jsonschema import ValidationError
from data.database import RepoMirrorConfig, RepoMirrorStatus, User
from data import model
-from data.model.repo_mirror import (create_mirroring_rule, get_eligible_mirrors, update_sync_status_to_cancel,
- MAX_SYNC_RETRIES, release_mirror)
+from data.model.repo_mirror import (
+ create_mirroring_rule,
+ get_eligible_mirrors,
+ update_sync_status_to_cancel,
+ MAX_SYNC_RETRIES,
+ release_mirror,
+)
from test.fixtures import *
def create_mirror_repo_robot(rules, repo_name="repo"):
- try:
- user = User.get(User.username == "mirror")
- except User.DoesNotExist:
- user = create_user_noverify("mirror", "mirror@example.com", email_required=False)
+ try:
+ user = User.get(User.username == "mirror")
+ except User.DoesNotExist:
+ user = create_user_noverify(
+ "mirror", "mirror@example.com", email_required=False
+ )
- try:
- robot = lookup_robot("mirror+robot")
- except model.InvalidRobotException:
- robot, _ = create_robot("robot", user)
+ try:
+ robot = lookup_robot("mirror+robot")
+ except model.InvalidRobotException:
+ robot, _ = create_robot("robot", user)
- repo = create_repository("mirror", repo_name, None, repo_kind="image", visibility="public")
- repo.save()
+ repo = create_repository(
+ "mirror", repo_name, None, repo_kind="image", visibility="public"
+ )
+ repo.save()
- rule = model.repo_mirror.create_mirroring_rule(repo, rules)
+ rule = model.repo_mirror.create_mirroring_rule(repo, rules)
- mirror_kwargs = {
- "repository": repo,
- "root_rule": rule,
- "internal_robot": robot,
- "external_reference": "registry.example.com/namespace/repository",
- "sync_interval": timedelta(days=1).total_seconds()
- }
- mirror = enable_mirroring_for_repository(**mirror_kwargs)
- mirror.sync_status = RepoMirrorStatus.NEVER_RUN
- mirror.sync_start_date = datetime.utcnow() - timedelta(days=1)
- mirror.sync_retries_remaining = 3
- mirror.save()
+ mirror_kwargs = {
+ "repository": repo,
+ "root_rule": rule,
+ "internal_robot": robot,
+ "external_reference": "registry.example.com/namespace/repository",
+ "sync_interval": timedelta(days=1).total_seconds(),
+ }
+ mirror = enable_mirroring_for_repository(**mirror_kwargs)
+ mirror.sync_status = RepoMirrorStatus.NEVER_RUN
+ mirror.sync_start_date = datetime.utcnow() - timedelta(days=1)
+ mirror.sync_retries_remaining = 3
+ mirror.save()
- return (mirror, repo)
+ return (mirror, repo)
def disable_existing_mirrors():
- mirrors = RepoMirrorConfig.select().execute()
- for mirror in mirrors:
- mirror.is_enabled = False
- mirror.save()
+ mirrors = RepoMirrorConfig.select().execute()
+ for mirror in mirrors:
+ mirror.is_enabled = False
+ mirror.save()
def test_eligible_oldest_first(initialized_db):
- """
+ """
Eligible mirror candidates should be returned with the oldest (earliest created) first.
"""
- disable_existing_mirrors()
- mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
- mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
- mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
+ disable_existing_mirrors()
+ mirror_first, repo_first = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="first"
+ )
+ mirror_second, repo_second = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="second"
+ )
+ mirror_third, repo_third = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="third"
+ )
- candidates = get_eligible_mirrors()
+ candidates = get_eligible_mirrors()
- assert len(candidates) == 3
- assert candidates[0] == mirror_first
- assert candidates[1] == mirror_second
- assert candidates[2] == mirror_third
+ assert len(candidates) == 3
+ assert candidates[0] == mirror_first
+ assert candidates[1] == mirror_second
+ assert candidates[2] == mirror_third
def test_eligible_includes_expired_syncing(initialized_db):
- """
+ """
Mirrors that have an end time in the past are eligible even if their state indicates still syncing.
"""
- disable_existing_mirrors()
- mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
- mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
- mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
- mirror_fourth, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="fourth")
+ disable_existing_mirrors()
+ mirror_first, repo_first = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="first"
+ )
+ mirror_second, repo_second = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="second"
+ )
+ mirror_third, repo_third = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="third"
+ )
+ mirror_fourth, repo_third = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="fourth"
+ )
- mirror_second.sync_expiration_date = datetime.utcnow() - timedelta(hours=1)
- mirror_second.sync_status = RepoMirrorStatus.SYNCING
- mirror_second.save()
+ mirror_second.sync_expiration_date = datetime.utcnow() - timedelta(hours=1)
+ mirror_second.sync_status = RepoMirrorStatus.SYNCING
+ mirror_second.save()
- mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
- mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
- mirror_fourth.save()
+ mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
+ mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
+ mirror_fourth.save()
- candidates = get_eligible_mirrors()
+ candidates = get_eligible_mirrors()
- assert len(candidates) == 3
- assert candidates[0] == mirror_first
- assert candidates[1] == mirror_second
- assert candidates[2] == mirror_third
+ assert len(candidates) == 3
+ assert candidates[0] == mirror_first
+ assert candidates[1] == mirror_second
+ assert candidates[2] == mirror_third
def test_eligible_includes_immediate(initialized_db):
- """
+ """
Mirrors that are SYNC_NOW, regardless of starting time
"""
- disable_existing_mirrors()
- mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
- mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
- mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
- mirror_fourth, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="fourth")
- mirror_future, _ = create_mirror_repo_robot(["updated", "created"], repo_name="future")
- mirror_past, _ = create_mirror_repo_robot(["updated", "created"], repo_name="past")
+ disable_existing_mirrors()
+ mirror_first, repo_first = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="first"
+ )
+ mirror_second, repo_second = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="second"
+ )
+ mirror_third, repo_third = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="third"
+ )
+ mirror_fourth, repo_third = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="fourth"
+ )
+ mirror_future, _ = create_mirror_repo_robot(
+ ["updated", "created"], repo_name="future"
+ )
+ mirror_past, _ = create_mirror_repo_robot(["updated", "created"], repo_name="past")
- mirror_future.sync_start_date = datetime.utcnow() + timedelta(hours=6)
- mirror_future.sync_status = RepoMirrorStatus.SYNC_NOW
- mirror_future.save()
+ mirror_future.sync_start_date = datetime.utcnow() + timedelta(hours=6)
+ mirror_future.sync_status = RepoMirrorStatus.SYNC_NOW
+ mirror_future.save()
- mirror_past.sync_start_date = datetime.utcnow() - timedelta(hours=6)
- mirror_past.sync_status = RepoMirrorStatus.SYNC_NOW
- mirror_past.save()
+ mirror_past.sync_start_date = datetime.utcnow() - timedelta(hours=6)
+ mirror_past.sync_status = RepoMirrorStatus.SYNC_NOW
+ mirror_past.save()
- mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
- mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
- mirror_fourth.save()
+ mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
+ mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
+ mirror_fourth.save()
- candidates = get_eligible_mirrors()
+ candidates = get_eligible_mirrors()
- assert len(candidates) == 5
- assert candidates[0] == mirror_first
- assert candidates[1] == mirror_second
- assert candidates[2] == mirror_third
- assert candidates[3] == mirror_past
- assert candidates[4] == mirror_future
+ assert len(candidates) == 5
+ assert candidates[0] == mirror_first
+ assert candidates[1] == mirror_second
+ assert candidates[2] == mirror_third
+ assert candidates[3] == mirror_past
+ assert candidates[4] == mirror_future
def test_create_rule_validations(initialized_db):
- mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
+ mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
- with pytest.raises(ValidationError):
- create_mirroring_rule(repo, None)
+ with pytest.raises(ValidationError):
+ create_mirroring_rule(repo, None)
- with pytest.raises(ValidationError):
- create_mirroring_rule(repo, "['tag1', 'tag2']")
+ with pytest.raises(ValidationError):
+ create_mirroring_rule(repo, "['tag1', 'tag2']")
- with pytest.raises(ValidationError):
- create_mirroring_rule(repo, ['tag1', 'tag2'], rule_type=None)
+ with pytest.raises(ValidationError):
+ create_mirroring_rule(repo, ["tag1", "tag2"], rule_type=None)
def test_long_registry_passwords(initialized_db):
- """
+ """
Verify that long passwords, such as Base64 JWT used by Redhat's Registry, work as expected.
"""
- MAX_PASSWORD_LENGTH = 1024
+ MAX_PASSWORD_LENGTH = 1024
- username = ''.join('a' for _ in range(MAX_PASSWORD_LENGTH))
- password = ''.join('b' for _ in range(MAX_PASSWORD_LENGTH))
- assert len(username) == MAX_PASSWORD_LENGTH
- assert len(password) == MAX_PASSWORD_LENGTH
+ username = "".join("a" for _ in range(MAX_PASSWORD_LENGTH))
+ password = "".join("b" for _ in range(MAX_PASSWORD_LENGTH))
+ assert len(username) == MAX_PASSWORD_LENGTH
+ assert len(password) == MAX_PASSWORD_LENGTH
- repo = model.repository.get_repository('devtable', 'mirrored')
- assert repo
+ repo = model.repository.get_repository("devtable", "mirrored")
+ assert repo
- existing_mirror_conf = model.repo_mirror.get_mirror(repo)
- assert existing_mirror_conf
+ existing_mirror_conf = model.repo_mirror.get_mirror(repo)
+ assert existing_mirror_conf
- assert model.repo_mirror.change_credentials(repo, username, password)
+ assert model.repo_mirror.change_credentials(repo, username, password)
- updated_mirror_conf = model.repo_mirror.get_mirror(repo)
- assert updated_mirror_conf
+ updated_mirror_conf = model.repo_mirror.get_mirror(repo)
+ assert updated_mirror_conf
- assert updated_mirror_conf.external_registry_username.decrypt() == username
- assert updated_mirror_conf.external_registry_password.decrypt() == password
+ assert updated_mirror_conf.external_registry_username.decrypt() == username
+ assert updated_mirror_conf.external_registry_password.decrypt() == password
def test_sync_status_to_cancel(initialized_db):
- """
+ """
SYNCING and SYNC_NOW mirrors may be canceled, ending in NEVER_RUN
"""
- disable_existing_mirrors()
- mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="cancel")
+ disable_existing_mirrors()
+ mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="cancel")
- mirror.sync_status = RepoMirrorStatus.SYNCING
- mirror.save()
- updated = update_sync_status_to_cancel(mirror)
- assert updated is not None
- assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
+ mirror.sync_status = RepoMirrorStatus.SYNCING
+ mirror.save()
+ updated = update_sync_status_to_cancel(mirror)
+ assert updated is not None
+ assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
- mirror.sync_status = RepoMirrorStatus.SYNC_NOW
- mirror.save()
- updated = update_sync_status_to_cancel(mirror)
- assert updated is not None
- assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
+ mirror.sync_status = RepoMirrorStatus.SYNC_NOW
+ mirror.save()
+ updated = update_sync_status_to_cancel(mirror)
+ assert updated is not None
+ assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
- mirror.sync_status = RepoMirrorStatus.FAIL
- mirror.save()
- updated = update_sync_status_to_cancel(mirror)
- assert updated is None
+ mirror.sync_status = RepoMirrorStatus.FAIL
+ mirror.save()
+ updated = update_sync_status_to_cancel(mirror)
+ assert updated is None
- mirror.sync_status = RepoMirrorStatus.NEVER_RUN
- mirror.save()
- updated = update_sync_status_to_cancel(mirror)
- assert updated is None
+ mirror.sync_status = RepoMirrorStatus.NEVER_RUN
+ mirror.save()
+ updated = update_sync_status_to_cancel(mirror)
+ assert updated is None
- mirror.sync_status = RepoMirrorStatus.SUCCESS
- mirror.save()
- updated = update_sync_status_to_cancel(mirror)
- assert updated is None
+ mirror.sync_status = RepoMirrorStatus.SUCCESS
+ mirror.save()
+ updated = update_sync_status_to_cancel(mirror)
+ assert updated is None
def test_release_mirror(initialized_db):
- """
+ """
Mirrors that are SYNC_NOW, regardless of starting time
"""
- disable_existing_mirrors()
- mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
+ disable_existing_mirrors()
+ mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
- # mysql rounds the milliseconds on update so force that to happen now
- query = (RepoMirrorConfig
- .update(sync_start_date=mirror.sync_start_date)
- .where(RepoMirrorConfig.id == mirror.id))
- query.execute()
- mirror = RepoMirrorConfig.get_by_id(mirror.id)
- original_sync_start_date = mirror.sync_start_date
+ # mysql rounds the milliseconds on update so force that to happen now
+ query = RepoMirrorConfig.update(sync_start_date=mirror.sync_start_date).where(
+ RepoMirrorConfig.id == mirror.id
+ )
+ query.execute()
+ mirror = RepoMirrorConfig.get_by_id(mirror.id)
+ original_sync_start_date = mirror.sync_start_date
- assert mirror.sync_retries_remaining == 3
+ assert mirror.sync_retries_remaining == 3
- mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
- assert mirror.sync_retries_remaining == 2
- assert mirror.sync_start_date == original_sync_start_date
+ mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
+ assert mirror.sync_retries_remaining == 2
+ assert mirror.sync_start_date == original_sync_start_date
- mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
- assert mirror.sync_retries_remaining == 1
- assert mirror.sync_start_date == original_sync_start_date
+ mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
+ assert mirror.sync_retries_remaining == 1
+ assert mirror.sync_start_date == original_sync_start_date
- mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
- assert mirror.sync_retries_remaining == 3
- assert mirror.sync_start_date > original_sync_start_date
+ mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
+ assert mirror.sync_retries_remaining == 3
+ assert mirror.sync_start_date > original_sync_start_date
diff --git a/data/model/test/test_repository.py b/data/model/test/test_repository.py
index 25e8b7cf2..4ef87f287 100644
--- a/data/model/test/test_repository.py
+++ b/data/model/test/test_repository.py
@@ -11,39 +11,46 @@ from test.fixtures import *
def test_duplicate_repository_different_kinds(initialized_db):
- # Create an image repo.
- create_repository('devtable', 'somenewrepo', None, repo_kind='image')
+ # Create an image repo.
+ create_repository("devtable", "somenewrepo", None, repo_kind="image")
- # Try to create an app repo with the same name, which should fail.
- with pytest.raises(IntegrityError):
- create_repository('devtable', 'somenewrepo', None, repo_kind='application')
+ # Try to create an app repo with the same name, which should fail.
+ with pytest.raises(IntegrityError):
+ create_repository("devtable", "somenewrepo", None, repo_kind="application")
def test_is_empty(initialized_db):
- create_repository('devtable', 'somenewrepo', None, repo_kind='image')
+ create_repository("devtable", "somenewrepo", None, repo_kind="image")
- assert is_empty('devtable', 'somenewrepo')
- assert not is_empty('devtable', 'simple')
+ assert is_empty("devtable", "somenewrepo")
+ assert not is_empty("devtable", "simple")
-@pytest.mark.skipif(os.environ.get('TEST_DATABASE_URI', '').find('mysql') >= 0,
- reason='MySQL requires specialized indexing of newly created repos')
-@pytest.mark.parametrize('query', [
- (''),
- ('e'),
-])
-@pytest.mark.parametrize('authed_username', [
- (None),
- ('devtable'),
-])
+
+@pytest.mark.skipif(
+ os.environ.get("TEST_DATABASE_URI", "").find("mysql") >= 0,
+ reason="MySQL requires specialized indexing of newly created repos",
+)
+@pytest.mark.parametrize("query", [(""), ("e")])
+@pytest.mark.parametrize("authed_username", [(None), ("devtable")])
def test_search_pagination(query, authed_username, initialized_db):
- # Create some public repos.
- repo1 = create_repository('devtable', 'somenewrepo', None, repo_kind='image', visibility='public')
- repo2 = create_repository('devtable', 'somenewrepo2', None, repo_kind='image', visibility='public')
- repo3 = create_repository('devtable', 'somenewrepo3', None, repo_kind='image', visibility='public')
+ # Create some public repos.
+ repo1 = create_repository(
+ "devtable", "somenewrepo", None, repo_kind="image", visibility="public"
+ )
+ repo2 = create_repository(
+ "devtable", "somenewrepo2", None, repo_kind="image", visibility="public"
+ )
+ repo3 = create_repository(
+ "devtable", "somenewrepo3", None, repo_kind="image", visibility="public"
+ )
- repositories = get_filtered_matching_repositories(query, filter_username=authed_username)
- assert len(repositories) > 3
+ repositories = get_filtered_matching_repositories(
+ query, filter_username=authed_username
+ )
+ assert len(repositories) > 3
- next_repos = get_filtered_matching_repositories(query, filter_username=authed_username, offset=1)
- assert repositories[0].id != next_repos[0].id
- assert repositories[1].id == next_repos[0].id
+ next_repos = get_filtered_matching_repositories(
+ query, filter_username=authed_username, offset=1
+ )
+ assert repositories[0].id != next_repos[0].id
+ assert repositories[1].id == next_repos[0].id
diff --git a/data/model/test/test_repositoryactioncount.py b/data/model/test/test_repositoryactioncount.py
index bdad4e315..7433109d4 100644
--- a/data/model/test/test_repositoryactioncount.py
+++ b/data/model/test/test_repositoryactioncount.py
@@ -7,32 +7,37 @@ from data.model.repository import create_repository
from data.model.repositoryactioncount import update_repository_score, SEARCH_BUCKETS
from test.fixtures import *
-@pytest.mark.parametrize('bucket_sums,expected_score', [
- ((0, 0, 0, 0), 0),
- ((1, 6, 24, 152), 100),
- ((2, 6, 24, 152), 101),
- ((1, 6, 24, 304), 171),
-
- ((100, 480, 24, 152), 703),
- ((1, 6, 24, 15200), 7131),
-
- ((300, 500, 1000, 0), 1733),
- ((5000, 0, 0, 0), 5434),
-])
+@pytest.mark.parametrize(
+ "bucket_sums,expected_score",
+ [
+ ((0, 0, 0, 0), 0),
+ ((1, 6, 24, 152), 100),
+ ((2, 6, 24, 152), 101),
+ ((1, 6, 24, 304), 171),
+ ((100, 480, 24, 152), 703),
+ ((1, 6, 24, 15200), 7131),
+ ((300, 500, 1000, 0), 1733),
+ ((5000, 0, 0, 0), 5434),
+ ],
+)
def test_update_repository_score(bucket_sums, expected_score, initialized_db):
- # Create a new repository.
- repo = create_repository('devtable', 'somenewrepo', None, repo_kind='image')
+ # Create a new repository.
+ repo = create_repository("devtable", "somenewrepo", None, repo_kind="image")
- # Delete the RAC created in create_repository.
- RepositoryActionCount.delete().where(RepositoryActionCount.repository == repo).execute()
+ # Delete the RAC created in create_repository.
+ RepositoryActionCount.delete().where(
+ RepositoryActionCount.repository == repo
+ ).execute()
- # Add RAC rows for each of the buckets.
- for index, bucket in enumerate(SEARCH_BUCKETS):
- for day in range(0, bucket.days):
- RepositoryActionCount.create(repository=repo,
- count=(bucket_sums[index] / bucket.days * 1.0),
- date=date.today() - bucket.delta + timedelta(days=day))
+ # Add RAC rows for each of the buckets.
+ for index, bucket in enumerate(SEARCH_BUCKETS):
+ for day in range(0, bucket.days):
+ RepositoryActionCount.create(
+ repository=repo,
+ count=(bucket_sums[index] / bucket.days * 1.0),
+ date=date.today() - bucket.delta + timedelta(days=day),
+ )
- assert update_repository_score(repo)
- assert RepositorySearchScore.get(repository=repo).score == expected_score
+ assert update_repository_score(repo)
+ assert RepositorySearchScore.get(repository=repo).score == expected_score
diff --git a/data/model/test/test_tag.py b/data/model/test/test_tag.py
index 2f5adf773..75f53023e 100644
--- a/data/model/test/test_tag.py
+++ b/data/model/test/test_tag.py
@@ -8,14 +8,34 @@ import pytest
from mock import patch
from app import docker_v2_signing_key
-from data.database import (Image, RepositoryTag, ImageStorage, Repository, Manifest, ManifestBlob,
- ManifestLegacyImage, TagManifestToManifest, Tag, TagToRepositoryTag)
+from data.database import (
+ Image,
+ RepositoryTag,
+ ImageStorage,
+ Repository,
+ Manifest,
+ ManifestBlob,
+ ManifestLegacyImage,
+ TagManifestToManifest,
+ Tag,
+ TagToRepositoryTag,
+)
from data.model.repository import create_repository
-from data.model.tag import (list_active_repo_tags, create_or_update_tag, delete_tag,
- get_matching_tags, _tag_alive, get_matching_tags_for_images,
- change_tag_expiration, get_active_tag, store_tag_manifest_for_testing,
- get_most_recent_tag, get_active_tag_for_repo,
- create_or_update_tag_for_repo, set_tag_end_ts)
+from data.model.tag import (
+ list_active_repo_tags,
+ create_or_update_tag,
+ delete_tag,
+ get_matching_tags,
+ _tag_alive,
+ get_matching_tags_for_images,
+ change_tag_expiration,
+ get_active_tag,
+ store_tag_manifest_for_testing,
+ get_most_recent_tag,
+ get_active_tag_for_repo,
+ create_or_update_tag_for_repo,
+ set_tag_end_ts,
+)
from data.model.image import find_create_or_link_image
from image.docker.schema1 import DockerSchema1ManifestBuilder
from util.timedeltastring import convert_to_timedelta
@@ -24,333 +44,369 @@ from test.fixtures import *
def _get_expected_tags(image):
- expected_query = (RepositoryTag
- .select()
- .join(Image)
- .where(RepositoryTag.hidden == False)
- .where((Image.id == image.id) | (Image.ancestors ** ('%%/%s/%%' % image.id))))
- return set([tag.id for tag in _tag_alive(expected_query)])
+ expected_query = (
+ RepositoryTag.select()
+ .join(Image)
+ .where(RepositoryTag.hidden == False)
+ .where((Image.id == image.id) | (Image.ancestors ** ("%%/%s/%%" % image.id)))
+ )
+ return set([tag.id for tag in _tag_alive(expected_query)])
-@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
- (1, 1),
- (10, 10),
- (100, 500),
-])
+@pytest.mark.parametrize(
+ "max_subqueries,max_image_lookup_count", [(1, 1), (10, 10), (100, 500)]
+)
def test_get_matching_tags(max_subqueries, max_image_lookup_count, initialized_db):
- with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
- with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
- # Test for every image in the test database.
- for image in Image.select(Image, ImageStorage).join(ImageStorage):
- matching_query = get_matching_tags(image.docker_image_id, image.storage.uuid)
- matching_tags = set([tag.id for tag in matching_query])
- expected_tags = _get_expected_tags(image)
- assert matching_tags == expected_tags, "mismatch for image %s" % image.id
+ with patch("data.model.tag._MAX_SUB_QUERIES", max_subqueries):
+ with patch("data.model.tag._MAX_IMAGE_LOOKUP_COUNT", max_image_lookup_count):
+ # Test for every image in the test database.
+ for image in Image.select(Image, ImageStorage).join(ImageStorage):
+ matching_query = get_matching_tags(
+ image.docker_image_id, image.storage.uuid
+ )
+ matching_tags = set([tag.id for tag in matching_query])
+ expected_tags = _get_expected_tags(image)
+ assert matching_tags == expected_tags, (
+ "mismatch for image %s" % image.id
+ )
- oci_tags = list(Tag
- .select()
- .join(TagToRepositoryTag)
- .where(TagToRepositoryTag.repository_tag << expected_tags))
- assert len(oci_tags) == len(expected_tags)
+ oci_tags = list(
+ Tag.select()
+ .join(TagToRepositoryTag)
+ .where(TagToRepositoryTag.repository_tag << expected_tags)
+ )
+ assert len(oci_tags) == len(expected_tags)
-@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
- (1, 1),
- (10, 10),
- (100, 500),
-])
-def test_get_matching_tag_ids_for_images(max_subqueries, max_image_lookup_count, initialized_db):
- with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
- with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
- # Try for various sets of the first N images.
- for count in [5, 10, 15]:
- pairs = []
- expected_tags_ids = set()
- for image in Image.select(Image, ImageStorage).join(ImageStorage):
- if len(pairs) >= count:
- break
+@pytest.mark.parametrize(
+ "max_subqueries,max_image_lookup_count", [(1, 1), (10, 10), (100, 500)]
+)
+def test_get_matching_tag_ids_for_images(
+ max_subqueries, max_image_lookup_count, initialized_db
+):
+ with patch("data.model.tag._MAX_SUB_QUERIES", max_subqueries):
+ with patch("data.model.tag._MAX_IMAGE_LOOKUP_COUNT", max_image_lookup_count):
+ # Try for various sets of the first N images.
+ for count in [5, 10, 15]:
+ pairs = []
+ expected_tags_ids = set()
+ for image in Image.select(Image, ImageStorage).join(ImageStorage):
+ if len(pairs) >= count:
+ break
- pairs.append((image.docker_image_id, image.storage.uuid))
- expected_tags_ids.update(_get_expected_tags(image))
+ pairs.append((image.docker_image_id, image.storage.uuid))
+ expected_tags_ids.update(_get_expected_tags(image))
- matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
- assert matching_tags_ids == expected_tags_ids
+ matching_tags_ids = set(
+ [tag.id for tag in get_matching_tags_for_images(pairs)]
+ )
+ assert matching_tags_ids == expected_tags_ids
-@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
- (1, 1),
- (10, 10),
- (100, 500),
-])
-def test_get_matching_tag_ids_for_all_images(max_subqueries, max_image_lookup_count, initialized_db):
- with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
- with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
- pairs = []
- for image in Image.select(Image, ImageStorage).join(ImageStorage):
- pairs.append((image.docker_image_id, image.storage.uuid))
+@pytest.mark.parametrize(
+ "max_subqueries,max_image_lookup_count", [(1, 1), (10, 10), (100, 500)]
+)
+def test_get_matching_tag_ids_for_all_images(
+ max_subqueries, max_image_lookup_count, initialized_db
+):
+ with patch("data.model.tag._MAX_SUB_QUERIES", max_subqueries):
+ with patch("data.model.tag._MAX_IMAGE_LOOKUP_COUNT", max_image_lookup_count):
+ pairs = []
+ for image in Image.select(Image, ImageStorage).join(ImageStorage):
+ pairs.append((image.docker_image_id, image.storage.uuid))
- expected_tags_ids = set([tag.id for tag in _tag_alive(RepositoryTag.select())])
- matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
+ expected_tags_ids = set(
+ [tag.id for tag in _tag_alive(RepositoryTag.select())]
+ )
+ matching_tags_ids = set(
+ [tag.id for tag in get_matching_tags_for_images(pairs)]
+ )
- # Ensure every alive tag was found.
- assert matching_tags_ids == expected_tags_ids
+ # Ensure every alive tag was found.
+ assert matching_tags_ids == expected_tags_ids
def test_get_matching_tag_ids_images_filtered(initialized_db):
- def filter_query(query):
- return query.join(Repository).where(Repository.name == 'simple')
+ def filter_query(query):
+ return query.join(Repository).where(Repository.name == "simple")
- filtered_images = filter_query(Image
- .select(Image, ImageStorage)
- .join(RepositoryTag)
- .switch(Image)
- .join(ImageStorage)
- .switch(Image))
+ filtered_images = filter_query(
+ Image.select(Image, ImageStorage)
+ .join(RepositoryTag)
+ .switch(Image)
+ .join(ImageStorage)
+ .switch(Image)
+ )
- expected_tags_query = _tag_alive(filter_query(RepositoryTag
- .select()))
+ expected_tags_query = _tag_alive(filter_query(RepositoryTag.select()))
- pairs = []
- for image in filtered_images:
- pairs.append((image.docker_image_id, image.storage.uuid))
+ pairs = []
+ for image in filtered_images:
+ pairs.append((image.docker_image_id, image.storage.uuid))
- matching_tags = get_matching_tags_for_images(pairs, filter_images=filter_query,
- filter_tags=filter_query)
+ matching_tags = get_matching_tags_for_images(
+ pairs, filter_images=filter_query, filter_tags=filter_query
+ )
- expected_tag_ids = set([tag.id for tag in expected_tags_query])
- matching_tags_ids = set([tag.id for tag in matching_tags])
+ expected_tag_ids = set([tag.id for tag in expected_tags_query])
+ matching_tags_ids = set([tag.id for tag in matching_tags])
- # Ensure every alive tag was found.
- assert matching_tags_ids == expected_tag_ids
+ # Ensure every alive tag was found.
+ assert matching_tags_ids == expected_tag_ids
def _get_oci_tag(tag):
- return (Tag
- .select()
- .join(TagToRepositoryTag)
- .where(TagToRepositoryTag.repository_tag == tag)).get()
+ return (
+ Tag.select()
+ .join(TagToRepositoryTag)
+ .where(TagToRepositoryTag.repository_tag == tag)
+ ).get()
def assert_tags(repository, *args):
- tags = list(list_active_repo_tags(repository))
- assert len(tags) == len(args)
+ tags = list(list_active_repo_tags(repository))
+ assert len(tags) == len(args)
- tags_dict = {}
- for tag in tags:
- assert not tag.name in tags_dict
- assert not tag.hidden
- assert not tag.lifetime_end_ts or tag.lifetime_end_ts > time()
+ tags_dict = {}
+ for tag in tags:
+ assert not tag.name in tags_dict
+ assert not tag.hidden
+ assert not tag.lifetime_end_ts or tag.lifetime_end_ts > time()
- tags_dict[tag.name] = tag
+ tags_dict[tag.name] = tag
- oci_tag = _get_oci_tag(tag)
- assert oci_tag.name == tag.name
- assert not oci_tag.hidden
- assert oci_tag.reversion == tag.reversion
+ oci_tag = _get_oci_tag(tag)
+ assert oci_tag.name == tag.name
+ assert not oci_tag.hidden
+ assert oci_tag.reversion == tag.reversion
- if tag.lifetime_end_ts:
- assert oci_tag.lifetime_end_ms == (tag.lifetime_end_ts * 1000)
- else:
- assert oci_tag.lifetime_end_ms is None
+ if tag.lifetime_end_ts:
+ assert oci_tag.lifetime_end_ms == (tag.lifetime_end_ts * 1000)
+ else:
+ assert oci_tag.lifetime_end_ms is None
- for expected in args:
- assert expected in tags_dict
+ for expected in args:
+ assert expected in tags_dict
def test_create_reversion_tag(initialized_db):
- repository = create_repository('devtable', 'somenewrepo', None)
- manifest = Manifest.get()
- image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
+ repository = create_repository("devtable", "somenewrepo", None)
+ manifest = Manifest.get()
+ image1 = find_create_or_link_image("foobarimage1", repository, None, {}, "local_us")
- footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
- oci_manifest=manifest, reversion=True)
- assert footag.reversion
+ footag = create_or_update_tag_for_repo(
+ repository, "foo", image1.docker_image_id, oci_manifest=manifest, reversion=True
+ )
+ assert footag.reversion
- oci_tag = _get_oci_tag(footag)
- assert oci_tag.name == footag.name
- assert not oci_tag.hidden
- assert oci_tag.reversion == footag.reversion
+ oci_tag = _get_oci_tag(footag)
+ assert oci_tag.name == footag.name
+ assert not oci_tag.hidden
+ assert oci_tag.reversion == footag.reversion
def test_list_active_tags(initialized_db):
- # Create a new repository.
- repository = create_repository('devtable', 'somenewrepo', None)
- manifest = Manifest.get()
+ # Create a new repository.
+ repository = create_repository("devtable", "somenewrepo", None)
+ manifest = Manifest.get()
- # Create some images.
- image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
- image2 = find_create_or_link_image('foobarimage2', repository, None, {}, 'local_us')
+ # Create some images.
+ image1 = find_create_or_link_image("foobarimage1", repository, None, {}, "local_us")
+ image2 = find_create_or_link_image("foobarimage2", repository, None, {}, "local_us")
- # Make sure its tags list is empty.
- assert_tags(repository)
+ # Make sure its tags list is empty.
+ assert_tags(repository)
- # Add some new tags.
- footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
- oci_manifest=manifest)
- bartag = create_or_update_tag_for_repo(repository, 'bar', image1.docker_image_id,
- oci_manifest=manifest)
+ # Add some new tags.
+ footag = create_or_update_tag_for_repo(
+ repository, "foo", image1.docker_image_id, oci_manifest=manifest
+ )
+ bartag = create_or_update_tag_for_repo(
+ repository, "bar", image1.docker_image_id, oci_manifest=manifest
+ )
- # Since timestamps are stored on a second-granularity, we need to make the tags "start"
- # before "now", so when we recreate them below, they don't conflict.
- footag.lifetime_start_ts -= 5
- footag.save()
+ # Since timestamps are stored on a second-granularity, we need to make the tags "start"
+ # before "now", so when we recreate them below, they don't conflict.
+ footag.lifetime_start_ts -= 5
+ footag.save()
- bartag.lifetime_start_ts -= 5
- bartag.save()
+ bartag.lifetime_start_ts -= 5
+ bartag.save()
- footag_oci = _get_oci_tag(footag)
- footag_oci.lifetime_start_ms -= 5000
- footag_oci.save()
+ footag_oci = _get_oci_tag(footag)
+ footag_oci.lifetime_start_ms -= 5000
+ footag_oci.save()
- bartag_oci = _get_oci_tag(bartag)
- bartag_oci.lifetime_start_ms -= 5000
- bartag_oci.save()
+ bartag_oci = _get_oci_tag(bartag)
+ bartag_oci.lifetime_start_ms -= 5000
+ bartag_oci.save()
- # Make sure they are returned.
- assert_tags(repository, 'foo', 'bar')
+ # Make sure they are returned.
+ assert_tags(repository, "foo", "bar")
- # Set the expirations to be explicitly empty.
- set_tag_end_ts(footag, None)
- set_tag_end_ts(bartag, None)
+ # Set the expirations to be explicitly empty.
+ set_tag_end_ts(footag, None)
+ set_tag_end_ts(bartag, None)
- # Make sure they are returned.
- assert_tags(repository, 'foo', 'bar')
+ # Make sure they are returned.
+ assert_tags(repository, "foo", "bar")
- # Mark as a tag as expiring in the far future, and make sure it is still returned.
- set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
+ # Mark as a tag as expiring in the far future, and make sure it is still returned.
+ set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
- # Make sure they are returned.
- assert_tags(repository, 'foo', 'bar')
+ # Make sure they are returned.
+ assert_tags(repository, "foo", "bar")
- # Delete a tag and make sure it isn't returned.
- footag = delete_tag('devtable', 'somenewrepo', 'foo')
- set_tag_end_ts(footag, footag.lifetime_end_ts - 4)
+ # Delete a tag and make sure it isn't returned.
+ footag = delete_tag("devtable", "somenewrepo", "foo")
+ set_tag_end_ts(footag, footag.lifetime_end_ts - 4)
- assert_tags(repository, 'bar')
+ assert_tags(repository, "bar")
- # Add a new foo again.
- footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
- oci_manifest=manifest)
- footag.lifetime_start_ts -= 3
- footag.save()
+ # Add a new foo again.
+ footag = create_or_update_tag_for_repo(
+ repository, "foo", image1.docker_image_id, oci_manifest=manifest
+ )
+ footag.lifetime_start_ts -= 3
+ footag.save()
- footag_oci = _get_oci_tag(footag)
- footag_oci.lifetime_start_ms -= 3000
- footag_oci.save()
+ footag_oci = _get_oci_tag(footag)
+ footag_oci.lifetime_start_ms -= 3000
+ footag_oci.save()
- assert_tags(repository, 'foo', 'bar')
+ assert_tags(repository, "foo", "bar")
- # Mark as a tag as expiring in the far future, and make sure it is still returned.
- set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
+ # Mark as a tag as expiring in the far future, and make sure it is still returned.
+ set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
- # Make sure they are returned.
- assert_tags(repository, 'foo', 'bar')
+ # Make sure they are returned.
+ assert_tags(repository, "foo", "bar")
- # "Move" foo by updating it and make sure we don't get duplicates.
- create_or_update_tag_for_repo(repository, 'foo', image2.docker_image_id, oci_manifest=manifest)
- assert_tags(repository, 'foo', 'bar')
+ # "Move" foo by updating it and make sure we don't get duplicates.
+ create_or_update_tag_for_repo(
+ repository, "foo", image2.docker_image_id, oci_manifest=manifest
+ )
+ assert_tags(repository, "foo", "bar")
-@pytest.mark.parametrize('expiration_offset, expected_offset', [
- (None, None),
- ('0s', '1h'),
- ('30m', '1h'),
- ('2h', '2h'),
- ('2w', '2w'),
- ('200w', '104w'),
-])
+@pytest.mark.parametrize(
+ "expiration_offset, expected_offset",
+ [
+ (None, None),
+ ("0s", "1h"),
+ ("30m", "1h"),
+ ("2h", "2h"),
+ ("2w", "2w"),
+ ("200w", "104w"),
+ ],
+)
def test_change_tag_expiration(expiration_offset, expected_offset, initialized_db):
- repository = create_repository('devtable', 'somenewrepo', None)
- image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
+ repository = create_repository("devtable", "somenewrepo", None)
+ image1 = find_create_or_link_image("foobarimage1", repository, None, {}, "local_us")
- manifest = Manifest.get()
- footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
- oci_manifest=manifest)
+ manifest = Manifest.get()
+ footag = create_or_update_tag_for_repo(
+ repository, "foo", image1.docker_image_id, oci_manifest=manifest
+ )
- expiration_date = None
- if expiration_offset is not None:
- expiration_date = datetime.utcnow() + convert_to_timedelta(expiration_offset)
+ expiration_date = None
+ if expiration_offset is not None:
+ expiration_date = datetime.utcnow() + convert_to_timedelta(expiration_offset)
- assert change_tag_expiration(footag, expiration_date)
+ assert change_tag_expiration(footag, expiration_date)
- # Lookup the tag again.
- footag_updated = get_active_tag('devtable', 'somenewrepo', 'foo')
- oci_tag = _get_oci_tag(footag_updated)
+ # Lookup the tag again.
+ footag_updated = get_active_tag("devtable", "somenewrepo", "foo")
+ oci_tag = _get_oci_tag(footag_updated)
- if expected_offset is None:
- assert footag_updated.lifetime_end_ts is None
- assert oci_tag.lifetime_end_ms is None
- else:
- start_date = datetime.utcfromtimestamp(footag_updated.lifetime_start_ts)
- end_date = datetime.utcfromtimestamp(footag_updated.lifetime_end_ts)
- expected_end_date = start_date + convert_to_timedelta(expected_offset)
- assert (expected_end_date - end_date).total_seconds() < 5 # variance in test
+ if expected_offset is None:
+ assert footag_updated.lifetime_end_ts is None
+ assert oci_tag.lifetime_end_ms is None
+ else:
+ start_date = datetime.utcfromtimestamp(footag_updated.lifetime_start_ts)
+ end_date = datetime.utcfromtimestamp(footag_updated.lifetime_end_ts)
+ expected_end_date = start_date + convert_to_timedelta(expected_offset)
+ assert (expected_end_date - end_date).total_seconds() < 5 # variance in test
- assert oci_tag.lifetime_end_ms == (footag_updated.lifetime_end_ts * 1000)
+ assert oci_tag.lifetime_end_ms == (footag_updated.lifetime_end_ts * 1000)
def random_storages():
- return list(ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(10))
+ return list(
+ ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(10)
+ )
def repeated_storages():
- storages = list(ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(5))
- return storages + storages
+ storages = list(
+ ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(5)
+ )
+ return storages + storages
-@pytest.mark.parametrize('get_storages', [
- random_storages,
- repeated_storages,
-])
+@pytest.mark.parametrize("get_storages", [random_storages, repeated_storages])
def test_store_tag_manifest(get_storages, initialized_db):
- # Create a manifest with some layers.
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'sometag')
+ # Create a manifest with some layers.
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "sometag")
- storages = get_storages()
- assert storages
+ storages = get_storages()
+ assert storages
- repo = model.repository.get_repository('devtable', 'simple')
- storage_id_map = {}
- for index, storage in enumerate(storages):
- image_id = 'someimage%s' % index
- builder.add_layer(storage.content_checksum, json.dumps({'id': image_id}))
- find_create_or_link_image(image_id, repo, 'devtable', {}, 'local_us')
- storage_id_map[storage.content_checksum] = storage.id
+ repo = model.repository.get_repository("devtable", "simple")
+ storage_id_map = {}
+ for index, storage in enumerate(storages):
+ image_id = "someimage%s" % index
+ builder.add_layer(storage.content_checksum, json.dumps({"id": image_id}))
+ find_create_or_link_image(image_id, repo, "devtable", {}, "local_us")
+ storage_id_map[storage.content_checksum] = storage.id
- manifest = builder.build(docker_v2_signing_key)
- tag_manifest, _ = store_tag_manifest_for_testing('devtable', 'simple', 'sometag', manifest,
- manifest.leaf_layer_v1_image_id, storage_id_map)
+ manifest = builder.build(docker_v2_signing_key)
+ tag_manifest, _ = store_tag_manifest_for_testing(
+ "devtable",
+ "simple",
+ "sometag",
+ manifest,
+ manifest.leaf_layer_v1_image_id,
+ storage_id_map,
+ )
- # Ensure we have the new-model expected rows.
- mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
+ # Ensure we have the new-model expected rows.
+ mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
- assert mapping_row.manifest is not None
- assert mapping_row.manifest.manifest_bytes == manifest.bytes.as_encoded_str()
- assert mapping_row.manifest.digest == str(manifest.digest)
+ assert mapping_row.manifest is not None
+ assert mapping_row.manifest.manifest_bytes == manifest.bytes.as_encoded_str()
+ assert mapping_row.manifest.digest == str(manifest.digest)
- blob_rows = {m.blob_id for m in
- ManifestBlob.select().where(ManifestBlob.manifest == mapping_row.manifest)}
- assert blob_rows == {s.id for s in storages}
+ blob_rows = {
+ m.blob_id
+ for m in ManifestBlob.select().where(
+ ManifestBlob.manifest == mapping_row.manifest
+ )
+ }
+ assert blob_rows == {s.id for s in storages}
- assert ManifestLegacyImage.get(manifest=mapping_row.manifest).image == tag_manifest.tag.image
+ assert (
+ ManifestLegacyImage.get(manifest=mapping_row.manifest).image
+ == tag_manifest.tag.image
+ )
def test_get_most_recent_tag(initialized_db):
- # Create a hidden tag that is the most recent.
- repo = model.repository.get_repository('devtable', 'simple')
- image = model.tag.get_tag_image('devtable', 'simple', 'latest')
- model.tag.create_temporary_hidden_tag(repo, image, 10000000)
+ # Create a hidden tag that is the most recent.
+ repo = model.repository.get_repository("devtable", "simple")
+ image = model.tag.get_tag_image("devtable", "simple", "latest")
+ model.tag.create_temporary_hidden_tag(repo, image, 10000000)
- # Ensure we find a non-hidden tag.
- found = model.tag.get_most_recent_tag(repo)
- assert not found.hidden
+ # Ensure we find a non-hidden tag.
+ found = model.tag.get_most_recent_tag(repo)
+ assert not found.hidden
def test_get_active_tag_for_repo(initialized_db):
- repo = model.repository.get_repository('devtable', 'simple')
- image = model.tag.get_tag_image('devtable', 'simple', 'latest')
- hidden_tag = model.tag.create_temporary_hidden_tag(repo, image, 10000000)
+ repo = model.repository.get_repository("devtable", "simple")
+ image = model.tag.get_tag_image("devtable", "simple", "latest")
+ hidden_tag = model.tag.create_temporary_hidden_tag(repo, image, 10000000)
- # Ensure get active tag for repo cannot find it.
- assert model.tag.get_active_tag_for_repo(repo, hidden_tag) is None
- assert model.tag.get_active_tag_for_repo(repo, 'latest') is not None
+ # Ensure get active tag for repo cannot find it.
+ assert model.tag.get_active_tag_for_repo(repo, hidden_tag) is None
+ assert model.tag.get_active_tag_for_repo(repo, "latest") is not None
diff --git a/data/model/test/test_team.py b/data/model/test/test_team.py
index 88b08855c..ea2a62f25 100644
--- a/data/model/test/test_team.py
+++ b/data/model/test/test_team.py
@@ -1,61 +1,69 @@
import pytest
-from data.model.team import (add_or_invite_to_team, create_team, confirm_team_invite,
- list_team_users, validate_team_name)
+from data.model.team import (
+ add_or_invite_to_team,
+ create_team,
+ confirm_team_invite,
+ list_team_users,
+ validate_team_name,
+)
from data.model.organization import create_organization
from data.model.user import get_user, create_user_noverify
from test.fixtures import *
-@pytest.mark.parametrize('name, is_valid', [
- ('', False),
- ('f', False),
- ('fo', True),
- ('f' * 255, True),
- ('f' * 256, False),
- (' ', False),
- ('helloworld', True),
- ('hello_world', True),
- ('hello-world', True),
- ('hello world', False),
- ('HelloWorld', False),
-])
+@pytest.mark.parametrize(
+ "name, is_valid",
+ [
+ ("", False),
+ ("f", False),
+ ("fo", True),
+ ("f" * 255, True),
+ ("f" * 256, False),
+ (" ", False),
+ ("helloworld", True),
+ ("hello_world", True),
+ ("hello-world", True),
+ ("hello world", False),
+ ("HelloWorld", False),
+ ],
+)
def test_validate_team_name(name, is_valid):
- result, _ = validate_team_name(name)
- assert result == is_valid
+ result, _ = validate_team_name(name)
+ assert result == is_valid
def is_in_team(team, user):
- return user.username in {u.username for u in list_team_users(team)}
+ return user.username in {u.username for u in list_team_users(team)}
def test_invite_to_team(initialized_db):
- first_user = get_user('devtable')
- second_user = create_user_noverify('newuser', 'foo@example.com')
+ first_user = get_user("devtable")
+ second_user = create_user_noverify("newuser", "foo@example.com")
- def run_invite_flow(orgname):
- # Create an org owned by `devtable`.
- org = create_organization(orgname, orgname + '@example.com', first_user)
+ def run_invite_flow(orgname):
+ # Create an org owned by `devtable`.
+ org = create_organization(orgname, orgname + "@example.com", first_user)
- # Create another team and add `devtable` to it. Since `devtable` is already
- # in the org, it should be done directly.
- other_team = create_team('otherteam', org, 'admin')
- invite = add_or_invite_to_team(first_user, other_team, user_obj=first_user)
- assert invite is None
- assert is_in_team(other_team, first_user)
+ # Create another team and add `devtable` to it. Since `devtable` is already
+ # in the org, it should be done directly.
+ other_team = create_team("otherteam", org, "admin")
+ invite = add_or_invite_to_team(first_user, other_team, user_obj=first_user)
+ assert invite is None
+ assert is_in_team(other_team, first_user)
- # Try to add `newuser` to the team, which should require an invite.
- invite = add_or_invite_to_team(first_user, other_team, user_obj=second_user)
- assert invite is not None
- assert not is_in_team(other_team, second_user)
+ # Try to add `newuser` to the team, which should require an invite.
+ invite = add_or_invite_to_team(first_user, other_team, user_obj=second_user)
+ assert invite is not None
+ assert not is_in_team(other_team, second_user)
- # Accept the invite.
- confirm_team_invite(invite.invite_token, second_user)
- assert is_in_team(other_team, second_user)
+ # Accept the invite.
+ confirm_team_invite(invite.invite_token, second_user)
+ assert is_in_team(other_team, second_user)
- # Run for a new org.
- run_invite_flow('firstorg')
+ # Run for a new org.
+ run_invite_flow("firstorg")
- # Create another org and repeat, ensuring the same operations perform the same way.
- run_invite_flow('secondorg')
+ # Create another org and repeat, ensuring the same operations perform the same way.
+ run_invite_flow("secondorg")
diff --git a/data/model/test/test_user.py b/data/model/test/test_user.py
index 4f124b7f3..6b878a47e 100644
--- a/data/model/test/test_user.py
+++ b/data/model/test/test_user.py
@@ -21,185 +21,188 @@ from util.timedeltastring import convert_to_timedelta
from util.security.token import encode_public_private_token
from test.fixtures import *
+
def test_create_user_with_expiration(initialized_db):
- with patch('data.model.config.app_config', {'DEFAULT_TAG_EXPIRATION': '1h'}):
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
- assert user.removed_tag_expiration_s == 60 * 60
+ with patch("data.model.config.app_config", {"DEFAULT_TAG_EXPIRATION": "1h"}):
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
+ assert user.removed_tag_expiration_s == 60 * 60
-@pytest.mark.parametrize('token_lifetime, time_since', [
- ('1m', '2m'),
- ('2m', '1m'),
- ('1h', '1m'),
-])
+
+@pytest.mark.parametrize(
+ "token_lifetime, time_since", [("1m", "2m"), ("2m", "1m"), ("1h", "1m")]
+)
def test_validation_code(token_lifetime, time_since, initialized_db):
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
- created = datetime.now() - convert_to_timedelta(time_since)
- verification_code, unhashed = Credential.generate()
- confirmation = EmailConfirmation.create(user=user, pw_reset=True,
- created=created, verification_code=verification_code)
- encoded = encode_public_private_token(confirmation.code, unhashed)
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
+ created = datetime.now() - convert_to_timedelta(time_since)
+ verification_code, unhashed = Credential.generate()
+ confirmation = EmailConfirmation.create(
+ user=user, pw_reset=True, created=created, verification_code=verification_code
+ )
+ encoded = encode_public_private_token(confirmation.code, unhashed)
- with patch('data.model.config.app_config', {'USER_RECOVERY_TOKEN_LIFETIME': token_lifetime}):
- result = validate_reset_code(encoded)
- expect_success = convert_to_timedelta(token_lifetime) >= convert_to_timedelta(time_since)
- assert expect_success == (result is not None)
+ with patch(
+ "data.model.config.app_config", {"USER_RECOVERY_TOKEN_LIFETIME": token_lifetime}
+ ):
+ result = validate_reset_code(encoded)
+ expect_success = convert_to_timedelta(token_lifetime) >= convert_to_timedelta(
+ time_since
+ )
+ assert expect_success == (result is not None)
-@pytest.mark.parametrize('disabled', [
- (True),
- (False),
-])
-@pytest.mark.parametrize('deleted', [
- (True),
- (False),
-])
+@pytest.mark.parametrize("disabled", [(True), (False)])
+@pytest.mark.parametrize("deleted", [(True), (False)])
def test_get_active_users(disabled, deleted, initialized_db):
- # Delete a user.
- deleted_user = model.user.get_user('public')
- queue = WorkQueue('testgcnamespace', lambda db: db.transaction())
- mark_namespace_for_deletion(deleted_user, [], queue)
+ # Delete a user.
+ deleted_user = model.user.get_user("public")
+ queue = WorkQueue("testgcnamespace", lambda db: db.transaction())
+ mark_namespace_for_deletion(deleted_user, [], queue)
- users = get_active_users(disabled=disabled, deleted=deleted)
- deleted_found = [user for user in users if user.id == deleted_user.id]
- assert bool(deleted_found) == (deleted and disabled)
+ users = get_active_users(disabled=disabled, deleted=deleted)
+ deleted_found = [user for user in users if user.id == deleted_user.id]
+ assert bool(deleted_found) == (deleted and disabled)
- for user in users:
- if not disabled:
- assert user.enabled
+ for user in users:
+ if not disabled:
+ assert user.enabled
def test_mark_namespace_for_deletion(initialized_db):
- def create_transaction(db):
- return db.transaction()
+ def create_transaction(db):
+ return db.transaction()
- # Create a user and then mark it for deletion.
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
+ # Create a user and then mark it for deletion.
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
- # Add some robots.
- create_robot('foo', user)
- create_robot('bar', user)
+ # Add some robots.
+ create_robot("foo", user)
+ create_robot("bar", user)
- assert lookup_robot('foobar+foo') is not None
- assert lookup_robot('foobar+bar') is not None
- assert len(list(list_namespace_robots('foobar'))) == 2
+ assert lookup_robot("foobar+foo") is not None
+ assert lookup_robot("foobar+bar") is not None
+ assert len(list(list_namespace_robots("foobar"))) == 2
- # Mark the user for deletion.
- queue = WorkQueue('testgcnamespace', create_transaction)
- mark_namespace_for_deletion(user, [], queue)
+ # Mark the user for deletion.
+ queue = WorkQueue("testgcnamespace", create_transaction)
+ mark_namespace_for_deletion(user, [], queue)
- # Ensure the older user is still in the DB.
- older_user = User.get(id=user.id)
- assert older_user.username != 'foobar'
+ # Ensure the older user is still in the DB.
+ older_user = User.get(id=user.id)
+ assert older_user.username != "foobar"
- # Ensure the robots are deleted.
- with pytest.raises(InvalidRobotException):
- assert lookup_robot('foobar+foo')
+ # Ensure the robots are deleted.
+ with pytest.raises(InvalidRobotException):
+ assert lookup_robot("foobar+foo")
- with pytest.raises(InvalidRobotException):
- assert lookup_robot('foobar+bar')
+ with pytest.raises(InvalidRobotException):
+ assert lookup_robot("foobar+bar")
- assert len(list(list_namespace_robots(older_user.username))) == 0
+ assert len(list(list_namespace_robots(older_user.username))) == 0
- # Ensure we can create a user with the same namespace again.
- new_user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
- assert new_user.id != user.id
+ # Ensure we can create a user with the same namespace again.
+ new_user = create_user_noverify("foobar", "foo@example.com", email_required=False)
+ assert new_user.id != user.id
- # Ensure the older user is still in the DB.
- assert User.get(id=user.id).username != 'foobar'
+ # Ensure the older user is still in the DB.
+ assert User.get(id=user.id).username != "foobar"
def test_delete_namespace_via_marker(initialized_db):
- def create_transaction(db):
- return db.transaction()
+ def create_transaction(db):
+ return db.transaction()
- # Create a user and then mark it for deletion.
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
+ # Create a user and then mark it for deletion.
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
- # Add some repositories.
- create_repository('foobar', 'somerepo', user)
- create_repository('foobar', 'anotherrepo', user)
+ # Add some repositories.
+ create_repository("foobar", "somerepo", user)
+ create_repository("foobar", "anotherrepo", user)
- # Mark the user for deletion.
- queue = WorkQueue('testgcnamespace', create_transaction)
- marker_id = mark_namespace_for_deletion(user, [], queue)
+ # Mark the user for deletion.
+ queue = WorkQueue("testgcnamespace", create_transaction)
+ marker_id = mark_namespace_for_deletion(user, [], queue)
- # Delete the user.
- delete_namespace_via_marker(marker_id, [])
+ # Delete the user.
+ delete_namespace_via_marker(marker_id, [])
- # Ensure the user was actually deleted.
- with pytest.raises(User.DoesNotExist):
- User.get(id=user.id)
+ # Ensure the user was actually deleted.
+ with pytest.raises(User.DoesNotExist):
+ User.get(id=user.id)
- with pytest.raises(DeletedNamespace.DoesNotExist):
- DeletedNamespace.get(id=marker_id)
+ with pytest.raises(DeletedNamespace.DoesNotExist):
+ DeletedNamespace.get(id=marker_id)
def test_delete_robot(initialized_db):
- # Create a robot account.
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
- robot, _ = create_robot('foo', user)
+ # Create a robot account.
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
+ robot, _ = create_robot("foo", user)
- # Add some notifications and other rows pointing to the robot.
- create_notification('repo_push', robot)
+ # Add some notifications and other rows pointing to the robot.
+ create_notification("repo_push", robot)
- team = create_team('someteam', get_organization('buynlarge'), 'member')
- add_user_to_team(robot, team)
+ team = create_team("someteam", get_organization("buynlarge"), "member")
+ add_user_to_team(robot, team)
- # Ensure the robot exists.
- assert lookup_robot(robot.username).id == robot.id
+ # Ensure the robot exists.
+ assert lookup_robot(robot.username).id == robot.id
- # Delete the robot.
- delete_robot(robot.username)
+ # Delete the robot.
+ delete_robot(robot.username)
- # Ensure it is gone.
- with pytest.raises(InvalidRobotException):
- lookup_robot(robot.username)
+ # Ensure it is gone.
+ with pytest.raises(InvalidRobotException):
+ lookup_robot(robot.username)
def test_get_matching_users(initialized_db):
- # Exact match.
- for user in User.select().where(User.organization == False, User.robot == False):
- assert list(get_matching_users(user.username))[0].username == user.username
+ # Exact match.
+ for user in User.select().where(User.organization == False, User.robot == False):
+ assert list(get_matching_users(user.username))[0].username == user.username
- # Prefix matching.
- for user in User.select().where(User.organization == False, User.robot == False):
- assert user.username in [r.username for r in get_matching_users(user.username[:2])]
+ # Prefix matching.
+ for user in User.select().where(User.organization == False, User.robot == False):
+ assert user.username in [
+ r.username for r in get_matching_users(user.username[:2])
+ ]
def test_get_matching_users_with_same_prefix(initialized_db):
- # Create a bunch of users with the same prefix.
- for index in range(0, 20):
- create_user_noverify('foo%s' % index, 'foo%s@example.com' % index, email_required=False)
+ # Create a bunch of users with the same prefix.
+ for index in range(0, 20):
+ create_user_noverify(
+ "foo%s" % index, "foo%s@example.com" % index, email_required=False
+ )
- # For each user, ensure that lookup of the exact name is found first.
- for index in range(0, 20):
- username = 'foo%s' % index
- assert list(get_matching_users(username))[0].username == username
+ # For each user, ensure that lookup of the exact name is found first.
+ for index in range(0, 20):
+ username = "foo%s" % index
+ assert list(get_matching_users(username))[0].username == username
- # Prefix matching.
- found = list(get_matching_users('foo', limit=50))
- assert len(found) == 20
+ # Prefix matching.
+ found = list(get_matching_users("foo", limit=50))
+ assert len(found) == 20
def test_robot(initialized_db):
- # Create a robot account.
- user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
- robot, token = create_robot('foo', user)
- assert retrieve_robot_token(robot) == token
+ # Create a robot account.
+ user = create_user_noverify("foobar", "foo@example.com", email_required=False)
+ robot, token = create_robot("foo", user)
+ assert retrieve_robot_token(robot) == token
- # Ensure we can retrieve its information.
- found = lookup_robot('foobar+foo')
- assert found == robot
+ # Ensure we can retrieve its information.
+ found = lookup_robot("foobar+foo")
+ assert found == robot
- creds = get_pull_credentials('foobar+foo')
- assert creds is not None
- assert creds['username'] == 'foobar+foo'
- assert creds['password'] == token
+ creds = get_pull_credentials("foobar+foo")
+ assert creds is not None
+ assert creds["username"] == "foobar+foo"
+ assert creds["password"] == token
- assert verify_robot('foobar+foo', token) == robot
+ assert verify_robot("foobar+foo", token) == robot
- with pytest.raises(InvalidRobotException):
- assert verify_robot('foobar+foo', 'someothertoken')
+ with pytest.raises(InvalidRobotException):
+ assert verify_robot("foobar+foo", "someothertoken")
- with pytest.raises(InvalidRobotException):
- assert verify_robot('foobar+unknownbot', token)
+ with pytest.raises(InvalidRobotException):
+ assert verify_robot("foobar+unknownbot", token)
diff --git a/data/model/test/test_visible_repos.py b/data/model/test/test_visible_repos.py
index 9e5e7cbf5..b037a2d6f 100644
--- a/data/model/test/test_visible_repos.py
+++ b/data/model/test/test_visible_repos.py
@@ -3,87 +3,89 @@ from data import model
from test.fixtures import *
-NO_ACCESS_USER = 'freshuser'
-READ_ACCESS_USER = 'reader'
-ADMIN_ACCESS_USER = 'devtable'
-PUBLIC_USER = 'public'
-RANDOM_USER = 'randomuser'
-OUTSIDE_ORG_USER = 'outsideorg'
+NO_ACCESS_USER = "freshuser"
+READ_ACCESS_USER = "reader"
+ADMIN_ACCESS_USER = "devtable"
+PUBLIC_USER = "public"
+RANDOM_USER = "randomuser"
+OUTSIDE_ORG_USER = "outsideorg"
-ADMIN_ROBOT_USER = 'devtable+dtrobot'
+ADMIN_ROBOT_USER = "devtable+dtrobot"
-ORGANIZATION = 'buynlarge'
+ORGANIZATION = "buynlarge"
-SIMPLE_REPO = 'simple'
-PUBLIC_REPO = 'publicrepo'
-RANDOM_REPO = 'randomrepo'
+SIMPLE_REPO = "simple"
+PUBLIC_REPO = "publicrepo"
+RANDOM_REPO = "randomrepo"
-OUTSIDE_ORG_REPO = 'coolrepo'
+OUTSIDE_ORG_REPO = "coolrepo"
-ORG_REPO = 'orgrepo'
-ANOTHER_ORG_REPO = 'anotherorgrepo'
+ORG_REPO = "orgrepo"
+ANOTHER_ORG_REPO = "anotherorgrepo"
# Note: The shared repo has devtable as admin, public as a writer and reader as a reader.
-SHARED_REPO = 'shared'
+SHARED_REPO = "shared"
def assertDoesNotHaveRepo(username, name):
- repos = list(model.repository.get_visible_repositories(username))
- names = [repo.name for repo in repos]
- assert not name in names
+ repos = list(model.repository.get_visible_repositories(username))
+ names = [repo.name for repo in repos]
+ assert not name in names
def assertHasRepo(username, name):
- repos = list(model.repository.get_visible_repositories(username))
- names = [repo.name for repo in repos]
- assert name in names
+ repos = list(model.repository.get_visible_repositories(username))
+ names = [repo.name for repo in repos]
+ assert name in names
def test_noaccess(initialized_db):
- repos = list(model.repository.get_visible_repositories(NO_ACCESS_USER))
- names = [repo.name for repo in repos]
- assert not names
+ repos = list(model.repository.get_visible_repositories(NO_ACCESS_USER))
+ names = [repo.name for repo in repos]
+ assert not names
- # Try retrieving public repos now.
- repos = list(model.repository.get_visible_repositories(NO_ACCESS_USER, include_public=True))
- names = [repo.name for repo in repos]
- assert PUBLIC_REPO in names
+ # Try retrieving public repos now.
+ repos = list(
+ model.repository.get_visible_repositories(NO_ACCESS_USER, include_public=True)
+ )
+ names = [repo.name for repo in repos]
+ assert PUBLIC_REPO in names
def test_public(initialized_db):
- assertHasRepo(PUBLIC_USER, PUBLIC_REPO)
- assertHasRepo(PUBLIC_USER, SHARED_REPO)
+ assertHasRepo(PUBLIC_USER, PUBLIC_REPO)
+ assertHasRepo(PUBLIC_USER, SHARED_REPO)
- assertDoesNotHaveRepo(PUBLIC_USER, SIMPLE_REPO)
- assertDoesNotHaveRepo(PUBLIC_USER, RANDOM_REPO)
- assertDoesNotHaveRepo(PUBLIC_USER, OUTSIDE_ORG_REPO)
+ assertDoesNotHaveRepo(PUBLIC_USER, SIMPLE_REPO)
+ assertDoesNotHaveRepo(PUBLIC_USER, RANDOM_REPO)
+ assertDoesNotHaveRepo(PUBLIC_USER, OUTSIDE_ORG_REPO)
def test_reader(initialized_db):
- assertHasRepo(READ_ACCESS_USER, SHARED_REPO)
- assertHasRepo(READ_ACCESS_USER, ORG_REPO)
+ assertHasRepo(READ_ACCESS_USER, SHARED_REPO)
+ assertHasRepo(READ_ACCESS_USER, ORG_REPO)
- assertDoesNotHaveRepo(READ_ACCESS_USER, SIMPLE_REPO)
- assertDoesNotHaveRepo(READ_ACCESS_USER, RANDOM_REPO)
- assertDoesNotHaveRepo(READ_ACCESS_USER, OUTSIDE_ORG_REPO)
- assertDoesNotHaveRepo(READ_ACCESS_USER, PUBLIC_REPO)
+ assertDoesNotHaveRepo(READ_ACCESS_USER, SIMPLE_REPO)
+ assertDoesNotHaveRepo(READ_ACCESS_USER, RANDOM_REPO)
+ assertDoesNotHaveRepo(READ_ACCESS_USER, OUTSIDE_ORG_REPO)
+ assertDoesNotHaveRepo(READ_ACCESS_USER, PUBLIC_REPO)
def test_random(initialized_db):
- assertHasRepo(RANDOM_USER, RANDOM_REPO)
+ assertHasRepo(RANDOM_USER, RANDOM_REPO)
- assertDoesNotHaveRepo(RANDOM_USER, SIMPLE_REPO)
- assertDoesNotHaveRepo(RANDOM_USER, SHARED_REPO)
- assertDoesNotHaveRepo(RANDOM_USER, ORG_REPO)
- assertDoesNotHaveRepo(RANDOM_USER, ANOTHER_ORG_REPO)
- assertDoesNotHaveRepo(RANDOM_USER, PUBLIC_REPO)
+ assertDoesNotHaveRepo(RANDOM_USER, SIMPLE_REPO)
+ assertDoesNotHaveRepo(RANDOM_USER, SHARED_REPO)
+ assertDoesNotHaveRepo(RANDOM_USER, ORG_REPO)
+ assertDoesNotHaveRepo(RANDOM_USER, ANOTHER_ORG_REPO)
+ assertDoesNotHaveRepo(RANDOM_USER, PUBLIC_REPO)
def test_admin(initialized_db):
- assertHasRepo(ADMIN_ACCESS_USER, SIMPLE_REPO)
- assertHasRepo(ADMIN_ACCESS_USER, SHARED_REPO)
+ assertHasRepo(ADMIN_ACCESS_USER, SIMPLE_REPO)
+ assertHasRepo(ADMIN_ACCESS_USER, SHARED_REPO)
- assertHasRepo(ADMIN_ACCESS_USER, ORG_REPO)
- assertHasRepo(ADMIN_ACCESS_USER, ANOTHER_ORG_REPO)
+ assertHasRepo(ADMIN_ACCESS_USER, ORG_REPO)
+ assertHasRepo(ADMIN_ACCESS_USER, ANOTHER_ORG_REPO)
- assertDoesNotHaveRepo(ADMIN_ACCESS_USER, OUTSIDE_ORG_REPO)
+ assertDoesNotHaveRepo(ADMIN_ACCESS_USER, OUTSIDE_ORG_REPO)
diff --git a/data/model/token.py b/data/model/token.py
index 82661cdef..cb6f91ca8 100644
--- a/data/model/token.py
+++ b/data/model/token.py
@@ -3,8 +3,14 @@ import logging
from peewee import JOIN
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from data.database import (AccessToken, AccessTokenKind, Repository, Namespace, Role,
- RepositoryBuildTrigger)
+from data.database import (
+ AccessToken,
+ AccessTokenKind,
+ Repository,
+ Namespace,
+ Role,
+ RepositoryBuildTrigger,
+)
from data.model import DataModelException, _basequery, InvalidTokenException
@@ -16,90 +22,97 @@ ACCESS_TOKEN_CODE_MINIMUM_LENGTH = 32
def create_access_token(repo, role, kind=None, friendly_name=None):
- role = Role.get(Role.name == role)
- kind_ref = None
- if kind is not None:
- kind_ref = AccessTokenKind.get(AccessTokenKind.name == kind)
+ role = Role.get(Role.name == role)
+ kind_ref = None
+ if kind is not None:
+ kind_ref = AccessTokenKind.get(AccessTokenKind.name == kind)
- new_token = AccessToken.create(repository=repo, temporary=True, role=role, kind=kind_ref,
- friendly_name=friendly_name)
+ new_token = AccessToken.create(
+ repository=repo,
+ temporary=True,
+ role=role,
+ kind=kind_ref,
+ friendly_name=friendly_name,
+ )
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- new_token.code = new_token.token_name + new_token.token_code.decrypt()
- new_token.save()
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ new_token.code = new_token.token_name + new_token.token_code.decrypt()
+ new_token.save()
- return new_token
+ return new_token
-def create_delegate_token(namespace_name, repository_name, friendly_name,
- role='read'):
- read_only = Role.get(name=role)
- repo = _basequery.get_existing_repository(namespace_name, repository_name)
- new_token = AccessToken.create(repository=repo, role=read_only,
- friendly_name=friendly_name, temporary=False)
+def create_delegate_token(namespace_name, repository_name, friendly_name, role="read"):
+ read_only = Role.get(name=role)
+ repo = _basequery.get_existing_repository(namespace_name, repository_name)
+ new_token = AccessToken.create(
+ repository=repo, role=read_only, friendly_name=friendly_name, temporary=False
+ )
- if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
- new_token.code = new_token.token_name + new_token.token_code.decrypt()
- new_token.save()
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
+ new_token.code = new_token.token_name + new_token.token_code.decrypt()
+ new_token.save()
- return new_token
+ return new_token
def load_token_data(code):
- """ Load the permissions for any token by code. """
- token_name = code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
- token_code = code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
+ """ Load the permissions for any token by code. """
+ token_name = code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
+ token_code = code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
- if not token_name or not token_code:
- raise InvalidTokenException('Invalid delegate token code: %s' % code)
+ if not token_name or not token_code:
+ raise InvalidTokenException("Invalid delegate token code: %s" % code)
- # Try loading by name and then comparing the code.
- assert token_name
- try:
- found = (AccessToken
- .select(AccessToken, Repository, Namespace, Role)
- .join(Role)
- .switch(AccessToken)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(AccessToken.token_name == token_name)
- .get())
-
- assert token_code
- if found.token_code is None or not found.token_code.matches(token_code):
- raise InvalidTokenException('Invalid delegate token code: %s' % code)
-
- assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
- return found
- except AccessToken.DoesNotExist:
- pass
-
- # Legacy: Try loading the full code directly.
- # TODO(remove-unenc): Remove this once migrated.
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ # Try loading by name and then comparing the code.
+ assert token_name
try:
- return (AccessToken
- .select(AccessToken, Repository, Namespace, Role)
- .join(Role)
- .switch(AccessToken)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(AccessToken.code == code)
- .get())
- except AccessToken.DoesNotExist:
- raise InvalidTokenException('Invalid delegate token code: %s' % code)
+ found = (
+ AccessToken.select(AccessToken, Repository, Namespace, Role)
+ .join(Role)
+ .switch(AccessToken)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(AccessToken.token_name == token_name)
+ .get()
+ )
- raise InvalidTokenException('Invalid delegate token code: %s' % code)
+ assert token_code
+ if found.token_code is None or not found.token_code.matches(token_code):
+ raise InvalidTokenException("Invalid delegate token code: %s" % code)
+
+ assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
+ return found
+ except AccessToken.DoesNotExist:
+ pass
+
+ # Legacy: Try loading the full code directly.
+ # TODO(remove-unenc): Remove this once migrated.
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ try:
+ return (
+ AccessToken.select(AccessToken, Repository, Namespace, Role)
+ .join(Role)
+ .switch(AccessToken)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(AccessToken.code == code)
+ .get()
+ )
+ except AccessToken.DoesNotExist:
+ raise InvalidTokenException("Invalid delegate token code: %s" % code)
+
+ raise InvalidTokenException("Invalid delegate token code: %s" % code)
def get_full_token_string(token):
- """ Returns the full string to use for this token to login. """
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- if token.token_name is None:
- return token.code
+ """ Returns the full string to use for this token to login. """
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ if token.token_name is None:
+ return token.code
- assert token.token_name
- token_code = token.token_code.decrypt()
- assert len(token.token_name) == ACCESS_TOKEN_NAME_PREFIX_LENGTH
- assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
- return '%s%s' % (token.token_name, token_code)
+ assert token.token_name
+ token_code = token.token_code.decrypt()
+ assert len(token.token_name) == ACCESS_TOKEN_NAME_PREFIX_LENGTH
+ assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
+ return "%s%s" % (token.token_name, token_code)
diff --git a/data/model/user.py b/data/model/user.py
index 7e9ed81b1..dd8c549cc 100644
--- a/data/model/user.py
+++ b/data/model/user.py
@@ -9,24 +9,60 @@ from uuid import uuid4
from datetime import datetime, timedelta
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from data.database import (User, LoginService, FederatedLogin, RepositoryPermission, TeamMember,
- Team, Repository, TupleSelector, TeamRole, Namespace, Visibility,
- EmailConfirmation, Role, db_for_update, random_string_generator,
- UserRegion, ImageStorageLocation,
- ServiceKeyApproval, OAuthApplication, RepositoryBuildTrigger,
- UserPromptKind, UserPrompt, UserPromptTypes, DeletedNamespace,
- RobotAccountMetadata, NamespaceGeoRestriction, RepoMirrorConfig,
- RobotAccountToken)
+from data.database import (
+ User,
+ LoginService,
+ FederatedLogin,
+ RepositoryPermission,
+ TeamMember,
+ Team,
+ Repository,
+ TupleSelector,
+ TeamRole,
+ Namespace,
+ Visibility,
+ EmailConfirmation,
+ Role,
+ db_for_update,
+ random_string_generator,
+ UserRegion,
+ ImageStorageLocation,
+ ServiceKeyApproval,
+ OAuthApplication,
+ RepositoryBuildTrigger,
+ UserPromptKind,
+ UserPrompt,
+ UserPromptTypes,
+ DeletedNamespace,
+ RobotAccountMetadata,
+ NamespaceGeoRestriction,
+ RepoMirrorConfig,
+ RobotAccountToken,
+)
from data.readreplica import ReadOnlyModeException
-from data.model import (DataModelException, InvalidPasswordException, InvalidRobotException,
- InvalidUsernameException, InvalidEmailAddressException,
- TooManyLoginAttemptsException, db_transaction,
- notification, config, repository, _basequery, gc)
+from data.model import (
+ DataModelException,
+ InvalidPasswordException,
+ InvalidRobotException,
+ InvalidUsernameException,
+ InvalidEmailAddressException,
+ TooManyLoginAttemptsException,
+ db_transaction,
+ notification,
+ config,
+ repository,
+ _basequery,
+ gc,
+)
from data.fields import Credential
from data.text import prefix_search
from util.names import format_robot_username, parse_robot_username
-from util.validation import (validate_username, validate_email, validate_password,
- INVALID_PASSWORD_MESSAGE)
+from util.validation import (
+ validate_username,
+ validate_email,
+ validate_password,
+ INVALID_PASSWORD_MESSAGE,
+)
from util.backoff import exponential_backoff
from util.timedeltastring import convert_to_timedelta
from util.unicode import remove_unicode
@@ -38,1180 +74,1394 @@ logger = logging.getLogger(__name__)
EXPONENTIAL_BACKOFF_SCALE = timedelta(seconds=1)
+
def hash_password(password, salt=None):
- salt = salt or bcrypt.gensalt()
- return bcrypt.hashpw(password.encode('utf-8'), salt)
-
-def create_user(username, password, email, auto_verify=False, email_required=True, prompts=tuple(),
- is_possible_abuser=False):
- """ Creates a regular user, if allowed. """
- if not validate_password(password):
- raise InvalidPasswordException(INVALID_PASSWORD_MESSAGE)
-
- created = create_user_noverify(username, email, email_required=email_required, prompts=prompts,
- is_possible_abuser=is_possible_abuser)
- created.password_hash = hash_password(password)
- created.verified = auto_verify
- created.save()
-
- return created
+ salt = salt or bcrypt.gensalt()
+ return bcrypt.hashpw(password.encode("utf-8"), salt)
-def create_user_noverify(username, email, email_required=True, prompts=tuple(),
- is_possible_abuser=False):
- if email_required:
- if not validate_email(email):
- raise InvalidEmailAddressException('Invalid email address: %s' % email)
- else:
- # If email addresses are not required and none was specified, then we just use a unique
- # ID to ensure that the database consistency check remains intact.
- email = email or str(uuid.uuid4())
+def create_user(
+ username,
+ password,
+ email,
+ auto_verify=False,
+ email_required=True,
+ prompts=tuple(),
+ is_possible_abuser=False,
+):
+ """ Creates a regular user, if allowed. """
+ if not validate_password(password):
+ raise InvalidPasswordException(INVALID_PASSWORD_MESSAGE)
- (username_valid, username_issue) = validate_username(username)
- if not username_valid:
- raise InvalidUsernameException('Invalid namespace %s: %s' % (username, username_issue))
+ created = create_user_noverify(
+ username,
+ email,
+ email_required=email_required,
+ prompts=prompts,
+ is_possible_abuser=is_possible_abuser,
+ )
+ created.password_hash = hash_password(password)
+ created.verified = auto_verify
+ created.save()
- try:
- existing = User.get((User.username == username) | (User.email == email))
- logger.info('Existing user with same username or email.')
+ return created
- # A user already exists with either the same username or email
- if existing.username == username:
- assert not existing.robot
- msg = 'Username has already been taken by an organization and cannot be reused: %s' % username
- if not existing.organization:
- msg = 'Username has already been taken by user cannot be reused: %s' % username
+def create_user_noverify(
+ username, email, email_required=True, prompts=tuple(), is_possible_abuser=False
+):
+ if email_required:
+ if not validate_email(email):
+ raise InvalidEmailAddressException("Invalid email address: %s" % email)
+ else:
+ # If email addresses are not required and none was specified, then we just use a unique
+ # ID to ensure that the database consistency check remains intact.
+ email = email or str(uuid.uuid4())
- raise InvalidUsernameException(msg)
+ (username_valid, username_issue) = validate_username(username)
+ if not username_valid:
+ raise InvalidUsernameException(
+ "Invalid namespace %s: %s" % (username, username_issue)
+ )
- raise InvalidEmailAddressException('Email has already been used: %s' % email)
- except User.DoesNotExist:
- # This is actually the happy path
- logger.debug('Email and username are unique!')
+ try:
+ existing = User.get((User.username == username) | (User.email == email))
+ logger.info("Existing user with same username or email.")
- # Create the user.
- try:
- default_expr_s = _convert_to_s(config.app_config['DEFAULT_TAG_EXPIRATION'])
- default_max_builds = config.app_config.get('DEFAULT_NAMESPACE_MAXIMUM_BUILD_COUNT')
- threat_max_builds = config.app_config.get('THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT')
+ # A user already exists with either the same username or email
+ if existing.username == username:
+ assert not existing.robot
- if is_possible_abuser and threat_max_builds is not None:
- default_max_builds = threat_max_builds
+ msg = (
+ "Username has already been taken by an organization and cannot be reused: %s"
+ % username
+ )
+ if not existing.organization:
+ msg = (
+ "Username has already been taken by user cannot be reused: %s"
+ % username
+ )
- new_user = User.create(username=username, email=email, removed_tag_expiration_s=default_expr_s,
- maximum_queued_builds_count=default_max_builds)
- for prompt in prompts:
- create_user_prompt(new_user, prompt)
+ raise InvalidUsernameException(msg)
+
+ raise InvalidEmailAddressException("Email has already been used: %s" % email)
+ except User.DoesNotExist:
+ # This is actually the happy path
+ logger.debug("Email and username are unique!")
+
+ # Create the user.
+ try:
+ default_expr_s = _convert_to_s(config.app_config["DEFAULT_TAG_EXPIRATION"])
+ default_max_builds = config.app_config.get(
+ "DEFAULT_NAMESPACE_MAXIMUM_BUILD_COUNT"
+ )
+ threat_max_builds = config.app_config.get(
+ "THREAT_NAMESPACE_MAXIMUM_BUILD_COUNT"
+ )
+
+ if is_possible_abuser and threat_max_builds is not None:
+ default_max_builds = threat_max_builds
+
+ new_user = User.create(
+ username=username,
+ email=email,
+ removed_tag_expiration_s=default_expr_s,
+ maximum_queued_builds_count=default_max_builds,
+ )
+ for prompt in prompts:
+ create_user_prompt(new_user, prompt)
+
+ return new_user
+ except Exception as ex:
+ raise DataModelException(ex.message)
- return new_user
- except Exception as ex:
- raise DataModelException(ex.message)
def increase_maximum_build_count(user, maximum_queued_builds_count):
- """ Increases the maximum number of allowed builds on the namespace, if greater than that
+ """ Increases the maximum number of allowed builds on the namespace, if greater than that
already present.
"""
- if (user.maximum_queued_builds_count is not None and
- maximum_queued_builds_count > user.maximum_queued_builds_count):
- user.maximum_queued_builds_count = maximum_queued_builds_count
- user.save()
+ if (
+ user.maximum_queued_builds_count is not None
+ and maximum_queued_builds_count > user.maximum_queued_builds_count
+ ):
+ user.maximum_queued_builds_count = maximum_queued_builds_count
+ user.save()
+
def is_username_unique(test_username):
- try:
- User.get((User.username == test_username))
- return False
- except User.DoesNotExist:
- return True
+ try:
+ User.get((User.username == test_username))
+ return False
+ except User.DoesNotExist:
+ return True
def change_password(user, new_password):
- if not validate_password(new_password):
- raise InvalidPasswordException(INVALID_PASSWORD_MESSAGE)
+ if not validate_password(new_password):
+ raise InvalidPasswordException(INVALID_PASSWORD_MESSAGE)
- pw_hash = hash_password(new_password)
- user.invalid_login_attempts = 0
- user.password_hash = pw_hash
- invalidate_all_sessions(user)
+ pw_hash = hash_password(new_password)
+ user.invalid_login_attempts = 0
+ user.password_hash = pw_hash
+ invalidate_all_sessions(user)
- # Remove any password required notifications for the user.
- notification.delete_notifications_by_kind(user, 'password_required')
+ # Remove any password required notifications for the user.
+ notification.delete_notifications_by_kind(user, "password_required")
def get_default_user_prompts(features):
- prompts = set()
- if features.USER_METADATA:
- prompts.add(UserPromptTypes.ENTER_NAME)
- prompts.add(UserPromptTypes.ENTER_COMPANY)
+ prompts = set()
+ if features.USER_METADATA:
+ prompts.add(UserPromptTypes.ENTER_NAME)
+ prompts.add(UserPromptTypes.ENTER_COMPANY)
- return prompts
+ return prompts
def has_user_prompts(user):
- try:
- UserPrompt.select().where(UserPrompt.user == user).get()
- return True
- except UserPrompt.DoesNotExist:
- return False
+ try:
+ UserPrompt.select().where(UserPrompt.user == user).get()
+ return True
+ except UserPrompt.DoesNotExist:
+ return False
def has_user_prompt(user, prompt_name):
- prompt_kind = UserPromptKind.get(name=prompt_name)
+ prompt_kind = UserPromptKind.get(name=prompt_name)
- try:
- UserPrompt.get(user=user, kind=prompt_kind)
- return True
- except UserPrompt.DoesNotExist:
- return False
+ try:
+ UserPrompt.get(user=user, kind=prompt_kind)
+ return True
+ except UserPrompt.DoesNotExist:
+ return False
def create_user_prompt(user, prompt_name):
- prompt_kind = UserPromptKind.get(name=prompt_name)
- return UserPrompt.create(user=user, kind=prompt_kind)
+ prompt_kind = UserPromptKind.get(name=prompt_name)
+ return UserPrompt.create(user=user, kind=prompt_kind)
def remove_user_prompt(user, prompt_name):
- prompt_kind = UserPromptKind.get(name=prompt_name)
- UserPrompt.delete().where(UserPrompt.user == user, UserPrompt.kind == prompt_kind).execute()
+ prompt_kind = UserPromptKind.get(name=prompt_name)
+ UserPrompt.delete().where(
+ UserPrompt.user == user, UserPrompt.kind == prompt_kind
+ ).execute()
def get_user_prompts(user):
- query = UserPrompt.select().where(UserPrompt.user == user).join(UserPromptKind)
- return [prompt.kind.name for prompt in query]
+ query = UserPrompt.select().where(UserPrompt.user == user).join(UserPromptKind)
+ return [prompt.kind.name for prompt in query]
def change_username(user_id, new_username):
- (username_valid, username_issue) = validate_username(new_username)
- if not username_valid:
- raise InvalidUsernameException('Invalid username %s: %s' % (new_username, username_issue))
+ (username_valid, username_issue) = validate_username(new_username)
+ if not username_valid:
+ raise InvalidUsernameException(
+ "Invalid username %s: %s" % (new_username, username_issue)
+ )
- with db_transaction():
- # Reload the user for update
- user = db_for_update(User.select().where(User.id == user_id)).get()
+ with db_transaction():
+ # Reload the user for update
+ user = db_for_update(User.select().where(User.id == user_id)).get()
- # Rename the robots
- for robot in db_for_update(_list_entity_robots(user.username, include_metadata=False,
- include_token=False)):
- _, robot_shortname = parse_robot_username(robot.username)
- new_robot_name = format_robot_username(new_username, robot_shortname)
- robot.username = new_robot_name
- robot.save()
+ # Rename the robots
+ for robot in db_for_update(
+ _list_entity_robots(
+ user.username, include_metadata=False, include_token=False
+ )
+ ):
+ _, robot_shortname = parse_robot_username(robot.username)
+ new_robot_name = format_robot_username(new_username, robot_shortname)
+ robot.username = new_robot_name
+ robot.save()
- # Rename the user
- user.username = new_username
+ # Rename the user
+ user.username = new_username
+ user.save()
+
+ # Remove any prompts for username.
+ remove_user_prompt(user, "confirm_username")
+
+ return user
+
+
+def change_invoice_email_address(user, invoice_email_address):
+ # Note: We null out the address if it is an empty string.
+ user.invoice_email_address = invoice_email_address or None
user.save()
- # Remove any prompts for username.
- remove_user_prompt(user, 'confirm_username')
+
+def change_send_invoice_email(user, invoice_email):
+ user.invoice_email = invoice_email
+ user.save()
+
+
+def _convert_to_s(timespan_string):
+ """ Returns the given timespan string (e.g. `2w` or `45s`) into seconds. """
+ return convert_to_timedelta(timespan_string).total_seconds()
+
+
+def change_user_tag_expiration(user, tag_expiration_s):
+ """ Changes the tag expiration on the given user/org. Note that the specified expiration must
+ be within the configured TAG_EXPIRATION_OPTIONS or this method will raise a
+ DataModelException.
+ """
+ allowed_options = [
+ _convert_to_s(o) for o in config.app_config["TAG_EXPIRATION_OPTIONS"]
+ ]
+ if tag_expiration_s not in allowed_options:
+ raise DataModelException("Invalid tag expiration option")
+
+ user.removed_tag_expiration_s = tag_expiration_s
+ user.save()
+
+
+def update_email(user, new_email, auto_verify=False):
+ try:
+ user.email = new_email
+ user.verified = auto_verify
+ user.save()
+ except IntegrityError:
+ raise DataModelException("E-mail address already used")
+
+
+def update_enabled(user, set_enabled):
+ user.enabled = set_enabled
+ user.save()
+
+
+def create_robot(robot_shortname, parent, description="", unstructured_metadata=None):
+ (username_valid, username_issue) = validate_username(robot_shortname)
+ if not username_valid:
+ raise InvalidRobotException(
+ "The name for the robot '%s' is invalid: %s"
+ % (robot_shortname, username_issue)
+ )
+
+ username = format_robot_username(parent.username, robot_shortname)
+
+ try:
+ User.get(User.username == username)
+
+ msg = "Existing robot with name: %s" % username
+ logger.info(msg)
+ raise InvalidRobotException(msg)
+ except User.DoesNotExist:
+ pass
+
+ service = LoginService.get(name="quayrobot")
+ try:
+ with db_transaction():
+ created = User.create(
+ username=username, email=str(uuid.uuid4()), robot=True
+ )
+ token = random_string_generator(length=64)()
+ RobotAccountToken.create(robot_account=created, token=token)
+ FederatedLogin.create(
+ user=created, service=service, service_ident="robot:%s" % created.id
+ )
+ RobotAccountMetadata.create(
+ robot_account=created,
+ description=description[0:255],
+ unstructured_json=unstructured_metadata or {},
+ )
+ return created, token
+ except Exception as ex:
+ raise DataModelException(ex.message)
+
+
+def get_or_create_robot_metadata(robot):
+ defaults = dict(description="", unstructured_json={})
+ metadata, _ = RobotAccountMetadata.get_or_create(
+ robot_account=robot, defaults=defaults
+ )
+ return metadata
+
+
+def update_robot_metadata(robot, description="", unstructured_json=None):
+ """ Updates the description and user-specified unstructured metadata associated
+ with a robot account to that specified. """
+ metadata = get_or_create_robot_metadata(robot)
+ metadata.description = description
+ metadata.unstructured_json = unstructured_json or metadata.unstructured_json or {}
+ metadata.save()
+
+
+def retrieve_robot_token(robot):
+ """ Returns the decrypted token for the given robot. """
+ try:
+ token = RobotAccountToken.get(robot_account=robot).token.decrypt()
+ except RobotAccountToken.DoesNotExist:
+ if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ # For legacy only.
+ token = robot.email
+ else:
+ raise
+
+ return token
+
+
+def get_robot_and_metadata(robot_shortname, parent):
+ """ Returns a tuple of the robot matching the given shortname, its token, and its metadata. """
+ robot_username = format_robot_username(parent.username, robot_shortname)
+ robot, metadata = lookup_robot_and_metadata(robot_username)
+ token = retrieve_robot_token(robot)
+ return robot, token, metadata
+
+
+def lookup_robot(robot_username):
+ try:
+ return User.get(username=robot_username, robot=True)
+ except User.DoesNotExist:
+ raise InvalidRobotException(
+ "Could not find robot with username: %s" % robot_username
+ )
+
+
+def lookup_robot_and_metadata(robot_username):
+ robot = lookup_robot(robot_username)
+ return robot, get_or_create_robot_metadata(robot)
+
+
+def get_matching_robots(name_prefix, username, limit=10):
+ admined_orgs = (
+ _basequery.get_user_organizations(username)
+ .switch(Team)
+ .join(TeamRole)
+ .where(TeamRole.name == "admin")
+ )
+
+ prefix_checks = False
+
+ for org in admined_orgs:
+ org_search = prefix_search(User.username, org.username + "+" + name_prefix)
+ prefix_checks = prefix_checks | org_search
+
+ user_search = prefix_search(User.username, username + "+" + name_prefix)
+ prefix_checks = prefix_checks | user_search
+
+ return User.select().where(prefix_checks).limit(limit)
+
+
+def verify_robot(robot_username, password):
+ try:
+ password = remove_unicode(password)
+ except UnicodeEncodeError:
+ msg = (
+ "Could not find robot with username: %s and supplied password."
+ % robot_username
+ )
+ raise InvalidRobotException(msg)
+
+ result = parse_robot_username(robot_username)
+ if result is None:
+ raise InvalidRobotException("%s is an invalid robot name" % robot_username)
+
+ robot = lookup_robot(robot_username)
+ assert robot.robot
+
+ # Lookup the token for the robot.
+ try:
+ token_data = RobotAccountToken.get(robot_account=robot)
+ if not token_data.token.matches(password):
+ msg = (
+ "Could not find robot with username: %s and supplied password."
+ % robot_username
+ )
+ raise InvalidRobotException(msg)
+ except RobotAccountToken.DoesNotExist:
+ # TODO(remove-unenc): Remove once migrated.
+ if not ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
+ raise InvalidRobotException(msg)
+
+ if password.find("robot:") >= 0:
+ # Just to be sure.
+ raise InvalidRobotException(msg)
+
+ query = (
+ User.select()
+ .join(FederatedLogin)
+ .join(LoginService)
+ .where(
+ FederatedLogin.service_ident == password,
+ LoginService.name == "quayrobot",
+ User.username == robot_username,
+ )
+ )
+
+ try:
+ robot = query.get()
+ except User.DoesNotExist:
+ msg = (
+ "Could not find robot with username: %s and supplied password."
+ % robot_username
+ )
+ raise InvalidRobotException(msg)
+
+ # Find the owner user and ensure it is not disabled.
+ try:
+ owner = User.get(User.username == result[0])
+ except User.DoesNotExist:
+ raise InvalidRobotException("Robot %s owner does not exist" % robot_username)
+
+ if not owner.enabled:
+ raise InvalidRobotException(
+ "This user has been disabled. Please contact your administrator."
+ )
+
+ # Mark that the robot was accessed.
+ _basequery.update_last_accessed(robot)
+
+ return robot
+
+
+def regenerate_robot_token(robot_shortname, parent):
+ robot_username = format_robot_username(parent.username, robot_shortname)
+
+ robot, metadata = lookup_robot_and_metadata(robot_username)
+ password = random_string_generator(length=64)()
+ robot.email = str(uuid4())
+ robot.uuid = str(uuid4())
+
+ service = LoginService.get(name="quayrobot")
+ login = FederatedLogin.get(
+ FederatedLogin.user == robot, FederatedLogin.service == service
+ )
+ login.service_ident = "robot:%s" % (robot.id)
+
+ try:
+ token_data = RobotAccountToken.get(robot_account=robot)
+ except RobotAccountToken.DoesNotExist:
+ token_data = RobotAccountToken.create(robot_account=robot)
+
+ token_data.token = password
+
+ with db_transaction():
+ token_data.save()
+ login.save()
+ robot.save()
+
+ return robot, password, metadata
+
+
+def delete_robot(robot_username):
+ try:
+ robot = User.get(username=robot_username, robot=True)
+ robot.delete_instance(recursive=True, delete_nullable=True)
+
+ except User.DoesNotExist:
+ raise InvalidRobotException(
+ "Could not find robot with username: %s" % robot_username
+ )
+
+
+def list_namespace_robots(namespace):
+ """ Returns all the robots found under the given namespace. """
+ return _list_entity_robots(namespace)
+
+
+def _list_entity_robots(entity_name, include_metadata=True, include_token=True):
+ """ Return the list of robots for the specified entity. This MUST return a query, not a
+ materialized list so that callers can use db_for_update.
+ """
+ # TODO(remove-unenc): Remove FederatedLogin and LEFT_OUTER on RobotAccountToken once migration
+ # is complete.
+ if include_metadata or include_token:
+ query = (
+ User.select(User, RobotAccountToken, FederatedLogin, RobotAccountMetadata)
+ .join(FederatedLogin)
+ .switch(User)
+ .join(RobotAccountMetadata, JOIN.LEFT_OUTER)
+ .switch(User)
+ .join(RobotAccountToken, JOIN.LEFT_OUTER)
+ .where(User.robot == True, User.username ** (entity_name + "+%"))
+ )
+ else:
+ query = User.select(User).where(
+ User.robot == True, User.username ** (entity_name + "+%")
+ )
+
+ return query
+
+
+def list_entity_robot_permission_teams(
+ entity_name, limit=None, include_permissions=False
+):
+ query = _list_entity_robots(entity_name)
+
+ # TODO(remove-unenc): Remove FederatedLogin once migration is complete.
+ fields = [
+ User.username,
+ User.creation_date,
+ User.last_accessed,
+ RobotAccountToken.token,
+ FederatedLogin.service_ident,
+ RobotAccountMetadata.description,
+ RobotAccountMetadata.unstructured_json,
+ ]
+ if include_permissions:
+ query = (
+ query.join(
+ RepositoryPermission,
+ JOIN.LEFT_OUTER,
+ on=(RepositoryPermission.user == FederatedLogin.user),
+ )
+ .join(Repository, JOIN.LEFT_OUTER)
+ .switch(User)
+ .join(TeamMember, JOIN.LEFT_OUTER)
+ .join(Team, JOIN.LEFT_OUTER)
+ )
+
+ fields.append(Repository.name)
+ fields.append(Team.name)
+
+ query = query.limit(limit).order_by(User.last_accessed.desc())
+ return TupleSelector(query, fields)
+
+
+def update_user_metadata(user, metadata=None):
+ """ Updates the metadata associated with the user, including his/her name and company. """
+ metadata = metadata if metadata is not None else {}
+
+ with db_transaction():
+ if "given_name" in metadata:
+ user.given_name = metadata["given_name"]
+
+ if "family_name" in metadata:
+ user.family_name = metadata["family_name"]
+
+ if "company" in metadata:
+ user.company = metadata["company"]
+
+ if "location" in metadata:
+ user.location = metadata["location"]
+
+ user.save()
+
+ # Remove any prompts associated with the user's metadata being needed.
+ remove_user_prompt(user, UserPromptTypes.ENTER_NAME)
+ remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY)
+
+
+def _get_login_service(service_id):
+ try:
+ return LoginService.get(LoginService.name == service_id)
+ except LoginService.DoesNotExist:
+ return LoginService.create(name=service_id)
+
+
+def create_federated_user(
+ username,
+ email,
+ service_id,
+ service_ident,
+ set_password_notification,
+ metadata={},
+ email_required=True,
+ confirm_username=True,
+ prompts=tuple(),
+):
+ prompts = set(prompts)
+
+ if confirm_username:
+ prompts.add(UserPromptTypes.CONFIRM_USERNAME)
+
+ new_user = create_user_noverify(
+ username, email, email_required=email_required, prompts=prompts
+ )
+ new_user.verified = True
+ new_user.save()
+
+ FederatedLogin.create(
+ user=new_user,
+ service=_get_login_service(service_id),
+ service_ident=service_ident,
+ metadata_json=json.dumps(metadata),
+ )
+
+ if set_password_notification:
+ notification.create_notification("password_required", new_user)
+
+ return new_user
+
+
+def attach_federated_login(user, service_id, service_ident, metadata=None):
+ service = _get_login_service(service_id)
+ FederatedLogin.create(
+ user=user,
+ service=service,
+ service_ident=service_ident,
+ metadata_json=json.dumps(metadata or {}),
+ )
+ return user
+
+
+def verify_federated_login(service_id, service_ident):
+ try:
+ found = (
+ FederatedLogin.select(FederatedLogin, User)
+ .join(LoginService)
+ .switch(FederatedLogin)
+ .join(User)
+ .where(
+ FederatedLogin.service_ident == service_ident,
+ LoginService.name == service_id,
+ )
+ .get()
+ )
+
+ # Mark that the user was accessed.
+ _basequery.update_last_accessed(found.user)
+
+ return found.user
+ except FederatedLogin.DoesNotExist:
+ return None
+
+
+def list_federated_logins(user):
+ selected = FederatedLogin.select(
+ FederatedLogin.service_ident, LoginService.name, FederatedLogin.metadata_json
+ )
+ joined = selected.join(LoginService)
+ return joined.where(LoginService.name != "quayrobot", FederatedLogin.user == user)
+
+
+def lookup_federated_login(user, service_name):
+ try:
+ return (
+ list_federated_logins(user).where(LoginService.name == service_name).get()
+ )
+ except FederatedLogin.DoesNotExist:
+ return None
+
+
+def create_confirm_email_code(user, new_email=None):
+ if new_email:
+ if not validate_email(new_email):
+ raise InvalidEmailAddressException("Invalid email address: %s" % new_email)
+
+ verification_code, unhashed = Credential.generate()
+ code = EmailConfirmation.create(
+ user=user,
+ email_confirm=True,
+ new_email=new_email,
+ verification_code=verification_code,
+ )
+ return encode_public_private_token(code.code, unhashed)
+
+
+def confirm_user_email(token):
+ # TODO(remove-unenc): Remove allow_public_only once migrated.
+ allow_public_only = ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
+ result = decode_public_private_token(token, allow_public_only=allow_public_only)
+ if not result:
+ raise DataModelException("Invalid email confirmation code")
+
+ try:
+ code = EmailConfirmation.get(
+ EmailConfirmation.code == result.public_code,
+ EmailConfirmation.email_confirm == True,
+ )
+ except EmailConfirmation.DoesNotExist:
+ raise DataModelException("Invalid email confirmation code")
+
+ if result.private_token and not code.verification_code.matches(
+ result.private_token
+ ):
+ raise DataModelException("Invalid email confirmation code")
+
+ user = code.user
+ user.verified = True
+
+ old_email = None
+ new_email = code.new_email
+ if new_email and new_email != old_email:
+ if find_user_by_email(new_email):
+ raise DataModelException("E-mail address already used")
+
+ old_email = user.email
+ user.email = new_email
+
+ with db_transaction():
+ user.save()
+ code.delete_instance()
+
+ return user, new_email, old_email
+
+
+def create_reset_password_email_code(email):
+ try:
+ user = User.get(User.email == email)
+ except User.DoesNotExist:
+ raise InvalidEmailAddressException("Email address was not found")
+
+ if user.organization:
+ raise InvalidEmailAddressException("Organizations can not have passwords")
+
+ verification_code, unhashed = Credential.generate()
+ code = EmailConfirmation.create(
+ user=user, pw_reset=True, verification_code=verification_code
+ )
+ return encode_public_private_token(code.code, unhashed)
+
+
+def validate_reset_code(token):
+ # TODO(remove-unenc): Remove allow_public_only once migrated.
+ allow_public_only = ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
+ result = decode_public_private_token(token, allow_public_only=allow_public_only)
+ if not result:
+ return None
+
+ # Find the reset code.
+ try:
+ code = EmailConfirmation.get(
+ EmailConfirmation.code == result.public_code,
+ EmailConfirmation.pw_reset == True,
+ )
+ except EmailConfirmation.DoesNotExist:
+ return None
+
+ if result.private_token and not code.verification_code.matches(
+ result.private_token
+ ):
+ return None
+
+ # Make sure the code is not expired.
+ max_lifetime_duration = convert_to_timedelta(
+ config.app_config["USER_RECOVERY_TOKEN_LIFETIME"]
+ )
+ if code.created + max_lifetime_duration < datetime.now():
+ code.delete_instance()
+ return None
+
+ # Verify the user and return the code.
+ user = code.user
+
+ with db_transaction():
+ if not user.verified:
+ user.verified = True
+ user.save()
+
+ code.delete_instance()
return user
-def change_invoice_email_address(user, invoice_email_address):
- # Note: We null out the address if it is an empty string.
- user.invoice_email_address = invoice_email_address or None
- user.save()
-
-
-def change_send_invoice_email(user, invoice_email):
- user.invoice_email = invoice_email
- user.save()
-
-
-def _convert_to_s(timespan_string):
- """ Returns the given timespan string (e.g. `2w` or `45s`) into seconds. """
- return convert_to_timedelta(timespan_string).total_seconds()
-
-
-def change_user_tag_expiration(user, tag_expiration_s):
- """ Changes the tag expiration on the given user/org. Note that the specified expiration must
- be within the configured TAG_EXPIRATION_OPTIONS or this method will raise a
- DataModelException.
- """
- allowed_options = [_convert_to_s(o) for o in config.app_config['TAG_EXPIRATION_OPTIONS']]
- if tag_expiration_s not in allowed_options:
- raise DataModelException('Invalid tag expiration option')
-
- user.removed_tag_expiration_s = tag_expiration_s
- user.save()
-
-
-def update_email(user, new_email, auto_verify=False):
- try:
- user.email = new_email
- user.verified = auto_verify
- user.save()
- except IntegrityError:
- raise DataModelException('E-mail address already used')
-
-
-def update_enabled(user, set_enabled):
- user.enabled = set_enabled
- user.save()
-
-
-def create_robot(robot_shortname, parent, description='', unstructured_metadata=None):
- (username_valid, username_issue) = validate_username(robot_shortname)
- if not username_valid:
- raise InvalidRobotException('The name for the robot \'%s\' is invalid: %s' %
- (robot_shortname, username_issue))
-
- username = format_robot_username(parent.username, robot_shortname)
-
- try:
- User.get(User.username == username)
-
- msg = 'Existing robot with name: %s' % username
- logger.info(msg)
- raise InvalidRobotException(msg)
- except User.DoesNotExist:
- pass
-
- service = LoginService.get(name='quayrobot')
- try:
- with db_transaction():
- created = User.create(username=username, email=str(uuid.uuid4()), robot=True)
- token = random_string_generator(length=64)()
- RobotAccountToken.create(robot_account=created, token=token)
- FederatedLogin.create(user=created, service=service, service_ident='robot:%s' % created.id)
- RobotAccountMetadata.create(robot_account=created, description=description[0:255],
- unstructured_json=unstructured_metadata or {})
- return created, token
- except Exception as ex:
- raise DataModelException(ex.message)
-
-
-def get_or_create_robot_metadata(robot):
- defaults = dict(description='', unstructured_json={})
- metadata, _ = RobotAccountMetadata.get_or_create(robot_account=robot, defaults=defaults)
- return metadata
-
-
-def update_robot_metadata(robot, description='', unstructured_json=None):
- """ Updates the description and user-specified unstructured metadata associated
- with a robot account to that specified. """
- metadata = get_or_create_robot_metadata(robot)
- metadata.description = description
- metadata.unstructured_json = unstructured_json or metadata.unstructured_json or {}
- metadata.save()
-
-
-def retrieve_robot_token(robot):
- """ Returns the decrypted token for the given robot. """
- try:
- token = RobotAccountToken.get(robot_account=robot).token.decrypt()
- except RobotAccountToken.DoesNotExist:
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- # For legacy only.
- token = robot.email
- else:
- raise
-
- return token
-
-
-def get_robot_and_metadata(robot_shortname, parent):
- """ Returns a tuple of the robot matching the given shortname, its token, and its metadata. """
- robot_username = format_robot_username(parent.username, robot_shortname)
- robot, metadata = lookup_robot_and_metadata(robot_username)
- token = retrieve_robot_token(robot)
- return robot, token, metadata
-
-
-def lookup_robot(robot_username):
- try:
- return User.get(username=robot_username, robot=True)
- except User.DoesNotExist:
- raise InvalidRobotException('Could not find robot with username: %s' % robot_username)
-
-
-def lookup_robot_and_metadata(robot_username):
- robot = lookup_robot(robot_username)
- return robot, get_or_create_robot_metadata(robot)
-
-
-def get_matching_robots(name_prefix, username, limit=10):
- admined_orgs = (_basequery.get_user_organizations(username)
- .switch(Team)
- .join(TeamRole)
- .where(TeamRole.name == 'admin'))
-
- prefix_checks = False
-
- for org in admined_orgs:
- org_search = prefix_search(User.username, org.username + '+' + name_prefix)
- prefix_checks = prefix_checks | org_search
-
- user_search = prefix_search(User.username, username + '+' + name_prefix)
- prefix_checks = prefix_checks | user_search
-
- return User.select().where(prefix_checks).limit(limit)
-
-
-def verify_robot(robot_username, password):
- try:
- password = remove_unicode(password)
- except UnicodeEncodeError:
- msg = ('Could not find robot with username: %s and supplied password.' %
- robot_username)
- raise InvalidRobotException(msg)
-
- result = parse_robot_username(robot_username)
- if result is None:
- raise InvalidRobotException('%s is an invalid robot name' % robot_username)
-
- robot = lookup_robot(robot_username)
- assert robot.robot
-
- # Lookup the token for the robot.
- try:
- token_data = RobotAccountToken.get(robot_account=robot)
- if not token_data.token.matches(password):
- msg = ('Could not find robot with username: %s and supplied password.' %
- robot_username)
- raise InvalidRobotException(msg)
- except RobotAccountToken.DoesNotExist:
- # TODO(remove-unenc): Remove once migrated.
- if not ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- raise InvalidRobotException(msg)
-
- if password.find('robot:') >= 0:
- # Just to be sure.
- raise InvalidRobotException(msg)
-
- query = (User
- .select()
- .join(FederatedLogin)
- .join(LoginService)
- .where(FederatedLogin.service_ident == password, LoginService.name == 'quayrobot',
- User.username == robot_username))
-
- try:
- robot = query.get()
- except User.DoesNotExist:
- msg = ('Could not find robot with username: %s and supplied password.' %
- robot_username)
- raise InvalidRobotException(msg)
-
- # Find the owner user and ensure it is not disabled.
- try:
- owner = User.get(User.username == result[0])
- except User.DoesNotExist:
- raise InvalidRobotException('Robot %s owner does not exist' % robot_username)
-
- if not owner.enabled:
- raise InvalidRobotException('This user has been disabled. Please contact your administrator.')
-
- # Mark that the robot was accessed.
- _basequery.update_last_accessed(robot)
-
- return robot
-
-def regenerate_robot_token(robot_shortname, parent):
- robot_username = format_robot_username(parent.username, robot_shortname)
-
- robot, metadata = lookup_robot_and_metadata(robot_username)
- password = random_string_generator(length=64)()
- robot.email = str(uuid4())
- robot.uuid = str(uuid4())
-
- service = LoginService.get(name='quayrobot')
- login = FederatedLogin.get(FederatedLogin.user == robot, FederatedLogin.service == service)
- login.service_ident = 'robot:%s' % (robot.id)
-
- try:
- token_data = RobotAccountToken.get(robot_account=robot)
- except RobotAccountToken.DoesNotExist:
- token_data = RobotAccountToken.create(robot_account=robot)
-
- token_data.token = password
-
- with db_transaction():
- token_data.save()
- login.save()
- robot.save()
-
- return robot, password, metadata
-
-
-def delete_robot(robot_username):
- try:
- robot = User.get(username=robot_username, robot=True)
- robot.delete_instance(recursive=True, delete_nullable=True)
-
- except User.DoesNotExist:
- raise InvalidRobotException('Could not find robot with username: %s' %
- robot_username)
-
-
-def list_namespace_robots(namespace):
- """ Returns all the robots found under the given namespace. """
- return _list_entity_robots(namespace)
-
-
-def _list_entity_robots(entity_name, include_metadata=True, include_token=True):
- """ Return the list of robots for the specified entity. This MUST return a query, not a
- materialized list so that callers can use db_for_update.
- """
- # TODO(remove-unenc): Remove FederatedLogin and LEFT_OUTER on RobotAccountToken once migration
- # is complete.
- if include_metadata or include_token:
- query = (User
- .select(User, RobotAccountToken, FederatedLogin, RobotAccountMetadata)
- .join(FederatedLogin)
- .switch(User)
- .join(RobotAccountMetadata, JOIN.LEFT_OUTER)
- .switch(User)
- .join(RobotAccountToken, JOIN.LEFT_OUTER)
- .where(User.robot == True, User.username ** (entity_name + '+%')))
- else:
- query = (User
- .select(User)
- .where(User.robot == True, User.username ** (entity_name + '+%')))
-
- return query
-
-
-def list_entity_robot_permission_teams(entity_name, limit=None, include_permissions=False):
- query = (_list_entity_robots(entity_name))
-
- # TODO(remove-unenc): Remove FederatedLogin once migration is complete.
- fields = [User.username, User.creation_date, User.last_accessed, RobotAccountToken.token,
- FederatedLogin.service_ident, RobotAccountMetadata.description,
- RobotAccountMetadata.unstructured_json]
- if include_permissions:
- query = (query
- .join(RepositoryPermission, JOIN.LEFT_OUTER,
- on=(RepositoryPermission.user == FederatedLogin.user))
- .join(Repository, JOIN.LEFT_OUTER)
- .switch(User)
- .join(TeamMember, JOIN.LEFT_OUTER)
- .join(Team, JOIN.LEFT_OUTER))
-
- fields.append(Repository.name)
- fields.append(Team.name)
-
- query = query.limit(limit).order_by(User.last_accessed.desc())
- return TupleSelector(query, fields)
-
-
-def update_user_metadata(user, metadata=None):
- """ Updates the metadata associated with the user, including his/her name and company. """
- metadata = metadata if metadata is not None else {}
-
- with db_transaction():
- if 'given_name' in metadata:
- user.given_name = metadata['given_name']
-
- if 'family_name' in metadata:
- user.family_name = metadata['family_name']
-
- if 'company' in metadata:
- user.company = metadata['company']
-
- if 'location' in metadata:
- user.location = metadata['location']
-
- user.save()
-
- # Remove any prompts associated with the user's metadata being needed.
- remove_user_prompt(user, UserPromptTypes.ENTER_NAME)
- remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY)
-
-
-def _get_login_service(service_id):
- try:
- return LoginService.get(LoginService.name == service_id)
- except LoginService.DoesNotExist:
- return LoginService.create(name=service_id)
-
-
-def create_federated_user(username, email, service_id, service_ident,
- set_password_notification, metadata={},
- email_required=True, confirm_username=True,
- prompts=tuple()):
- prompts = set(prompts)
-
- if confirm_username:
- prompts.add(UserPromptTypes.CONFIRM_USERNAME)
-
- new_user = create_user_noverify(username, email, email_required=email_required, prompts=prompts)
- new_user.verified = True
- new_user.save()
-
- FederatedLogin.create(user=new_user, service=_get_login_service(service_id),
- service_ident=service_ident,
- metadata_json=json.dumps(metadata))
-
- if set_password_notification:
- notification.create_notification('password_required', new_user)
-
- return new_user
-
-
-def attach_federated_login(user, service_id, service_ident, metadata=None):
- service = _get_login_service(service_id)
- FederatedLogin.create(user=user, service=service, service_ident=service_ident,
- metadata_json=json.dumps(metadata or {}))
- return user
-
-
-def verify_federated_login(service_id, service_ident):
- try:
- found = (FederatedLogin
- .select(FederatedLogin, User)
- .join(LoginService)
- .switch(FederatedLogin).join(User)
- .where(FederatedLogin.service_ident == service_ident, LoginService.name == service_id)
- .get())
-
- # Mark that the user was accessed.
- _basequery.update_last_accessed(found.user)
-
- return found.user
- except FederatedLogin.DoesNotExist:
- return None
-
-
-def list_federated_logins(user):
- selected = FederatedLogin.select(FederatedLogin.service_ident,
- LoginService.name, FederatedLogin.metadata_json)
- joined = selected.join(LoginService)
- return joined.where(LoginService.name != 'quayrobot',
- FederatedLogin.user == user)
-
-
-def lookup_federated_login(user, service_name):
- try:
- return list_federated_logins(user).where(LoginService.name == service_name).get()
- except FederatedLogin.DoesNotExist:
- return None
-
-
-def create_confirm_email_code(user, new_email=None):
- if new_email:
- if not validate_email(new_email):
- raise InvalidEmailAddressException('Invalid email address: %s' %
- new_email)
-
- verification_code, unhashed = Credential.generate()
- code = EmailConfirmation.create(user=user,
- email_confirm=True,
- new_email=new_email,
- verification_code=verification_code)
- return encode_public_private_token(code.code, unhashed)
-
-
-def confirm_user_email(token):
- # TODO(remove-unenc): Remove allow_public_only once migrated.
- allow_public_only = ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
- result = decode_public_private_token(token, allow_public_only=allow_public_only)
- if not result:
- raise DataModelException('Invalid email confirmation code')
-
- try:
- code = EmailConfirmation.get(EmailConfirmation.code == result.public_code,
- EmailConfirmation.email_confirm == True)
- except EmailConfirmation.DoesNotExist:
- raise DataModelException('Invalid email confirmation code')
-
- if result.private_token and not code.verification_code.matches(result.private_token):
- raise DataModelException('Invalid email confirmation code')
-
- user = code.user
- user.verified = True
-
- old_email = None
- new_email = code.new_email
- if new_email and new_email != old_email:
- if find_user_by_email(new_email):
- raise DataModelException('E-mail address already used')
-
- old_email = user.email
- user.email = new_email
-
- with db_transaction():
- user.save()
- code.delete_instance()
-
- return user, new_email, old_email
-
-
-def create_reset_password_email_code(email):
- try:
- user = User.get(User.email == email)
- except User.DoesNotExist:
- raise InvalidEmailAddressException('Email address was not found')
-
- if user.organization:
- raise InvalidEmailAddressException('Organizations can not have passwords')
-
- verification_code, unhashed = Credential.generate()
- code = EmailConfirmation.create(user=user, pw_reset=True, verification_code=verification_code)
- return encode_public_private_token(code.code, unhashed)
-
-
-def validate_reset_code(token):
- # TODO(remove-unenc): Remove allow_public_only once migrated.
- allow_public_only = ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
- result = decode_public_private_token(token, allow_public_only=allow_public_only)
- if not result:
- return None
-
- # Find the reset code.
- try:
- code = EmailConfirmation.get(EmailConfirmation.code == result.public_code,
- EmailConfirmation.pw_reset == True)
- except EmailConfirmation.DoesNotExist:
- return None
-
- if result.private_token and not code.verification_code.matches(result.private_token):
- return None
-
- # Make sure the code is not expired.
- max_lifetime_duration = convert_to_timedelta(config.app_config['USER_RECOVERY_TOKEN_LIFETIME'])
- if code.created + max_lifetime_duration < datetime.now():
- code.delete_instance()
- return None
-
- # Verify the user and return the code.
- user = code.user
-
- with db_transaction():
- if not user.verified:
- user.verified = True
- user.save()
-
- code.delete_instance()
-
- return user
-
-
def find_user_by_email(email):
- try:
- return User.get(User.email == email)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.email == email)
+ except User.DoesNotExist:
+ return None
def get_nonrobot_user(username):
- try:
- return User.get(User.username == username, User.organization == False, User.robot == False)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(
+ User.username == username, User.organization == False, User.robot == False
+ )
+ except User.DoesNotExist:
+ return None
def get_user(username):
- try:
- return User.get(User.username == username, User.organization == False)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.username == username, User.organization == False)
+ except User.DoesNotExist:
+ return None
def get_namespace_user(username):
- try:
- return User.get(User.username == username)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.username == username)
+ except User.DoesNotExist:
+ return None
def get_user_or_org(username):
- try:
- return User.get(User.username == username, User.robot == False)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.username == username, User.robot == False)
+ except User.DoesNotExist:
+ return None
def get_user_by_id(user_db_id):
- try:
- return User.get(User.id == user_db_id, User.organization == False)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.id == user_db_id, User.organization == False)
+ except User.DoesNotExist:
+ return None
def get_user_map_by_ids(namespace_ids):
- id_user = {namespace_id: None for namespace_id in namespace_ids}
- users = User.select().where(User.id << namespace_ids, User.organization == False)
- for user in users:
- id_user[user.id] = user
+ id_user = {namespace_id: None for namespace_id in namespace_ids}
+ users = User.select().where(User.id << namespace_ids, User.organization == False)
+ for user in users:
+ id_user[user.id] = user
+
+ return id_user
- return id_user
def get_namespace_user_by_user_id(namespace_user_db_id):
- try:
- return User.get(User.id == namespace_user_db_id, User.robot == False)
- except User.DoesNotExist:
- raise InvalidUsernameException('User with id does not exist: %s' % namespace_user_db_id)
+ try:
+ return User.get(User.id == namespace_user_db_id, User.robot == False)
+ except User.DoesNotExist:
+ raise InvalidUsernameException(
+ "User with id does not exist: %s" % namespace_user_db_id
+ )
def get_namespace_by_user_id(namespace_user_db_id):
- try:
- return User.get(User.id == namespace_user_db_id, User.robot == False).username
- except User.DoesNotExist:
- raise InvalidUsernameException('User with id does not exist: %s' % namespace_user_db_id)
+ try:
+ return User.get(User.id == namespace_user_db_id, User.robot == False).username
+ except User.DoesNotExist:
+ raise InvalidUsernameException(
+ "User with id does not exist: %s" % namespace_user_db_id
+ )
def get_user_by_uuid(user_uuid):
- try:
- return User.get(User.uuid == user_uuid, User.organization == False)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.uuid == user_uuid, User.organization == False)
+ except User.DoesNotExist:
+ return None
def get_user_or_org_by_customer_id(customer_id):
- try:
- return User.get(User.stripe_id == customer_id)
- except User.DoesNotExist:
- return None
+ try:
+ return User.get(User.stripe_id == customer_id)
+ except User.DoesNotExist:
+ return None
+
def invalidate_all_sessions(user):
- """ Invalidates all existing user sessions by rotating the user's UUID. """
- if not user:
- return
+ """ Invalidates all existing user sessions by rotating the user's UUID. """
+ if not user:
+ return
+
+ user.uuid = str(uuid4())
+ user.save()
- user.uuid = str(uuid4())
- user.save()
def get_matching_user_namespaces(namespace_prefix, username, limit=10):
- namespace_user = get_namespace_user(username)
- namespace_user_id = namespace_user.id if namespace_user is not None else None
+ namespace_user = get_namespace_user(username)
+ namespace_user_id = namespace_user.id if namespace_user is not None else None
- namespace_search = prefix_search(Namespace.username, namespace_prefix)
- base_query = (Namespace
- .select()
- .distinct()
- .join(Repository, on=(Repository.namespace_user == Namespace.id))
- .join(RepositoryPermission, JOIN.LEFT_OUTER)
- .where(namespace_search))
+ namespace_search = prefix_search(Namespace.username, namespace_prefix)
+ base_query = (
+ Namespace.select()
+ .distinct()
+ .join(Repository, on=(Repository.namespace_user == Namespace.id))
+ .join(RepositoryPermission, JOIN.LEFT_OUTER)
+ .where(namespace_search)
+ )
- return _basequery.filter_to_repos_for_user(base_query, namespace_user_id).limit(limit)
+ return _basequery.filter_to_repos_for_user(base_query, namespace_user_id).limit(
+ limit
+ )
-def get_matching_users(username_prefix, robot_namespace=None, organization=None, limit=20,
- exact_matches_only=False):
- # Lookup the exact match first. This ensures that the exact match is not cut off by the list
- # limit.
- updated_limit = limit
- exact_match = list(_get_matching_users(username_prefix, robot_namespace, organization, limit=1,
- exact_matches_only=True))
- if exact_match:
- updated_limit -= 1
- yield exact_match[0]
- # Perform the remainder of the lookup.
- if updated_limit:
- for result in _get_matching_users(username_prefix, robot_namespace, organization, updated_limit,
- exact_matches_only):
- if exact_match and result.username == exact_match[0].username:
- continue
+def get_matching_users(
+ username_prefix,
+ robot_namespace=None,
+ organization=None,
+ limit=20,
+ exact_matches_only=False,
+):
+ # Lookup the exact match first. This ensures that the exact match is not cut off by the list
+ # limit.
+ updated_limit = limit
+ exact_match = list(
+ _get_matching_users(
+ username_prefix,
+ robot_namespace,
+ organization,
+ limit=1,
+ exact_matches_only=True,
+ )
+ )
+ if exact_match:
+ updated_limit -= 1
+ yield exact_match[0]
- yield result
+ # Perform the remainder of the lookup.
+ if updated_limit:
+ for result in _get_matching_users(
+ username_prefix,
+ robot_namespace,
+ organization,
+ updated_limit,
+ exact_matches_only,
+ ):
+ if exact_match and result.username == exact_match[0].username:
+ continue
-def _get_matching_users(username_prefix, robot_namespace=None, organization=None, limit=20,
- exact_matches_only=False):
- user_search = prefix_search(User.username, username_prefix)
- if exact_matches_only:
- user_search = (User.username == username_prefix)
+ yield result
- direct_user_query = (user_search & (User.organization == False) & (User.robot == False))
- if robot_namespace:
- robot_prefix = format_robot_username(robot_namespace, username_prefix)
- robot_search = prefix_search(User.username, robot_prefix)
- direct_user_query = ((robot_search & (User.robot == True)) | direct_user_query)
+def _get_matching_users(
+ username_prefix,
+ robot_namespace=None,
+ organization=None,
+ limit=20,
+ exact_matches_only=False,
+):
+ user_search = prefix_search(User.username, username_prefix)
+ if exact_matches_only:
+ user_search = User.username == username_prefix
- query = (User
- .select(User.id, User.username, User.email, User.robot)
- .group_by(User.id, User.username, User.email, User.robot)
- .where(direct_user_query))
+ direct_user_query = (
+ user_search & (User.organization == False) & (User.robot == False)
+ )
- if organization:
- query = (query
- .select(User.id, User.username, User.email, User.robot, fn.Sum(Team.id))
- .join(TeamMember, JOIN.LEFT_OUTER)
- .join(Team, JOIN.LEFT_OUTER, on=((Team.id == TeamMember.team) &
- (Team.organization == organization)))
- .order_by(User.robot.desc()))
+ if robot_namespace:
+ robot_prefix = format_robot_username(robot_namespace, username_prefix)
+ robot_search = prefix_search(User.username, robot_prefix)
+ direct_user_query = (robot_search & (User.robot == True)) | direct_user_query
- class MatchingUserResult(object):
- def __init__(self, *args):
- self.id = args[0]
- self.username = args[1]
- self.email = args[2]
- self.robot = args[3]
+ query = (
+ User.select(User.id, User.username, User.email, User.robot)
+ .group_by(User.id, User.username, User.email, User.robot)
+ .where(direct_user_query)
+ )
- if organization:
- self.is_org_member = (args[3] != None)
- else:
- self.is_org_member = None
+ if organization:
+ query = (
+ query.select(
+ User.id, User.username, User.email, User.robot, fn.Sum(Team.id)
+ )
+ .join(TeamMember, JOIN.LEFT_OUTER)
+ .join(
+ Team,
+ JOIN.LEFT_OUTER,
+ on=((Team.id == TeamMember.team) & (Team.organization == organization)),
+ )
+ .order_by(User.robot.desc())
+ )
- return (MatchingUserResult(*args) for args in query.tuples().limit(limit))
+ class MatchingUserResult(object):
+ def __init__(self, *args):
+ self.id = args[0]
+ self.username = args[1]
+ self.email = args[2]
+ self.robot = args[3]
+
+ if organization:
+ self.is_org_member = args[3] != None
+ else:
+ self.is_org_member = None
+
+ return (MatchingUserResult(*args) for args in query.tuples().limit(limit))
def verify_user(username_or_email, password):
- """ Verifies that the given username/email + password pair is valid. If the username or e-mail
+ """ Verifies that the given username/email + password pair is valid. If the username or e-mail
address is invalid, returns None. If the password specified does not match for the given user,
either returns None or raises TooManyLoginAttemptsException if there have been too many
invalid login attempts. Returns the user object if the login was valid.
"""
- # Make sure we didn't get any unicode for the username.
- try:
- str(username_or_email)
- except ValueError:
- return None
+ # Make sure we didn't get any unicode for the username.
+ try:
+ str(username_or_email)
+ except ValueError:
+ return None
- # Fetch the user with the matching username or e-mail address.
- try:
- fetched = User.get((User.username == username_or_email) | (User.email == username_or_email))
- except User.DoesNotExist:
- return None
+ # Fetch the user with the matching username or e-mail address.
+ try:
+ fetched = User.get(
+ (User.username == username_or_email) | (User.email == username_or_email)
+ )
+ except User.DoesNotExist:
+ return None
- # If the user has any invalid login attempts, check to see if we are within the exponential
- # backoff window for the user. If so, we raise an exception indicating that the user cannot
- # login.
- now = datetime.utcnow()
- if fetched.invalid_login_attempts > 0:
- can_retry_at = exponential_backoff(fetched.invalid_login_attempts, EXPONENTIAL_BACKOFF_SCALE,
- fetched.last_invalid_login)
-
- if can_retry_at > now:
- retry_after = can_retry_at - now
- raise TooManyLoginAttemptsException('Too many login attempts.', retry_after.total_seconds())
-
- # Hash the given password and compare it to the specified password.
- if (fetched.password_hash and
- hash_password(password, fetched.password_hash) == fetched.password_hash):
-
- # If the user previously had any invalid login attempts, clear them out now.
+ # If the user has any invalid login attempts, check to see if we are within the exponential
+ # backoff window for the user. If so, we raise an exception indicating that the user cannot
+ # login.
+ now = datetime.utcnow()
if fetched.invalid_login_attempts > 0:
- try:
- (User
- .update(invalid_login_attempts=0)
- .where(User.id == fetched.id)
- .execute())
+ can_retry_at = exponential_backoff(
+ fetched.invalid_login_attempts,
+ EXPONENTIAL_BACKOFF_SCALE,
+ fetched.last_invalid_login,
+ )
- # Mark that the user was accessed.
- _basequery.update_last_accessed(fetched)
- except ReadOnlyModeException:
+ if can_retry_at > now:
+ retry_after = can_retry_at - now
+ raise TooManyLoginAttemptsException(
+ "Too many login attempts.", retry_after.total_seconds()
+ )
+
+ # Hash the given password and compare it to the specified password.
+ if (
+ fetched.password_hash
+ and hash_password(password, fetched.password_hash) == fetched.password_hash
+ ):
+
+ # If the user previously had any invalid login attempts, clear them out now.
+ if fetched.invalid_login_attempts > 0:
+ try:
+ (
+ User.update(invalid_login_attempts=0)
+ .where(User.id == fetched.id)
+ .execute()
+ )
+
+ # Mark that the user was accessed.
+ _basequery.update_last_accessed(fetched)
+ except ReadOnlyModeException:
+ pass
+
+ # Return the valid user.
+ return fetched
+
+ # Otherwise, update the user's invalid login attempts.
+ try:
+ (
+ User.update(
+ invalid_login_attempts=User.invalid_login_attempts + 1,
+ last_invalid_login=now,
+ )
+ .where(User.id == fetched.id)
+ .execute()
+ )
+ except ReadOnlyModeException:
pass
- # Return the valid user.
- return fetched
-
- # Otherwise, update the user's invalid login attempts.
- try:
- (User
- .update(invalid_login_attempts=User.invalid_login_attempts+1, last_invalid_login=now)
- .where(User.id == fetched.id)
- .execute())
- except ReadOnlyModeException:
- pass
-
- # We weren't able to authorize the user
- return None
+ # We weren't able to authorize the user
+ return None
def get_all_repo_users(namespace_name, repository_name):
- return (RepositoryPermission
- .select(User, Role, RepositoryPermission)
- .join(User)
- .switch(RepositoryPermission)
- .join(Role)
- .switch(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ return (
+ RepositoryPermission.select(User, Role, RepositoryPermission)
+ .join(User)
+ .switch(RepositoryPermission)
+ .join(Role)
+ .switch(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
def get_all_repo_users_transitive_via_teams(namespace_name, repository_name):
- return (User
- .select()
- .distinct()
- .join(TeamMember)
- .join(Team)
- .join(RepositoryPermission)
- .join(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == namespace_name, Repository.name == repository_name))
+ return (
+ User.select()
+ .distinct()
+ .join(TeamMember)
+ .join(Team)
+ .join(RepositoryPermission)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == namespace_name, Repository.name == repository_name)
+ )
def get_all_repo_users_transitive(namespace_name, repository_name):
- # Load the users found via teams and directly via permissions.
- via_teams = get_all_repo_users_transitive_via_teams(namespace_name, repository_name)
- directly = [perm.user for perm in get_all_repo_users(namespace_name, repository_name)]
+ # Load the users found via teams and directly via permissions.
+ via_teams = get_all_repo_users_transitive_via_teams(namespace_name, repository_name)
+ directly = [
+ perm.user for perm in get_all_repo_users(namespace_name, repository_name)
+ ]
- # Filter duplicates.
- user_set = set()
+ # Filter duplicates.
+ user_set = set()
- def check_add(u):
- if u.username in user_set:
- return False
+ def check_add(u):
+ if u.username in user_set:
+ return False
- user_set.add(u.username)
- return True
+ user_set.add(u.username)
+ return True
- return [user for user in list(directly) + list(via_teams) if check_add(user)]
+ return [user for user in list(directly) + list(via_teams) if check_add(user)]
def get_private_repo_count(username):
- return (Repository
- .select()
- .join(Visibility)
- .switch(Repository)
- .join(Namespace, on=(Repository.namespace_user == Namespace.id))
- .where(Namespace.username == username, Visibility.name == 'private')
- .count())
+ return (
+ Repository.select()
+ .join(Visibility)
+ .switch(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(Namespace.username == username, Visibility.name == "private")
+ .count()
+ )
def get_active_users(disabled=True, deleted=False):
- query = (User
- .select()
- .where(User.organization == False, User.robot == False))
+ query = User.select().where(User.organization == False, User.robot == False)
- if not disabled:
- query = query.where(User.enabled == True)
+ if not disabled:
+ query = query.where(User.enabled == True)
- if not deleted:
- query = query.where(User.id.not_in(DeletedNamespace.select(DeletedNamespace.namespace)))
+ if not deleted:
+ query = query.where(
+ User.id.not_in(DeletedNamespace.select(DeletedNamespace.namespace))
+ )
- return query
+ return query
def get_active_user_count():
- return get_active_users().count()
+ return get_active_users().count()
def get_robot_count():
- return User.select().where(User.robot == True).count()
+ return User.select().where(User.robot == True).count()
def detach_external_login(user, service_name):
- try:
- service = LoginService.get(name=service_name)
- except LoginService.DoesNotExist:
- return
+ try:
+ service = LoginService.get(name=service_name)
+ except LoginService.DoesNotExist:
+ return
- FederatedLogin.delete().where(FederatedLogin.user == user,
- FederatedLogin.service == service).execute()
+ FederatedLogin.delete().where(
+ FederatedLogin.user == user, FederatedLogin.service == service
+ ).execute()
def get_solely_admined_organizations(user_obj):
- """ Returns the organizations admined solely by the given user. """
- orgs = (User.select()
- .where(User.organization == True)
- .join(Team)
- .join(TeamRole)
- .where(TeamRole.name == 'admin')
- .switch(Team)
- .join(TeamMember)
- .where(TeamMember.user == user_obj)
- .distinct())
+ """ Returns the organizations admined solely by the given user. """
+ orgs = (
+ User.select()
+ .where(User.organization == True)
+ .join(Team)
+ .join(TeamRole)
+ .where(TeamRole.name == "admin")
+ .switch(Team)
+ .join(TeamMember)
+ .where(TeamMember.user == user_obj)
+ .distinct()
+ )
- # Filter to organizations where the user is the sole admin.
- solely_admined = []
- for org in orgs:
- admin_user_count = (TeamMember.select()
- .join(Team)
- .join(TeamRole)
- .where(Team.organization == org, TeamRole.name == 'admin')
- .switch(TeamMember)
- .join(User)
- .where(User.robot == False)
- .distinct()
- .count())
+ # Filter to organizations where the user is the sole admin.
+ solely_admined = []
+ for org in orgs:
+ admin_user_count = (
+ TeamMember.select()
+ .join(Team)
+ .join(TeamRole)
+ .where(Team.organization == org, TeamRole.name == "admin")
+ .switch(TeamMember)
+ .join(User)
+ .where(User.robot == False)
+ .distinct()
+ .count()
+ )
- if admin_user_count == 1:
- solely_admined.append(org)
+ if admin_user_count == 1:
+ solely_admined.append(org)
- return solely_admined
+ return solely_admined
def mark_namespace_for_deletion(user, queues, namespace_gc_queue, force=False):
- """ Marks a namespace (as referenced by the given user) for deletion. A queue item will be added
+ """ Marks a namespace (as referenced by the given user) for deletion. A queue item will be added
to delete the namespace's repositories and storage, while the namespace itself will be
renamed, disabled, and delinked from other tables.
"""
- if not user.enabled:
- return None
+ if not user.enabled:
+ return None
- if not force and not user.organization:
- # Ensure that the user is not the sole admin for any organizations. If so, then the user
- # cannot be deleted before those organizations are deleted or reassigned.
- organizations = get_solely_admined_organizations(user)
- if len(organizations) > 0:
- message = 'Cannot delete %s as you are the only admin for organizations: ' % user.username
- for index, org in enumerate(organizations):
- if index > 0:
- message = message + ', '
+ if not force and not user.organization:
+ # Ensure that the user is not the sole admin for any organizations. If so, then the user
+ # cannot be deleted before those organizations are deleted or reassigned.
+ organizations = get_solely_admined_organizations(user)
+ if len(organizations) > 0:
+ message = (
+ "Cannot delete %s as you are the only admin for organizations: "
+ % user.username
+ )
+ for index, org in enumerate(organizations):
+ if index > 0:
+ message = message + ", "
- message = message + org.username
+ message = message + org.username
- raise DataModelException(message)
+ raise DataModelException(message)
- # Delete all queue items for the user.
- for queue in queues:
- queue.delete_namespaced_items(user.username)
+ # Delete all queue items for the user.
+ for queue in queues:
+ queue.delete_namespaced_items(user.username)
- # Delete non-repository related items. This operation is very quick, so we can do so here.
- _delete_user_linked_data(user)
+ # Delete non-repository related items. This operation is very quick, so we can do so here.
+ _delete_user_linked_data(user)
- with db_transaction():
- original_username = user.username
- user = db_for_update(User.select().where(User.id == user.id)).get()
+ with db_transaction():
+ original_username = user.username
+ user = db_for_update(User.select().where(User.id == user.id)).get()
- # Mark the namespace as deleted and ready for GC.
- try:
- marker = DeletedNamespace.create(namespace=user,
- original_username=original_username,
- original_email=user.email)
- except IntegrityError:
- return
-
- # Disable the namespace itself, and replace its various unique fields with UUIDs.
- user.enabled = False
- user.username = str(uuid4())
- user.email = str(uuid4())
- user.save()
+ # Mark the namespace as deleted and ready for GC.
+ try:
+ marker = DeletedNamespace.create(
+ namespace=user,
+ original_username=original_username,
+ original_email=user.email,
+ )
+ except IntegrityError:
+ return
- # Add a queueitem to delete the namespace.
- marker.queue_id = namespace_gc_queue.put([str(user.id)], json.dumps({
- 'marker_id': marker.id,
- 'original_username': original_username,
- }))
- marker.save()
- return marker.id
+ # Disable the namespace itself, and replace its various unique fields with UUIDs.
+ user.enabled = False
+ user.username = str(uuid4())
+ user.email = str(uuid4())
+ user.save()
+
+ # Add a queueitem to delete the namespace.
+ marker.queue_id = namespace_gc_queue.put(
+ [str(user.id)],
+ json.dumps({"marker_id": marker.id, "original_username": original_username}),
+ )
+ marker.save()
+ return marker.id
def delete_namespace_via_marker(marker_id, queues):
- """ Deletes a namespace referenced by the given DeletedNamespace marker ID. """
- try:
- marker = DeletedNamespace.get(id=marker_id)
- except DeletedNamespace.DoesNotExist:
- return
+ """ Deletes a namespace referenced by the given DeletedNamespace marker ID. """
+ try:
+ marker = DeletedNamespace.get(id=marker_id)
+ except DeletedNamespace.DoesNotExist:
+ return
- delete_user(marker.namespace, queues)
+ delete_user(marker.namespace, queues)
def delete_user(user, queues):
- """ Deletes a user/organization/robot. Should *not* be called by any user-facing API. Instead,
+ """ Deletes a user/organization/robot. Should *not* be called by any user-facing API. Instead,
mark_namespace_for_deletion should be used, and the queue should call this method.
"""
- # Delete all queue items for the user.
- for queue in queues:
- queue.delete_namespaced_items(user.username)
+ # Delete all queue items for the user.
+ for queue in queues:
+ queue.delete_namespaced_items(user.username)
- # Delete any repositories under the user's namespace.
- for repo in list(Repository.select().where(Repository.namespace_user == user)):
- gc.purge_repository(user.username, repo.name)
+ # Delete any repositories under the user's namespace.
+ for repo in list(Repository.select().where(Repository.namespace_user == user)):
+ gc.purge_repository(user.username, repo.name)
- # Delete non-repository related items.
- _delete_user_linked_data(user)
+ # Delete non-repository related items.
+ _delete_user_linked_data(user)
- # Delete the user itself.
- user.delete_instance(recursive=True, delete_nullable=True)
+ # Delete the user itself.
+ user.delete_instance(recursive=True, delete_nullable=True)
def _delete_user_linked_data(user):
- if user.organization:
- # Delete the organization's teams.
+ if user.organization:
+ # Delete the organization's teams.
+ with db_transaction():
+ for team in Team.select().where(Team.organization == user):
+ team.delete_instance(recursive=True)
+
+ # Delete any OAuth approvals and tokens associated with the user.
+ with db_transaction():
+ for app in OAuthApplication.select().where(
+ OAuthApplication.organization == user
+ ):
+ app.delete_instance(recursive=True)
+ else:
+ # Remove the user from any teams in which they are a member.
+ TeamMember.delete().where(TeamMember.user == user).execute()
+
+ # Delete any repository buildtriggers where the user is the connected user.
with db_transaction():
- for team in Team.select().where(Team.organization == user):
- team.delete_instance(recursive=True)
+ triggers = RepositoryBuildTrigger.select().where(
+ RepositoryBuildTrigger.connected_user == user
+ )
+ for trigger in triggers:
+ trigger.delete_instance(recursive=True, delete_nullable=False)
- # Delete any OAuth approvals and tokens associated with the user.
+ # Delete any mirrors with robots owned by this user.
with db_transaction():
- for app in OAuthApplication.select().where(OAuthApplication.organization == user):
- app.delete_instance(recursive=True)
- else:
- # Remove the user from any teams in which they are a member.
- TeamMember.delete().where(TeamMember.user == user).execute()
+ robots = list(list_namespace_robots(user.username))
+ RepoMirrorConfig.delete().where(
+ RepoMirrorConfig.internal_robot << robots
+ ).execute()
- # Delete any repository buildtriggers where the user is the connected user.
- with db_transaction():
- triggers = RepositoryBuildTrigger.select().where(RepositoryBuildTrigger.connected_user == user)
- for trigger in triggers:
- trigger.delete_instance(recursive=True, delete_nullable=False)
+ # Delete any robots owned by this user.
+ with db_transaction():
+ robots = list(list_namespace_robots(user.username))
+ for robot in robots:
+ robot.delete_instance(recursive=True, delete_nullable=True)
- # Delete any mirrors with robots owned by this user.
- with db_transaction():
- robots = list(list_namespace_robots(user.username))
- RepoMirrorConfig.delete().where(RepoMirrorConfig.internal_robot << robots).execute()
-
- # Delete any robots owned by this user.
- with db_transaction():
- robots = list(list_namespace_robots(user.username))
- for robot in robots:
- robot.delete_instance(recursive=True, delete_nullable=True)
-
- # Null out any service key approvals. We technically lose information here, but its better than
- # falling and only occurs if a superuser is being deleted.
- ServiceKeyApproval.update(approver=None).where(ServiceKeyApproval.approver == user).execute()
+ # Null out any service key approvals. We technically lose information here, but its better than
+ # falling and only occurs if a superuser is being deleted.
+ ServiceKeyApproval.update(approver=None).where(
+ ServiceKeyApproval.approver == user
+ ).execute()
def get_pull_credentials(robotname):
- """ Returns the pull credentials for a robot with the given name. """
- try:
- robot = lookup_robot(robotname)
- except InvalidRobotException:
- return None
+ """ Returns the pull credentials for a robot with the given name. """
+ try:
+ robot = lookup_robot(robotname)
+ except InvalidRobotException:
+ return None
- token = retrieve_robot_token(robot)
+ token = retrieve_robot_token(robot)
+
+ return {
+ "username": robot.username,
+ "password": token,
+ "registry": "%s://%s/v1/"
+ % (
+ config.app_config["PREFERRED_URL_SCHEME"],
+ config.app_config["SERVER_HOSTNAME"],
+ ),
+ }
- return {
- 'username': robot.username,
- 'password': token,
- 'registry': '%s://%s/v1/' % (config.app_config['PREFERRED_URL_SCHEME'],
- config.app_config['SERVER_HOSTNAME']),
- }
def get_region_locations(user):
- """ Returns the locations defined as preferred storage for the given user. """
- query = UserRegion.select().join(ImageStorageLocation).where(UserRegion.user == user)
- return set([region.location.name for region in query])
+ """ Returns the locations defined as preferred storage for the given user. """
+ query = (
+ UserRegion.select().join(ImageStorageLocation).where(UserRegion.user == user)
+ )
+ return set([region.location.name for region in query])
+
def get_federated_logins(user_ids, service_name):
- """ Returns all federated logins for the given user ids under the given external service. """
- if not user_ids:
- return []
+ """ Returns all federated logins for the given user ids under the given external service. """
+ if not user_ids:
+ return []
- return (FederatedLogin
- .select()
- .join(User)
- .switch(FederatedLogin)
- .join(LoginService)
- .where(FederatedLogin.user << user_ids,
- LoginService.name == service_name))
+ return (
+ FederatedLogin.select()
+ .join(User)
+ .switch(FederatedLogin)
+ .join(LoginService)
+ .where(FederatedLogin.user << user_ids, LoginService.name == service_name)
+ )
def list_namespace_geo_restrictions(namespace_name):
- """ Returns all of the defined geographic restrictions for the given namespace. """
- return (NamespaceGeoRestriction
- .select()
- .join(User)
- .where(User.username == namespace_name))
+ """ Returns all of the defined geographic restrictions for the given namespace. """
+ return (
+ NamespaceGeoRestriction.select()
+ .join(User)
+ .where(User.username == namespace_name)
+ )
def get_minimum_user_id():
- return User.select(fn.Min(User.id)).tuples().get()[0]
+ return User.select(fn.Min(User.id)).tuples().get()[0]
class LoginWrappedDBUser(UserMixin):
- def __init__(self, user_uuid, db_user=None):
- self._uuid = user_uuid
- self._db_user = db_user
+ def __init__(self, user_uuid, db_user=None):
+ self._uuid = user_uuid
+ self._db_user = db_user
- def db_user(self):
- if not self._db_user:
- self._db_user = get_user_by_uuid(self._uuid)
- return self._db_user
+ def db_user(self):
+ if not self._db_user:
+ self._db_user = get_user_by_uuid(self._uuid)
+ return self._db_user
- @property
- def is_authenticated(self):
- return self.db_user() is not None
+ @property
+ def is_authenticated(self):
+ return self.db_user() is not None
- @property
- def is_active(self):
- return self.db_user() and self.db_user().verified
+ @property
+ def is_active(self):
+ return self.db_user() and self.db_user().verified
- def get_id(self):
- return unicode(self._uuid)
+ def get_id(self):
+ return unicode(self._uuid)
diff --git a/data/queue.py b/data/queue.py
index 289f4ad64..0a42a8320 100644
--- a/data/queue.py
+++ b/data/queue.py
@@ -12,367 +12,427 @@ DEFAULT_BATCH_SIZE = 1000
class BuildMetricQueueReporter(object):
- """ Metric queue reporter for the build system. """
- def __init__(self, metric_queue):
- self._metric_queue = metric_queue
+ """ Metric queue reporter for the build system. """
- def __call__(self, currently_processing, running_count, total_count):
- need_capacity_count = total_count - running_count
- self._metric_queue.put_deprecated('BuildCapacityShortage', need_capacity_count, unit='Count')
- self._metric_queue.build_capacity_shortage.Set(need_capacity_count)
+ def __init__(self, metric_queue):
+ self._metric_queue = metric_queue
- building_percent = 100 if currently_processing else 0
- self._metric_queue.percent_building.Set(building_percent)
+ def __call__(self, currently_processing, running_count, total_count):
+ need_capacity_count = total_count - running_count
+ self._metric_queue.put_deprecated(
+ "BuildCapacityShortage", need_capacity_count, unit="Count"
+ )
+ self._metric_queue.build_capacity_shortage.Set(need_capacity_count)
+
+ building_percent = 100 if currently_processing else 0
+ self._metric_queue.percent_building.Set(building_percent)
class WorkQueue(object):
- """ Work queue defines methods for interacting with a queue backed by the database. """
- def __init__(self, queue_name, transaction_factory,
- canonical_name_match_list=None, reporter=None, metric_queue=None,
- has_namespace=False):
- self._queue_name = queue_name
- self._reporter = reporter
- self._metric_queue = metric_queue
- self._transaction_factory = transaction_factory
- self._currently_processing = False
- self._has_namespaced_items = has_namespace
+ """ Work queue defines methods for interacting with a queue backed by the database. """
- if canonical_name_match_list is None:
- self._canonical_name_match_list = []
- else:
- self._canonical_name_match_list = canonical_name_match_list
+ def __init__(
+ self,
+ queue_name,
+ transaction_factory,
+ canonical_name_match_list=None,
+ reporter=None,
+ metric_queue=None,
+ has_namespace=False,
+ ):
+ self._queue_name = queue_name
+ self._reporter = reporter
+ self._metric_queue = metric_queue
+ self._transaction_factory = transaction_factory
+ self._currently_processing = False
+ self._has_namespaced_items = has_namespace
- @staticmethod
- def _canonical_name(name_list):
- return '/'.join(name_list) + '/'
+ if canonical_name_match_list is None:
+ self._canonical_name_match_list = []
+ else:
+ self._canonical_name_match_list = canonical_name_match_list
- @classmethod
- def _running_jobs(cls, now, name_match_query):
- return (cls
- ._running_jobs_where(QueueItem.select(QueueItem.queue_name), now)
- .where(QueueItem.queue_name ** name_match_query))
+ @staticmethod
+ def _canonical_name(name_list):
+ return "/".join(name_list) + "/"
- @classmethod
- def _available_jobs(cls, now, name_match_query):
- return (cls
- ._available_jobs_where(QueueItem.select(), now)
- .where(QueueItem.queue_name ** name_match_query))
+ @classmethod
+ def _running_jobs(cls, now, name_match_query):
+ return cls._running_jobs_where(
+ QueueItem.select(QueueItem.queue_name), now
+ ).where(QueueItem.queue_name ** name_match_query)
- @staticmethod
- def _running_jobs_where(query, now):
- return query.where(QueueItem.available == False, QueueItem.processing_expires > now)
+ @classmethod
+ def _available_jobs(cls, now, name_match_query):
+ return cls._available_jobs_where(QueueItem.select(), now).where(
+ QueueItem.queue_name ** name_match_query
+ )
- @staticmethod
- def _available_jobs_where(query, now):
- return query.where(QueueItem.available_after <= now,
- ((QueueItem.available == True) | (QueueItem.processing_expires <= now)),
- QueueItem.retries_remaining > 0)
+ @staticmethod
+ def _running_jobs_where(query, now):
+ return query.where(
+ QueueItem.available == False, QueueItem.processing_expires > now
+ )
- @classmethod
- def _available_jobs_not_running(cls, now, name_match_query, running_query):
- return (cls
- ._available_jobs(now, name_match_query)
- .where(~(QueueItem.queue_name << running_query)))
+ @staticmethod
+ def _available_jobs_where(query, now):
+ return query.where(
+ QueueItem.available_after <= now,
+ ((QueueItem.available == True) | (QueueItem.processing_expires <= now)),
+ QueueItem.retries_remaining > 0,
+ )
- def num_alive_jobs(self, canonical_name_list):
- """
+ @classmethod
+ def _available_jobs_not_running(cls, now, name_match_query, running_query):
+ return cls._available_jobs(now, name_match_query).where(
+ ~(QueueItem.queue_name << running_query)
+ )
+
+ def num_alive_jobs(self, canonical_name_list):
+ """
Returns the number of alive queue items with a given prefix.
"""
- def strip_slash(name):
- return name.lstrip('/')
- canonical_name_list = map(strip_slash, canonical_name_list)
- canonical_name_query = '/'.join([self._queue_name] + canonical_name_list) + '%'
- return (QueueItem
- .select()
+ def strip_slash(name):
+ return name.lstrip("/")
+
+ canonical_name_list = map(strip_slash, canonical_name_list)
+ canonical_name_query = "/".join([self._queue_name] + canonical_name_list) + "%"
+
+ return (
+ QueueItem.select()
.where(QueueItem.queue_name ** canonical_name_query)
.where(QueueItem.retries_remaining > 0)
- .count())
+ .count()
+ )
- def num_available_jobs_between(self, available_min_time, available_max_time, canonical_name_list):
- """
+ def num_available_jobs_between(
+ self, available_min_time, available_max_time, canonical_name_list
+ ):
+ """
Returns the number of available queue items with a given prefix, between the two provided times.
"""
- def strip_slash(name):
- return name.lstrip('/')
- canonical_name_list = map(strip_slash, canonical_name_list)
- available = self._available_jobs(available_max_time,
- '/'.join([self._queue_name] + canonical_name_list) + '%')
+ def strip_slash(name):
+ return name.lstrip("/")
- return available.where(QueueItem.available_after >= available_min_time).count()
+ canonical_name_list = map(strip_slash, canonical_name_list)
- def _name_match_query(self):
- return '%s%%' % self._canonical_name([self._queue_name] + self._canonical_name_match_list)
+ available = self._available_jobs(
+ available_max_time, "/".join([self._queue_name] + canonical_name_list) + "%"
+ )
- @staticmethod
- def _item_by_id_for_update(queue_id):
- return db_for_update(QueueItem.select().where(QueueItem.id == queue_id)).get()
+ return available.where(QueueItem.available_after >= available_min_time).count()
- def get_metrics(self):
- now = datetime.utcnow()
- name_match_query = self._name_match_query()
+ def _name_match_query(self):
+ return "%s%%" % self._canonical_name(
+ [self._queue_name] + self._canonical_name_match_list
+ )
- running_query = self._running_jobs(now, name_match_query)
- running_count = running_query.distinct().count()
+ @staticmethod
+ def _item_by_id_for_update(queue_id):
+ return db_for_update(QueueItem.select().where(QueueItem.id == queue_id)).get()
- available_query = self._available_jobs(now, name_match_query)
- available_count = available_query.select(QueueItem.queue_name).distinct().count()
+ def get_metrics(self):
+ now = datetime.utcnow()
+ name_match_query = self._name_match_query()
- available_not_running_query = self._available_jobs_not_running(now, name_match_query,
- running_query)
- available_not_running_count = (available_not_running_query
- .select(QueueItem.queue_name)
- .distinct()
- .count())
+ running_query = self._running_jobs(now, name_match_query)
+ running_count = running_query.distinct().count()
- return (running_count, available_not_running_count, available_count)
+ available_query = self._available_jobs(now, name_match_query)
+ available_count = (
+ available_query.select(QueueItem.queue_name).distinct().count()
+ )
- def update_metrics(self):
- if self._reporter is None and self._metric_queue is None:
- return
+ available_not_running_query = self._available_jobs_not_running(
+ now, name_match_query, running_query
+ )
+ available_not_running_count = (
+ available_not_running_query.select(QueueItem.queue_name).distinct().count()
+ )
- (running_count, available_not_running_count, available_count) = self.get_metrics()
+ return (running_count, available_not_running_count, available_count)
- if self._metric_queue:
- self._metric_queue.work_queue_running.Set(running_count, labelvalues=[self._queue_name])
- self._metric_queue.work_queue_available.Set(available_count, labelvalues=[self._queue_name])
- self._metric_queue.work_queue_available_not_running.Set(available_not_running_count,
- labelvalues=[self._queue_name])
+ def update_metrics(self):
+ if self._reporter is None and self._metric_queue is None:
+ return
+ (
+ running_count,
+ available_not_running_count,
+ available_count,
+ ) = self.get_metrics()
- if self._reporter:
- self._reporter(self._currently_processing, running_count,
- running_count + available_not_running_count)
+ if self._metric_queue:
+ self._metric_queue.work_queue_running.Set(
+ running_count, labelvalues=[self._queue_name]
+ )
+ self._metric_queue.work_queue_available.Set(
+ available_count, labelvalues=[self._queue_name]
+ )
+ self._metric_queue.work_queue_available_not_running.Set(
+ available_not_running_count, labelvalues=[self._queue_name]
+ )
- def has_retries_remaining(self, item_id):
- """ Returns whether the queue item with the given id has any retries remaining. If the
+ if self._reporter:
+ self._reporter(
+ self._currently_processing,
+ running_count,
+ running_count + available_not_running_count,
+ )
+
+ def has_retries_remaining(self, item_id):
+ """ Returns whether the queue item with the given id has any retries remaining. If the
queue item does not exist, returns False. """
- with self._transaction_factory(db):
- try:
- return QueueItem.get(id=item_id).retries_remaining > 0
- except QueueItem.DoesNotExist:
- return False
+ with self._transaction_factory(db):
+ try:
+ return QueueItem.get(id=item_id).retries_remaining > 0
+ except QueueItem.DoesNotExist:
+ return False
- def delete_namespaced_items(self, namespace, subpath=None):
- """ Deletes all items in this queue that exist under the given namespace. """
- if not self._has_namespaced_items:
- return False
+ def delete_namespaced_items(self, namespace, subpath=None):
+ """ Deletes all items in this queue that exist under the given namespace. """
+ if not self._has_namespaced_items:
+ return False
- subpath_query = '%s/' % subpath if subpath else ''
- queue_prefix = '%s/%s/%s%%' % (self._queue_name, namespace, subpath_query)
- return QueueItem.delete().where(QueueItem.queue_name ** queue_prefix).execute()
+ subpath_query = "%s/" % subpath if subpath else ""
+ queue_prefix = "%s/%s/%s%%" % (self._queue_name, namespace, subpath_query)
+ return QueueItem.delete().where(QueueItem.queue_name ** queue_prefix).execute()
- def alive(self, canonical_name_list):
- """
+ def alive(self, canonical_name_list):
+ """
Returns True if a job matching the canonical name list is currently processing
or available.
"""
- canonical_name = self._canonical_name([self._queue_name] + canonical_name_list)
- try:
- select_query = QueueItem.select().where(QueueItem.queue_name == canonical_name)
- now = datetime.utcnow()
+ canonical_name = self._canonical_name([self._queue_name] + canonical_name_list)
+ try:
+ select_query = QueueItem.select().where(
+ QueueItem.queue_name == canonical_name
+ )
+ now = datetime.utcnow()
- overall_query = (self._available_jobs_where(select_query.clone(), now) |
- self._running_jobs_where(select_query.clone(), now))
- overall_query.get()
- return True
- except QueueItem.DoesNotExist:
- return False
+ overall_query = self._available_jobs_where(
+ select_query.clone(), now
+ ) | self._running_jobs_where(select_query.clone(), now)
+ overall_query.get()
+ return True
+ except QueueItem.DoesNotExist:
+ return False
- def _queue_dict(self, canonical_name_list, message, available_after, retries_remaining):
- return dict(
- queue_name=self._canonical_name([self._queue_name] + canonical_name_list),
- body=message,
- retries_remaining=retries_remaining,
- available_after=datetime.utcnow() + timedelta(seconds=available_after or 0),
- )
+ def _queue_dict(
+ self, canonical_name_list, message, available_after, retries_remaining
+ ):
+ return dict(
+ queue_name=self._canonical_name([self._queue_name] + canonical_name_list),
+ body=message,
+ retries_remaining=retries_remaining,
+ available_after=datetime.utcnow() + timedelta(seconds=available_after or 0),
+ )
- @contextmanager
- def batch_insert(self, batch_size=DEFAULT_BATCH_SIZE):
- items_to_insert = []
- def batch_put(canonical_name_list, message, available_after=0, retries_remaining=5):
- """
+ @contextmanager
+ def batch_insert(self, batch_size=DEFAULT_BATCH_SIZE):
+ items_to_insert = []
+
+ def batch_put(
+ canonical_name_list, message, available_after=0, retries_remaining=5
+ ):
+ """
Put an item, if it shouldn't be processed for some number of seconds,
specify that amount as available_after. Returns the ID of the queue item added.
"""
- items_to_insert.append(self._queue_dict(canonical_name_list, message, available_after,
- retries_remaining))
+ items_to_insert.append(
+ self._queue_dict(
+ canonical_name_list, message, available_after, retries_remaining
+ )
+ )
- yield batch_put
+ yield batch_put
- # Chunk the inserted items into batch_size chunks and insert_many
- remaining = list(items_to_insert)
- while remaining:
- QueueItem.insert_many(remaining[0:batch_size]).execute()
- remaining = remaining[batch_size:]
+ # Chunk the inserted items into batch_size chunks and insert_many
+ remaining = list(items_to_insert)
+ while remaining:
+ QueueItem.insert_many(remaining[0:batch_size]).execute()
+ remaining = remaining[batch_size:]
- def put(self, canonical_name_list, message, available_after=0, retries_remaining=5):
- """
+ def put(self, canonical_name_list, message, available_after=0, retries_remaining=5):
+ """
Put an item, if it shouldn't be processed for some number of seconds,
specify that amount as available_after. Returns the ID of the queue item added.
"""
- item = QueueItem.create(**self._queue_dict(canonical_name_list, message, available_after,
- retries_remaining))
- return str(item.id)
+ item = QueueItem.create(
+ **self._queue_dict(
+ canonical_name_list, message, available_after, retries_remaining
+ )
+ )
+ return str(item.id)
- def _select_available_item(self, ordering_required, now):
- """ Selects an available queue item from the queue table and returns it, if any. If none,
+ def _select_available_item(self, ordering_required, now):
+ """ Selects an available queue item from the queue table and returns it, if any. If none,
return None.
"""
- name_match_query = self._name_match_query()
+ name_match_query = self._name_match_query()
- try:
- if ordering_required:
- # The previous solution to this used a select for update in a
- # transaction to prevent multiple instances from processing the
- # same queue item. This suffered performance problems. This solution
- # instead has instances attempt to update the potential queue item to be
- # unavailable. However, since their update clause is restricted to items
- # that are available=False, only one instance's update will succeed, and
- # it will have a changed row count of 1. Instances that have 0 changed
- # rows know that another instance is already handling that item.
- running = self._running_jobs(now, name_match_query)
- avail = self._available_jobs_not_running(now, name_match_query, running)
- return avail.order_by(QueueItem.id).get()
- else:
- # If we don't require ordering, we grab a random item from any of the first 50 available.
- subquery = self._available_jobs(now, name_match_query).limit(50).alias('j1')
- return (QueueItem
- .select()
- .join(subquery, on=QueueItem.id == subquery.c.id)
- .order_by(db_random_func())
- .get())
+ try:
+ if ordering_required:
+ # The previous solution to this used a select for update in a
+ # transaction to prevent multiple instances from processing the
+ # same queue item. This suffered performance problems. This solution
+ # instead has instances attempt to update the potential queue item to be
+ # unavailable. However, since their update clause is restricted to items
+ # that are available=False, only one instance's update will succeed, and
+ # it will have a changed row count of 1. Instances that have 0 changed
+ # rows know that another instance is already handling that item.
+ running = self._running_jobs(now, name_match_query)
+ avail = self._available_jobs_not_running(now, name_match_query, running)
+ return avail.order_by(QueueItem.id).get()
+ else:
+ # If we don't require ordering, we grab a random item from any of the first 50 available.
+ subquery = (
+ self._available_jobs(now, name_match_query).limit(50).alias("j1")
+ )
+ return (
+ QueueItem.select()
+ .join(subquery, on=QueueItem.id == subquery.c.id)
+ .order_by(db_random_func())
+ .get()
+ )
- except QueueItem.DoesNotExist:
- # No available queue item was found.
- return None
+ except QueueItem.DoesNotExist:
+ # No available queue item was found.
+ return None
- def _attempt_to_claim_item(self, db_item, now, processing_time):
- """ Attempts to claim the specified queue item for this instance. Returns True on success and
+ def _attempt_to_claim_item(self, db_item, now, processing_time):
+ """ Attempts to claim the specified queue item for this instance. Returns True on success and
False on failure.
Note that the underlying QueueItem row in the database will be changed on success, but
the db_item object given as a parameter will *not* have its fields updated.
"""
- # Try to claim the item. We do so by updating the item's information only if its current
- # state ID matches that returned in the previous query. Since all updates to the QueueItem
- # must change the state ID, this is guarenteed to only succeed if the item has not yet been
- # claimed by another caller.
- #
- # Note that we use this method because InnoDB takes locks on *every* clause in the WHERE when
- # performing the update. Previously, we would check all these columns, resulting in a bunch
- # of lock contention. This change mitigates the problem significantly by only checking two
- # columns (id and state_id), both of which should be absolutely unique at all times.
- set_unavailable_query = (QueueItem
- .update(available=False,
- processing_expires=now + timedelta(seconds=processing_time),
- retries_remaining=QueueItem.retries_remaining - 1,
- state_id=str(uuid.uuid4()))
- .where(QueueItem.id == db_item.id,
- QueueItem.state_id == db_item.state_id))
+ # Try to claim the item. We do so by updating the item's information only if its current
+ # state ID matches that returned in the previous query. Since all updates to the QueueItem
+ # must change the state ID, this is guarenteed to only succeed if the item has not yet been
+ # claimed by another caller.
+ #
+ # Note that we use this method because InnoDB takes locks on *every* clause in the WHERE when
+ # performing the update. Previously, we would check all these columns, resulting in a bunch
+ # of lock contention. This change mitigates the problem significantly by only checking two
+ # columns (id and state_id), both of which should be absolutely unique at all times.
+ set_unavailable_query = QueueItem.update(
+ available=False,
+ processing_expires=now + timedelta(seconds=processing_time),
+ retries_remaining=QueueItem.retries_remaining - 1,
+ state_id=str(uuid.uuid4()),
+ ).where(QueueItem.id == db_item.id, QueueItem.state_id == db_item.state_id)
- changed = set_unavailable_query.execute()
- return changed == 1
+ changed = set_unavailable_query.execute()
+ return changed == 1
-
- def get(self, processing_time=300, ordering_required=False):
- """
+ def get(self, processing_time=300, ordering_required=False):
+ """
Get an available item and mark it as unavailable for the default of five
minutes. The result of this method must always be composed of simple
python objects which are JSON serializable for network portability reasons.
"""
- now = datetime.utcnow()
+ now = datetime.utcnow()
- # Select an available queue item.
- db_item = self._select_available_item(ordering_required, now)
- if db_item is None:
- self._currently_processing = False
- return None
+ # Select an available queue item.
+ db_item = self._select_available_item(ordering_required, now)
+ if db_item is None:
+ self._currently_processing = False
+ return None
- # Attempt to claim the item for this instance.
- was_claimed = self._attempt_to_claim_item(db_item, now, processing_time)
- if not was_claimed:
- self._currently_processing = False
- return None
+ # Attempt to claim the item for this instance.
+ was_claimed = self._attempt_to_claim_item(db_item, now, processing_time)
+ if not was_claimed:
+ self._currently_processing = False
+ return None
- self._currently_processing = True
+ self._currently_processing = True
- # Return a view of the queue item rather than an active db object
- return AttrDict({
- 'id': db_item.id,
- 'body': db_item.body,
- 'retries_remaining': db_item.retries_remaining - 1,
- })
+ # Return a view of the queue item rather than an active db object
+ return AttrDict(
+ {
+ "id": db_item.id,
+ "body": db_item.body,
+ "retries_remaining": db_item.retries_remaining - 1,
+ }
+ )
- def cancel(self, item_id):
- """ Attempts to cancel the queue item with the given ID from the queue. Returns true on success
+ def cancel(self, item_id):
+ """ Attempts to cancel the queue item with the given ID from the queue. Returns true on success
and false if the queue item could not be canceled.
"""
- count_removed = QueueItem.delete().where(QueueItem.id == item_id).execute()
- return count_removed > 0
+ count_removed = QueueItem.delete().where(QueueItem.id == item_id).execute()
+ return count_removed > 0
- def complete(self, completed_item):
- self._currently_processing = not self.cancel(completed_item.id)
+ def complete(self, completed_item):
+ self._currently_processing = not self.cancel(completed_item.id)
- def incomplete(self, incomplete_item, retry_after=300, restore_retry=False):
- with self._transaction_factory(db):
- retry_date = datetime.utcnow() + timedelta(seconds=retry_after)
+ def incomplete(self, incomplete_item, retry_after=300, restore_retry=False):
+ with self._transaction_factory(db):
+ retry_date = datetime.utcnow() + timedelta(seconds=retry_after)
- try:
- incomplete_item_obj = self._item_by_id_for_update(incomplete_item.id)
- incomplete_item_obj.available_after = retry_date
- incomplete_item_obj.available = True
+ try:
+ incomplete_item_obj = self._item_by_id_for_update(incomplete_item.id)
+ incomplete_item_obj.available_after = retry_date
+ incomplete_item_obj.available = True
- if restore_retry:
- incomplete_item_obj.retries_remaining += 1
+ if restore_retry:
+ incomplete_item_obj.retries_remaining += 1
- incomplete_item_obj.save()
- self._currently_processing = False
- return incomplete_item_obj.retries_remaining > 0
- except QueueItem.DoesNotExist:
- return False
+ incomplete_item_obj.save()
+ self._currently_processing = False
+ return incomplete_item_obj.retries_remaining > 0
+ except QueueItem.DoesNotExist:
+ return False
- def extend_processing(self, item, seconds_from_now, minimum_extension=MINIMUM_EXTENSION,
- updated_data=None):
- with self._transaction_factory(db):
- try:
- queue_item = self._item_by_id_for_update(item.id)
- new_expiration = datetime.utcnow() + timedelta(seconds=seconds_from_now)
- has_change = False
+ def extend_processing(
+ self,
+ item,
+ seconds_from_now,
+ minimum_extension=MINIMUM_EXTENSION,
+ updated_data=None,
+ ):
+ with self._transaction_factory(db):
+ try:
+ queue_item = self._item_by_id_for_update(item.id)
+ new_expiration = datetime.utcnow() + timedelta(seconds=seconds_from_now)
+ has_change = False
- # Only actually write the new expiration to the db if it moves the expiration some minimum
- if new_expiration - queue_item.processing_expires > minimum_extension:
- queue_item.processing_expires = new_expiration
- has_change = True
+ # Only actually write the new expiration to the db if it moves the expiration some minimum
+ if new_expiration - queue_item.processing_expires > minimum_extension:
+ queue_item.processing_expires = new_expiration
+ has_change = True
- if updated_data is not None and queue_item.body != updated_data:
- queue_item.body = updated_data
- has_change = True
+ if updated_data is not None and queue_item.body != updated_data:
+ queue_item.body = updated_data
+ has_change = True
- if has_change:
- queue_item.save()
+ if has_change:
+ queue_item.save()
- return has_change
- except QueueItem.DoesNotExist:
- return False
+ return has_change
+ except QueueItem.DoesNotExist:
+ return False
def delete_expired(expiration_threshold, deletion_threshold, batch_size):
- """
+ """
Deletes all queue items that are older than the provided expiration threshold in batches of the
provided size. If there are less items than the deletion threshold, this method does nothing.
Returns the number of items deleted.
"""
- to_delete = list(QueueItem
- .select()
- .where(QueueItem.processing_expires <= expiration_threshold)
- .limit(batch_size))
+ to_delete = list(
+ QueueItem.select()
+ .where(QueueItem.processing_expires <= expiration_threshold)
+ .limit(batch_size)
+ )
- if len(to_delete) < deletion_threshold:
- return 0
+ if len(to_delete) < deletion_threshold:
+ return 0
- QueueItem.delete().where(QueueItem.id << to_delete).execute()
- return len(to_delete)
+ QueueItem.delete().where(QueueItem.id << to_delete).execute()
+ return len(to_delete)
diff --git a/data/readreplica.py b/data/readreplica.py
index 33abff2ed..1dfb6d26a 100644
--- a/data/readreplica.py
+++ b/data/readreplica.py
@@ -4,46 +4,48 @@ from collections import namedtuple
from peewee import Model, SENTINEL, OperationalError, Proxy
-ReadOnlyConfig = namedtuple('ReadOnlyConfig', ['is_readonly', 'read_replicas'])
+ReadOnlyConfig = namedtuple("ReadOnlyConfig", ["is_readonly", "read_replicas"])
+
class ReadOnlyModeException(Exception):
- """ Exception raised if a write operation was attempted when in read only mode.
+ """ Exception raised if a write operation was attempted when in read only mode.
"""
class AutomaticFailoverWrapper(object):
- """ Class which wraps a peewee database driver and (optionally) a second driver.
+ """ Class which wraps a peewee database driver and (optionally) a second driver.
When executing SQL, if an OperationalError occurs, if a second driver is given,
the query is attempted again on the fallback DB. Otherwise, the exception is raised.
"""
- def __init__(self, primary_db, fallback_db=None):
- self._primary_db = primary_db
- self._fallback_db = fallback_db
- def __getattr__(self, attribute):
- if attribute != 'execute_sql' and hasattr(self._primary_db, attribute):
- return getattr(self._primary_db, attribute)
+ def __init__(self, primary_db, fallback_db=None):
+ self._primary_db = primary_db
+ self._fallback_db = fallback_db
- return getattr(self, attribute)
+ def __getattr__(self, attribute):
+ if attribute != "execute_sql" and hasattr(self._primary_db, attribute):
+ return getattr(self._primary_db, attribute)
- def execute(self, query, commit=SENTINEL, **context_options):
- ctx = self.get_sql_context(**context_options)
- sql, params = ctx.sql(query).query()
- return self.execute_sql(sql, params, commit=commit)
+ return getattr(self, attribute)
- def execute_sql(self, sql, params=None, commit=SENTINEL):
- try:
- return self._primary_db.execute_sql(sql, params, commit)
- except OperationalError:
- if self._fallback_db is not None:
+ def execute(self, query, commit=SENTINEL, **context_options):
+ ctx = self.get_sql_context(**context_options)
+ sql, params = ctx.sql(query).query()
+ return self.execute_sql(sql, params, commit=commit)
+
+ def execute_sql(self, sql, params=None, commit=SENTINEL):
try:
- return self._fallback_db.execute_sql(sql, params, commit)
+ return self._primary_db.execute_sql(sql, params, commit)
except OperationalError:
- raise
+ if self._fallback_db is not None:
+ try:
+ return self._fallback_db.execute_sql(sql, params, commit)
+ except OperationalError:
+ raise
class ReadReplicaSupportedModel(Model):
- """ Base model for peewee data models that support using a read replica for SELECT
+ """ Base model for peewee data models that support using a read replica for SELECT
requests not under transactions, and automatic failover to the master if the
read replica fails.
@@ -57,73 +59,74 @@ class ReadReplicaSupportedModel(Model):
If the system is configured into read only mode, then all non-read-only queries
will raise a ReadOnlyModeException.
"""
- @classmethod
- def _read_only_config(cls):
- read_only_config = getattr(cls._meta, 'read_only_config', None)
- if read_only_config is None:
- return ReadOnlyConfig(False, [])
- if isinstance(read_only_config, Proxy) and read_only_config.obj is None:
- return ReadOnlyConfig(False, [])
+ @classmethod
+ def _read_only_config(cls):
+ read_only_config = getattr(cls._meta, "read_only_config", None)
+ if read_only_config is None:
+ return ReadOnlyConfig(False, [])
- return read_only_config.obj or ReadOnlyConfig(False, [])
+ if isinstance(read_only_config, Proxy) and read_only_config.obj is None:
+ return ReadOnlyConfig(False, [])
- @classmethod
- def _in_readonly_mode(cls):
- return cls._read_only_config().is_readonly
+ return read_only_config.obj or ReadOnlyConfig(False, [])
- @classmethod
- def _select_database(cls):
- """ Selects a read replica database if we're configured to support read replicas.
+ @classmethod
+ def _in_readonly_mode(cls):
+ return cls._read_only_config().is_readonly
+
+ @classmethod
+ def _select_database(cls):
+ """ Selects a read replica database if we're configured to support read replicas.
Otherwise, selects the master database.
"""
- # Select the master DB if read replica support is not enabled.
- read_only_config = cls._read_only_config()
- if not read_only_config.read_replicas:
- return cls._meta.database
+ # Select the master DB if read replica support is not enabled.
+ read_only_config = cls._read_only_config()
+ if not read_only_config.read_replicas:
+ return cls._meta.database
- # Select the master DB if we're ever under a transaction.
- if cls._meta.database.transaction_depth() > 0:
- return cls._meta.database
+ # Select the master DB if we're ever under a transaction.
+ if cls._meta.database.transaction_depth() > 0:
+ return cls._meta.database
- # Otherwise, return a read replica database with auto-retry onto the main database.
- replicas = read_only_config.read_replicas
- selected_read_replica = replicas[random.randrange(len(replicas))]
- return AutomaticFailoverWrapper(selected_read_replica, cls._meta.database)
+ # Otherwise, return a read replica database with auto-retry onto the main database.
+ replicas = read_only_config.read_replicas
+ selected_read_replica = replicas[random.randrange(len(replicas))]
+ return AutomaticFailoverWrapper(selected_read_replica, cls._meta.database)
- @classmethod
- def select(cls, *args, **kwargs):
- query = super(ReadReplicaSupportedModel, cls).select(*args, **kwargs)
- query._database = cls._select_database()
- return query
+ @classmethod
+ def select(cls, *args, **kwargs):
+ query = super(ReadReplicaSupportedModel, cls).select(*args, **kwargs)
+ query._database = cls._select_database()
+ return query
- @classmethod
- def insert(cls, *args, **kwargs):
- query = super(ReadReplicaSupportedModel, cls).insert(*args, **kwargs)
- if cls._in_readonly_mode():
- raise ReadOnlyModeException()
- return query
+ @classmethod
+ def insert(cls, *args, **kwargs):
+ query = super(ReadReplicaSupportedModel, cls).insert(*args, **kwargs)
+ if cls._in_readonly_mode():
+ raise ReadOnlyModeException()
+ return query
- @classmethod
- def update(cls, *args, **kwargs):
- query = super(ReadReplicaSupportedModel, cls).update(*args, **kwargs)
- if cls._in_readonly_mode():
- raise ReadOnlyModeException()
- return query
+ @classmethod
+ def update(cls, *args, **kwargs):
+ query = super(ReadReplicaSupportedModel, cls).update(*args, **kwargs)
+ if cls._in_readonly_mode():
+ raise ReadOnlyModeException()
+ return query
- @classmethod
- def delete(cls, *args, **kwargs):
- query = super(ReadReplicaSupportedModel, cls).delete(*args, **kwargs)
- if cls._in_readonly_mode():
- raise ReadOnlyModeException()
- return query
+ @classmethod
+ def delete(cls, *args, **kwargs):
+ query = super(ReadReplicaSupportedModel, cls).delete(*args, **kwargs)
+ if cls._in_readonly_mode():
+ raise ReadOnlyModeException()
+ return query
- @classmethod
- def raw(cls, *args, **kwargs):
- query = super(ReadReplicaSupportedModel, cls).raw(*args, **kwargs)
- if query._sql.lower().startswith('select '):
- query._database = cls._select_database()
- elif cls._in_readonly_mode():
- raise ReadOnlyModeException()
+ @classmethod
+ def raw(cls, *args, **kwargs):
+ query = super(ReadReplicaSupportedModel, cls).raw(*args, **kwargs)
+ if query._sql.lower().startswith("select "):
+ query._database = cls._select_database()
+ elif cls._in_readonly_mode():
+ raise ReadOnlyModeException()
- return query
+ return query
diff --git a/data/registry_model/__init__.py b/data/registry_model/__init__.py
index ffac9dd59..5c841c1dc 100644
--- a/data/registry_model/__init__.py
+++ b/data/registry_model/__init__.py
@@ -9,35 +9,48 @@ logger = logging.getLogger(__name__)
class RegistryModelProxy(object):
- def __init__(self):
- self._model = oci_model if os.getenv('OCI_DATA_MODEL') == 'true' else pre_oci_model
+ def __init__(self):
+ self._model = (
+ oci_model if os.getenv("OCI_DATA_MODEL") == "true" else pre_oci_model
+ )
- def setup_split(self, oci_model_proportion, oci_whitelist, v22_whitelist, upgrade_mode):
- if os.getenv('OCI_DATA_MODEL') == 'true':
- return
+ def setup_split(
+ self, oci_model_proportion, oci_whitelist, v22_whitelist, upgrade_mode
+ ):
+ if os.getenv("OCI_DATA_MODEL") == "true":
+ return
- if upgrade_mode == 'complete':
- logger.info('===============================')
- logger.info('Full V2_2 + OCI model is enabled')
- logger.info('===============================')
- self._model = oci_model
- return
+ if upgrade_mode == "complete":
+ logger.info("===============================")
+ logger.info("Full V2_2 + OCI model is enabled")
+ logger.info("===============================")
+ self._model = oci_model
+ return
- logger.info('===============================')
- logger.info('Split registry model: OCI %s proportion and whitelist `%s` and V22 whitelist `%s`',
- oci_model_proportion, oci_whitelist, v22_whitelist)
- logger.info('===============================')
- self._model = SplitModel(oci_model_proportion, oci_whitelist, v22_whitelist,
- upgrade_mode == 'post-oci-rollout')
+ logger.info("===============================")
+ logger.info(
+ "Split registry model: OCI %s proportion and whitelist `%s` and V22 whitelist `%s`",
+ oci_model_proportion,
+ oci_whitelist,
+ v22_whitelist,
+ )
+ logger.info("===============================")
+ self._model = SplitModel(
+ oci_model_proportion,
+ oci_whitelist,
+ v22_whitelist,
+ upgrade_mode == "post-oci-rollout",
+ )
- def set_for_testing(self, use_oci_model):
- self._model = oci_model if use_oci_model else pre_oci_model
- logger.debug('Changed registry model to `%s` for testing', self._model)
+ def set_for_testing(self, use_oci_model):
+ self._model = oci_model if use_oci_model else pre_oci_model
+ logger.debug("Changed registry model to `%s` for testing", self._model)
+
+ def __getattr__(self, attr):
+ return getattr(self._model, attr)
- def __getattr__(self, attr):
- return getattr(self._model, attr)
registry_model = RegistryModelProxy()
-logger.info('===============================')
-logger.info('Using registry model `%s`', registry_model._model)
-logger.info('===============================')
+logger.info("===============================")
+logger.info("Using registry model `%s`", registry_model._model)
+logger.info("===============================")
diff --git a/data/registry_model/blobuploader.py b/data/registry_model/blobuploader.py
index 5f99d3ec8..d4760b1f6 100644
--- a/data/registry_model/blobuploader.py
+++ b/data/registry_model/blobuploader.py
@@ -18,115 +18,140 @@ from util.registry.torrent import PieceHasher
logger = logging.getLogger(__name__)
-BLOB_CONTENT_TYPE = 'application/octet-stream'
+BLOB_CONTENT_TYPE = "application/octet-stream"
class BlobUploadException(Exception):
- """ Base for all exceptions raised when uploading blobs. """
+ """ Base for all exceptions raised when uploading blobs. """
+
class BlobRangeMismatchException(BlobUploadException):
- """ Exception raised if the range to be uploaded does not match. """
+ """ Exception raised if the range to be uploaded does not match. """
+
class BlobDigestMismatchException(BlobUploadException):
- """ Exception raised if the digest requested does not match that of the contents uploaded. """
+ """ Exception raised if the digest requested does not match that of the contents uploaded. """
+
class BlobTooLargeException(BlobUploadException):
- """ Exception raised if the data uploaded exceeds the maximum_blob_size. """
- def __init__(self, uploaded, max_allowed):
- super(BlobTooLargeException, self).__init__()
- self.uploaded = uploaded
- self.max_allowed = max_allowed
+ """ Exception raised if the data uploaded exceeds the maximum_blob_size. """
+
+ def __init__(self, uploaded, max_allowed):
+ super(BlobTooLargeException, self).__init__()
+ self.uploaded = uploaded
+ self.max_allowed = max_allowed
-BlobUploadSettings = namedtuple('BlobUploadSettings', ['maximum_blob_size', 'bittorrent_piece_size',
- 'committed_blob_expiration'])
+BlobUploadSettings = namedtuple(
+ "BlobUploadSettings",
+ ["maximum_blob_size", "bittorrent_piece_size", "committed_blob_expiration"],
+)
-def create_blob_upload(repository_ref, storage, settings, extra_blob_stream_handlers=None):
- """ Creates a new blob upload in the specified repository and returns a manager for interacting
+def create_blob_upload(
+ repository_ref, storage, settings, extra_blob_stream_handlers=None
+):
+ """ Creates a new blob upload in the specified repository and returns a manager for interacting
with that upload. Returns None if a new blob upload could not be started.
"""
- location_name = storage.preferred_locations[0]
- new_upload_uuid, upload_metadata = storage.initiate_chunked_upload(location_name)
- blob_upload = registry_model.create_blob_upload(repository_ref, new_upload_uuid, location_name,
- upload_metadata)
- if blob_upload is None:
- return None
+ location_name = storage.preferred_locations[0]
+ new_upload_uuid, upload_metadata = storage.initiate_chunked_upload(location_name)
+ blob_upload = registry_model.create_blob_upload(
+ repository_ref, new_upload_uuid, location_name, upload_metadata
+ )
+ if blob_upload is None:
+ return None
- return _BlobUploadManager(repository_ref, blob_upload, settings, storage,
- extra_blob_stream_handlers)
+ return _BlobUploadManager(
+ repository_ref, blob_upload, settings, storage, extra_blob_stream_handlers
+ )
def retrieve_blob_upload_manager(repository_ref, blob_upload_id, storage, settings):
- """ Retrieves the manager for an in-progress blob upload with the specified ID under the given
+ """ Retrieves the manager for an in-progress blob upload with the specified ID under the given
repository or None if none.
"""
- blob_upload = registry_model.lookup_blob_upload(repository_ref, blob_upload_id)
- if blob_upload is None:
- return None
+ blob_upload = registry_model.lookup_blob_upload(repository_ref, blob_upload_id)
+ if blob_upload is None:
+ return None
+
+ return _BlobUploadManager(repository_ref, blob_upload, settings, storage)
- return _BlobUploadManager(repository_ref, blob_upload, settings, storage)
@contextmanager
def complete_when_uploaded(blob_upload):
- """ Wraps the given blob upload in a context manager that completes the upload when the context
+ """ Wraps the given blob upload in a context manager that completes the upload when the context
closes.
"""
- try:
- yield blob_upload
- except Exception as ex:
- logger.exception('Exception when uploading blob `%s`', blob_upload.blob_upload_id)
- raise ex
- finally:
- # Cancel the upload if something went wrong or it was not commit to a blob.
- if blob_upload.committed_blob is None:
- blob_upload.cancel_upload()
+ try:
+ yield blob_upload
+ except Exception as ex:
+ logger.exception(
+ "Exception when uploading blob `%s`", blob_upload.blob_upload_id
+ )
+ raise ex
+ finally:
+ # Cancel the upload if something went wrong or it was not commit to a blob.
+ if blob_upload.committed_blob is None:
+ blob_upload.cancel_upload()
+
@contextmanager
def upload_blob(repository_ref, storage, settings, extra_blob_stream_handlers=None):
- """ Starts a new blob upload in the specified repository and yields a manager for interacting
+ """ Starts a new blob upload in the specified repository and yields a manager for interacting
with that upload. When the context manager completes, the blob upload is deleted, whether
committed to a blob or not. Yields None if a blob upload could not be started.
"""
- created = create_blob_upload(repository_ref, storage, settings, extra_blob_stream_handlers)
- if not created:
- yield None
- return
+ created = create_blob_upload(
+ repository_ref, storage, settings, extra_blob_stream_handlers
+ )
+ if not created:
+ yield None
+ return
- try:
- yield created
- except Exception as ex:
- logger.exception('Exception when uploading blob `%s`', created.blob_upload_id)
- raise ex
- finally:
- # Cancel the upload if something went wrong or it was not commit to a blob.
- if created.committed_blob is None:
- created.cancel_upload()
+ try:
+ yield created
+ except Exception as ex:
+ logger.exception("Exception when uploading blob `%s`", created.blob_upload_id)
+ raise ex
+ finally:
+ # Cancel the upload if something went wrong or it was not commit to a blob.
+ if created.committed_blob is None:
+ created.cancel_upload()
class _BlobUploadManager(object):
- """ Defines a helper class for easily interacting with blob uploads in progress, including
+ """ Defines a helper class for easily interacting with blob uploads in progress, including
handling of database and storage calls.
"""
- def __init__(self, repository_ref, blob_upload, settings, storage,
- extra_blob_stream_handlers=None):
- assert repository_ref is not None
- assert blob_upload is not None
- self.repository_ref = repository_ref
- self.blob_upload = blob_upload
- self.settings = settings
- self.storage = storage
- self.extra_blob_stream_handlers = extra_blob_stream_handlers
- self.committed_blob = None
+ def __init__(
+ self,
+ repository_ref,
+ blob_upload,
+ settings,
+ storage,
+ extra_blob_stream_handlers=None,
+ ):
+ assert repository_ref is not None
+ assert blob_upload is not None
- @property
- def blob_upload_id(self):
- """ Returns the unique ID for the blob upload. """
- return self.blob_upload.upload_id
+ self.repository_ref = repository_ref
+ self.blob_upload = blob_upload
+ self.settings = settings
+ self.storage = storage
+ self.extra_blob_stream_handlers = extra_blob_stream_handlers
+ self.committed_blob = None
- def upload_chunk(self, app_config, input_fp, start_offset=0, length=-1, metric_queue=None):
- """ Uploads a chunk of data found in the given input file-like interface. start_offset and
+ @property
+ def blob_upload_id(self):
+ """ Returns the unique ID for the blob upload. """
+ return self.blob_upload.upload_id
+
+ def upload_chunk(
+ self, app_config, input_fp, start_offset=0, length=-1, metric_queue=None
+ ):
+ """ Uploads a chunk of data found in the given input file-like interface. start_offset and
length are optional and should match a range header if any was given.
If metric_queue is given, the upload time and chunk size are written into the metrics in
@@ -135,201 +160,250 @@ class _BlobUploadManager(object):
Returns the total number of bytes uploaded after this upload has completed. Raises
a BlobUploadException if the upload failed.
"""
- assert start_offset is not None
- assert length is not None
+ assert start_offset is not None
+ assert length is not None
- if start_offset > 0 and start_offset > self.blob_upload.byte_count:
- logger.error('start_offset provided greater than blob_upload.byte_count')
- raise BlobRangeMismatchException()
+ if start_offset > 0 and start_offset > self.blob_upload.byte_count:
+ logger.error("start_offset provided greater than blob_upload.byte_count")
+ raise BlobRangeMismatchException()
- # Ensure that we won't go over the allowed maximum size for blobs.
- max_blob_size = bitmath.parse_string_unsafe(self.settings.maximum_blob_size)
- uploaded = bitmath.Byte(length + start_offset)
- if length > -1 and uploaded > max_blob_size:
- raise BlobTooLargeException(uploaded=uploaded.bytes, max_allowed=max_blob_size.bytes)
+ # Ensure that we won't go over the allowed maximum size for blobs.
+ max_blob_size = bitmath.parse_string_unsafe(self.settings.maximum_blob_size)
+ uploaded = bitmath.Byte(length + start_offset)
+ if length > -1 and uploaded > max_blob_size:
+ raise BlobTooLargeException(
+ uploaded=uploaded.bytes, max_allowed=max_blob_size.bytes
+ )
- location_set = {self.blob_upload.location_name}
- upload_error = None
- with CloseForLongOperation(app_config):
- if start_offset > 0 and start_offset < self.blob_upload.byte_count:
- # Skip the bytes which were received on a previous push, which are already stored and
- # included in the sha calculation
- overlap_size = self.blob_upload.byte_count - start_offset
- input_fp = StreamSlice(input_fp, overlap_size)
+ location_set = {self.blob_upload.location_name}
+ upload_error = None
+ with CloseForLongOperation(app_config):
+ if start_offset > 0 and start_offset < self.blob_upload.byte_count:
+ # Skip the bytes which were received on a previous push, which are already stored and
+ # included in the sha calculation
+ overlap_size = self.blob_upload.byte_count - start_offset
+ input_fp = StreamSlice(input_fp, overlap_size)
- # Update our upload bounds to reflect the skipped portion of the overlap
- start_offset = self.blob_upload.byte_count
- length = max(length - overlap_size, 0)
+ # Update our upload bounds to reflect the skipped portion of the overlap
+ start_offset = self.blob_upload.byte_count
+ length = max(length - overlap_size, 0)
- # We use this to escape early in case we have already processed all of the bytes the user
- # wants to upload.
- if length == 0:
- return self.blob_upload.byte_count
+ # We use this to escape early in case we have already processed all of the bytes the user
+ # wants to upload.
+ if length == 0:
+ return self.blob_upload.byte_count
- input_fp = wrap_with_handler(input_fp, self.blob_upload.sha_state.update)
+ input_fp = wrap_with_handler(input_fp, self.blob_upload.sha_state.update)
- if self.extra_blob_stream_handlers:
- for handler in self.extra_blob_stream_handlers:
- input_fp = wrap_with_handler(input_fp, handler)
+ if self.extra_blob_stream_handlers:
+ for handler in self.extra_blob_stream_handlers:
+ input_fp = wrap_with_handler(input_fp, handler)
- # Add a hasher for calculating SHA1s for torrents if this is the first chunk and/or we have
- # already calculated hash data for the previous chunk(s).
- piece_hasher = None
- if self.blob_upload.chunk_count == 0 or self.blob_upload.piece_sha_state:
- initial_sha1_value = self.blob_upload.piece_sha_state or resumablehashlib.sha1()
- initial_sha1_pieces_value = self.blob_upload.piece_hashes or ''
+ # Add a hasher for calculating SHA1s for torrents if this is the first chunk and/or we have
+ # already calculated hash data for the previous chunk(s).
+ piece_hasher = None
+ if self.blob_upload.chunk_count == 0 or self.blob_upload.piece_sha_state:
+ initial_sha1_value = (
+ self.blob_upload.piece_sha_state or resumablehashlib.sha1()
+ )
+ initial_sha1_pieces_value = self.blob_upload.piece_hashes or ""
- piece_hasher = PieceHasher(self.settings.bittorrent_piece_size, start_offset,
- initial_sha1_pieces_value, initial_sha1_value)
- input_fp = wrap_with_handler(input_fp, piece_hasher.update)
+ piece_hasher = PieceHasher(
+ self.settings.bittorrent_piece_size,
+ start_offset,
+ initial_sha1_pieces_value,
+ initial_sha1_value,
+ )
+ input_fp = wrap_with_handler(input_fp, piece_hasher.update)
- # If this is the first chunk and we're starting at the 0 offset, add a handler to gunzip the
- # stream so we can determine the uncompressed size. We'll throw out this data if another chunk
- # comes in, but in the common case the docker client only sends one chunk.
- size_info = None
- if start_offset == 0 and self.blob_upload.chunk_count == 0:
- size_info, fn = calculate_size_handler()
- input_fp = wrap_with_handler(input_fp, fn)
+ # If this is the first chunk and we're starting at the 0 offset, add a handler to gunzip the
+ # stream so we can determine the uncompressed size. We'll throw out this data if another chunk
+ # comes in, but in the common case the docker client only sends one chunk.
+ size_info = None
+ if start_offset == 0 and self.blob_upload.chunk_count == 0:
+ size_info, fn = calculate_size_handler()
+ input_fp = wrap_with_handler(input_fp, fn)
- start_time = time.time()
- length_written, new_metadata, upload_error = self.storage.stream_upload_chunk(
- location_set,
- self.blob_upload.upload_id,
- start_offset,
- length,
- input_fp,
- self.blob_upload.storage_metadata,
- content_type=BLOB_CONTENT_TYPE,
- )
+ start_time = time.time()
+ length_written, new_metadata, upload_error = self.storage.stream_upload_chunk(
+ location_set,
+ self.blob_upload.upload_id,
+ start_offset,
+ length,
+ input_fp,
+ self.blob_upload.storage_metadata,
+ content_type=BLOB_CONTENT_TYPE,
+ )
- if upload_error is not None:
- logger.error('storage.stream_upload_chunk returned error %s', upload_error)
- raise BlobUploadException(upload_error)
+ if upload_error is not None:
+ logger.error(
+ "storage.stream_upload_chunk returned error %s", upload_error
+ )
+ raise BlobUploadException(upload_error)
- # Update the chunk upload time and push bytes metrics.
- if metric_queue is not None:
- metric_queue.chunk_upload_time.Observe(time.time() - start_time, labelvalues=[
- length_written, list(location_set)[0]])
+ # Update the chunk upload time and push bytes metrics.
+ if metric_queue is not None:
+ metric_queue.chunk_upload_time.Observe(
+ time.time() - start_time,
+ labelvalues=[length_written, list(location_set)[0]],
+ )
- metric_queue.push_byte_count.Inc(length_written)
+ metric_queue.push_byte_count.Inc(length_written)
- # Ensure we have not gone beyond the max layer size.
- new_blob_bytes = self.blob_upload.byte_count + length_written
- new_blob_size = bitmath.Byte(new_blob_bytes)
- if new_blob_size > max_blob_size:
- raise BlobTooLargeException(uploaded=new_blob_size, max_allowed=max_blob_size.bytes)
+ # Ensure we have not gone beyond the max layer size.
+ new_blob_bytes = self.blob_upload.byte_count + length_written
+ new_blob_size = bitmath.Byte(new_blob_bytes)
+ if new_blob_size > max_blob_size:
+ raise BlobTooLargeException(
+ uploaded=new_blob_size, max_allowed=max_blob_size.bytes
+ )
- # If we determined an uncompressed size and this is the first chunk, add it to the blob.
- # Otherwise, we clear the size from the blob as it was uploaded in multiple chunks.
- uncompressed_byte_count = self.blob_upload.uncompressed_byte_count
- if size_info is not None and self.blob_upload.chunk_count == 0 and size_info.is_valid:
- uncompressed_byte_count = size_info.uncompressed_size
- elif length_written > 0:
- # Otherwise, if we wrote some bytes and the above conditions were not met, then we don't
- # know the uncompressed size.
- uncompressed_byte_count = None
+ # If we determined an uncompressed size and this is the first chunk, add it to the blob.
+ # Otherwise, we clear the size from the blob as it was uploaded in multiple chunks.
+ uncompressed_byte_count = self.blob_upload.uncompressed_byte_count
+ if (
+ size_info is not None
+ and self.blob_upload.chunk_count == 0
+ and size_info.is_valid
+ ):
+ uncompressed_byte_count = size_info.uncompressed_size
+ elif length_written > 0:
+ # Otherwise, if we wrote some bytes and the above conditions were not met, then we don't
+ # know the uncompressed size.
+ uncompressed_byte_count = None
- piece_hashes = None
- piece_sha_state = None
- if piece_hasher is not None:
- piece_hashes = piece_hasher.piece_hashes
- piece_sha_state = piece_hasher.hash_fragment
+ piece_hashes = None
+ piece_sha_state = None
+ if piece_hasher is not None:
+ piece_hashes = piece_hasher.piece_hashes
+ piece_sha_state = piece_hasher.hash_fragment
- self.blob_upload = registry_model.update_blob_upload(self.blob_upload,
- uncompressed_byte_count,
- piece_hashes,
- piece_sha_state,
- new_metadata,
- new_blob_bytes,
- self.blob_upload.chunk_count + 1,
- self.blob_upload.sha_state)
- if self.blob_upload is None:
- raise BlobUploadException('Could not complete upload of chunk')
+ self.blob_upload = registry_model.update_blob_upload(
+ self.blob_upload,
+ uncompressed_byte_count,
+ piece_hashes,
+ piece_sha_state,
+ new_metadata,
+ new_blob_bytes,
+ self.blob_upload.chunk_count + 1,
+ self.blob_upload.sha_state,
+ )
+ if self.blob_upload is None:
+ raise BlobUploadException("Could not complete upload of chunk")
- return new_blob_bytes
+ return new_blob_bytes
- def cancel_upload(self):
- """ Cancels the blob upload, deleting any data uploaded and removing the upload itself. """
- if self.blob_upload is None:
- return
+ def cancel_upload(self):
+ """ Cancels the blob upload, deleting any data uploaded and removing the upload itself. """
+ if self.blob_upload is None:
+ return
- # Tell storage to cancel the chunked upload, deleting its contents.
- self.storage.cancel_chunked_upload({self.blob_upload.location_name},
- self.blob_upload.upload_id,
- self.blob_upload.storage_metadata)
+ # Tell storage to cancel the chunked upload, deleting its contents.
+ self.storage.cancel_chunked_upload(
+ {self.blob_upload.location_name},
+ self.blob_upload.upload_id,
+ self.blob_upload.storage_metadata,
+ )
- # Remove the blob upload record itself.
- registry_model.delete_blob_upload(self.blob_upload)
+ # Remove the blob upload record itself.
+ registry_model.delete_blob_upload(self.blob_upload)
- def commit_to_blob(self, app_config, expected_digest=None):
- """ Commits the blob upload to a blob under the repository. The resulting blob will be marked
+ def commit_to_blob(self, app_config, expected_digest=None):
+ """ Commits the blob upload to a blob under the repository. The resulting blob will be marked
to not be GCed for some period of time (as configured by `committed_blob_expiration`).
If expected_digest is specified, the content digest of the data uploaded for the blob is
compared to that given and, if it does not match, a BlobDigestMismatchException is
raised. The digest given must be of type `Digest` and not a string.
"""
- # Compare the content digest.
- if expected_digest is not None:
- self._validate_digest(expected_digest)
+ # Compare the content digest.
+ if expected_digest is not None:
+ self._validate_digest(expected_digest)
- # Finalize the storage.
- storage_already_existed = self._finalize_blob_storage(app_config)
+ # Finalize the storage.
+ storage_already_existed = self._finalize_blob_storage(app_config)
- # Convert the upload to a blob.
- computed_digest_str = digest_tools.sha256_digest_from_hashlib(self.blob_upload.sha_state)
+ # Convert the upload to a blob.
+ computed_digest_str = digest_tools.sha256_digest_from_hashlib(
+ self.blob_upload.sha_state
+ )
- with db_transaction():
- blob = registry_model.commit_blob_upload(self.blob_upload, computed_digest_str,
- self.settings.committed_blob_expiration)
- if blob is None:
- return None
+ with db_transaction():
+ blob = registry_model.commit_blob_upload(
+ self.blob_upload,
+ computed_digest_str,
+ self.settings.committed_blob_expiration,
+ )
+ if blob is None:
+ return None
- # Save torrent hash information (if available).
- if self.blob_upload.piece_sha_state is not None and not storage_already_existed:
- piece_bytes = self.blob_upload.piece_hashes + self.blob_upload.piece_sha_state.digest()
- registry_model.set_torrent_info(blob, self.settings.bittorrent_piece_size, piece_bytes)
+ # Save torrent hash information (if available).
+ if (
+ self.blob_upload.piece_sha_state is not None
+ and not storage_already_existed
+ ):
+ piece_bytes = (
+ self.blob_upload.piece_hashes
+ + self.blob_upload.piece_sha_state.digest()
+ )
+ registry_model.set_torrent_info(
+ blob, self.settings.bittorrent_piece_size, piece_bytes
+ )
- self.committed_blob = blob
- return blob
+ self.committed_blob = blob
+ return blob
- def _validate_digest(self, expected_digest):
- """
+ def _validate_digest(self, expected_digest):
+ """
Verifies that the digest's SHA matches that of the uploaded data.
"""
- computed_digest = digest_tools.sha256_digest_from_hashlib(self.blob_upload.sha_state)
- if not digest_tools.digests_equal(computed_digest, expected_digest):
- logger.error('Digest mismatch for upload %s: Expected digest %s, found digest %s',
- self.blob_upload.upload_id, expected_digest, computed_digest)
- raise BlobDigestMismatchException()
+ computed_digest = digest_tools.sha256_digest_from_hashlib(
+ self.blob_upload.sha_state
+ )
+ if not digest_tools.digests_equal(computed_digest, expected_digest):
+ logger.error(
+ "Digest mismatch for upload %s: Expected digest %s, found digest %s",
+ self.blob_upload.upload_id,
+ expected_digest,
+ computed_digest,
+ )
+ raise BlobDigestMismatchException()
- def _finalize_blob_storage(self, app_config):
- """
+ def _finalize_blob_storage(self, app_config):
+ """
When an upload is successful, this ends the uploading process from the
storage's perspective.
Returns True if the blob already existed.
"""
- computed_digest = digest_tools.sha256_digest_from_hashlib(self.blob_upload.sha_state)
- final_blob_location = digest_tools.content_path(computed_digest)
+ computed_digest = digest_tools.sha256_digest_from_hashlib(
+ self.blob_upload.sha_state
+ )
+ final_blob_location = digest_tools.content_path(computed_digest)
- # Close the database connection before we perform this operation, as it can take a while
- # and we shouldn't hold the connection during that time.
- with CloseForLongOperation(app_config):
- # Move the storage into place, or if this was a re-upload, cancel it
- already_existed = self.storage.exists({self.blob_upload.location_name}, final_blob_location)
- if already_existed:
- # It already existed, clean up our upload which served as proof that the
- # uploader had the blob.
- self.storage.cancel_chunked_upload({self.blob_upload.location_name},
- self.blob_upload.upload_id,
- self.blob_upload.storage_metadata)
- else:
- # We were the first ones to upload this image (at least to this location)
- # Let's copy it into place
- self.storage.complete_chunked_upload({self.blob_upload.location_name},
- self.blob_upload.upload_id,
- final_blob_location,
- self.blob_upload.storage_metadata)
+ # Close the database connection before we perform this operation, as it can take a while
+ # and we shouldn't hold the connection during that time.
+ with CloseForLongOperation(app_config):
+ # Move the storage into place, or if this was a re-upload, cancel it
+ already_existed = self.storage.exists(
+ {self.blob_upload.location_name}, final_blob_location
+ )
+ if already_existed:
+ # It already existed, clean up our upload which served as proof that the
+ # uploader had the blob.
+ self.storage.cancel_chunked_upload(
+ {self.blob_upload.location_name},
+ self.blob_upload.upload_id,
+ self.blob_upload.storage_metadata,
+ )
+ else:
+ # We were the first ones to upload this image (at least to this location)
+ # Let's copy it into place
+ self.storage.complete_chunked_upload(
+ {self.blob_upload.location_name},
+ self.blob_upload.upload_id,
+ final_blob_location,
+ self.blob_upload.storage_metadata,
+ )
- return already_existed
+ return already_existed
diff --git a/data/registry_model/datatype.py b/data/registry_model/datatype.py
index 091776bb1..43b996285 100644
--- a/data/registry_model/datatype.py
+++ b/data/registry_model/datatype.py
@@ -2,85 +2,93 @@
from functools import wraps, total_ordering
+
class FromDictionaryException(Exception):
- """ Exception raised if constructing a data type from a dictionary fails due to
+ """ Exception raised if constructing a data type from a dictionary fails due to
missing data.
"""
+
def datatype(name, static_fields):
- """ Defines a base class for a datatype that will represent a row from the database,
+ """ Defines a base class for a datatype that will represent a row from the database,
in an abstracted form.
"""
- @total_ordering
- class DataType(object):
- __name__ = name
- def __init__(self, **kwargs):
- self._db_id = kwargs.pop('db_id', None)
- self._inputs = kwargs.pop('inputs', None)
- self._fields = kwargs
+ @total_ordering
+ class DataType(object):
+ __name__ = name
- for name in static_fields:
- assert name in self._fields, 'Missing field %s' % name
+ def __init__(self, **kwargs):
+ self._db_id = kwargs.pop("db_id", None)
+ self._inputs = kwargs.pop("inputs", None)
+ self._fields = kwargs
- def __eq__(self, other):
- return self._db_id == other._db_id
+ for name in static_fields:
+ assert name in self._fields, "Missing field %s" % name
- def __lt__(self, other):
- return self._db_id < other._db_id
+ def __eq__(self, other):
+ return self._db_id == other._db_id
- def __getattr__(self, name):
- if name in static_fields:
- return self._fields[name]
+ def __lt__(self, other):
+ return self._db_id < other._db_id
- raise AttributeError('Unknown field `%s`' % name)
+ def __getattr__(self, name):
+ if name in static_fields:
+ return self._fields[name]
- def __repr__(self):
- return '<%s> #%s' % (name, self._db_id)
+ raise AttributeError("Unknown field `%s`" % name)
- @classmethod
- def from_dict(cls, dict_data):
- try:
- return cls(**dict_data)
- except:
- raise FromDictionaryException()
+ def __repr__(self):
+ return "<%s> #%s" % (name, self._db_id)
- def asdict(self):
- dictionary_rep = dict(self._fields)
- assert ('db_id' not in dictionary_rep and
- 'inputs' not in dictionary_rep)
+ @classmethod
+ def from_dict(cls, dict_data):
+ try:
+ return cls(**dict_data)
+ except:
+ raise FromDictionaryException()
- dictionary_rep['db_id'] = self._db_id
- dictionary_rep['inputs'] = self._inputs
- return dictionary_rep
+ def asdict(self):
+ dictionary_rep = dict(self._fields)
+ assert "db_id" not in dictionary_rep and "inputs" not in dictionary_rep
- return DataType
+ dictionary_rep["db_id"] = self._db_id
+ dictionary_rep["inputs"] = self._inputs
+ return dictionary_rep
+
+ return DataType
def requiresinput(input_name):
- """ Marks a property on the data type as requiring an input to be invoked. """
- def inner(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- if self._inputs.get(input_name) is None:
- raise Exception('Cannot invoke function with missing input `%s`' % input_name)
+ """ Marks a property on the data type as requiring an input to be invoked. """
- kwargs[input_name] = self._inputs[input_name]
- result = func(self, *args, **kwargs)
- return result
+ def inner(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if self._inputs.get(input_name) is None:
+ raise Exception(
+ "Cannot invoke function with missing input `%s`" % input_name
+ )
- return wrapper
- return inner
+ kwargs[input_name] = self._inputs[input_name]
+ result = func(self, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ return inner
def optionalinput(input_name):
- """ Marks a property on the data type as having an input be optional when invoked. """
- def inner(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- kwargs[input_name] = self._inputs.get(input_name)
- result = func(self, *args, **kwargs)
- return result
+ """ Marks a property on the data type as having an input be optional when invoked. """
- return wrapper
- return inner
+ def inner(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ kwargs[input_name] = self._inputs.get(input_name)
+ result = func(self, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ return inner
diff --git a/data/registry_model/datatypes.py b/data/registry_model/datatypes.py
index b732fbefc..94caec410 100644
--- a/data/registry_model/datatypes.py
+++ b/data/registry_model/datatypes.py
@@ -15,490 +15,606 @@ from image.docker.schema2 import DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
from util.bytes import Bytes
-class RepositoryReference(datatype('Repository', [])):
- """ RepositoryReference is a reference to a repository, passed to registry interface methods. """
- @classmethod
- def for_repo_obj(cls, repo_obj, namespace_name=None, repo_name=None, is_free_namespace=None,
- state=None):
- if repo_obj is None:
- return None
+class RepositoryReference(datatype("Repository", [])):
+ """ RepositoryReference is a reference to a repository, passed to registry interface methods. """
- return RepositoryReference(db_id=repo_obj.id,
- inputs=dict(
- kind=model.repository.get_repo_kind_name(repo_obj),
- is_public=model.repository.is_repository_public(repo_obj),
- namespace_name=namespace_name,
- repo_name=repo_name,
- is_free_namespace=is_free_namespace,
- state=state
- ))
+ @classmethod
+ def for_repo_obj(
+ cls,
+ repo_obj,
+ namespace_name=None,
+ repo_name=None,
+ is_free_namespace=None,
+ state=None,
+ ):
+ if repo_obj is None:
+ return None
- @classmethod
- def for_id(cls, repo_id, namespace_name=None, repo_name=None, is_free_namespace=None, state=None):
- return RepositoryReference(db_id=repo_id,
- inputs=dict(
- kind=None,
- is_public=None,
- namespace_name=namespace_name,
- repo_name=repo_name,
- is_free_namespace=is_free_namespace,
- state=state
- ))
+ return RepositoryReference(
+ db_id=repo_obj.id,
+ inputs=dict(
+ kind=model.repository.get_repo_kind_name(repo_obj),
+ is_public=model.repository.is_repository_public(repo_obj),
+ namespace_name=namespace_name,
+ repo_name=repo_name,
+ is_free_namespace=is_free_namespace,
+ state=state,
+ ),
+ )
- @property
- @lru_cache(maxsize=1)
- def _repository_obj(self):
- return model.repository.lookup_repository(self._db_id)
+ @classmethod
+ def for_id(
+ cls,
+ repo_id,
+ namespace_name=None,
+ repo_name=None,
+ is_free_namespace=None,
+ state=None,
+ ):
+ return RepositoryReference(
+ db_id=repo_id,
+ inputs=dict(
+ kind=None,
+ is_public=None,
+ namespace_name=namespace_name,
+ repo_name=repo_name,
+ is_free_namespace=is_free_namespace,
+ state=state,
+ ),
+ )
- @property
- @optionalinput('kind')
- def kind(self, kind):
- """ Returns the kind of the repository. """
- return kind or model.repository.get_repo_kind_name(self._repositry_obj)
+ @property
+ @lru_cache(maxsize=1)
+ def _repository_obj(self):
+ return model.repository.lookup_repository(self._db_id)
- @property
- @optionalinput('is_public')
- def is_public(self, is_public):
- """ Returns whether the repository is public. """
- if is_public is not None:
- return is_public
+ @property
+ @optionalinput("kind")
+ def kind(self, kind):
+ """ Returns the kind of the repository. """
+ return kind or model.repository.get_repo_kind_name(self._repositry_obj)
- return model.repository.is_repository_public(self._repository_obj)
+ @property
+ @optionalinput("is_public")
+ def is_public(self, is_public):
+ """ Returns whether the repository is public. """
+ if is_public is not None:
+ return is_public
- @property
- def trust_enabled(self):
- """ Returns whether trust is enabled in this repository. """
- repository = self._repository_obj
- if repository is None:
- return None
+ return model.repository.is_repository_public(self._repository_obj)
- return repository.trust_enabled
+ @property
+ def trust_enabled(self):
+ """ Returns whether trust is enabled in this repository. """
+ repository = self._repository_obj
+ if repository is None:
+ return None
- @property
- def id(self):
- """ Returns the database ID of the repository. """
- return self._db_id
+ return repository.trust_enabled
- @property
- @optionalinput('namespace_name')
- def namespace_name(self, namespace_name=None):
- """ Returns the namespace name of this repository.
+ @property
+ def id(self):
+ """ Returns the database ID of the repository. """
+ return self._db_id
+
+ @property
+ @optionalinput("namespace_name")
+ def namespace_name(self, namespace_name=None):
+ """ Returns the namespace name of this repository.
"""
- if namespace_name is not None:
- return namespace_name
+ if namespace_name is not None:
+ return namespace_name
- repository = self._repository_obj
- if repository is None:
- return None
+ repository = self._repository_obj
+ if repository is None:
+ return None
- return repository.namespace_user.username
+ return repository.namespace_user.username
- @property
- @optionalinput('is_free_namespace')
- def is_free_namespace(self, is_free_namespace=None):
- """ Returns whether the namespace of the repository is on a free plan.
+ @property
+ @optionalinput("is_free_namespace")
+ def is_free_namespace(self, is_free_namespace=None):
+ """ Returns whether the namespace of the repository is on a free plan.
"""
- if is_free_namespace is not None:
- return is_free_namespace
+ if is_free_namespace is not None:
+ return is_free_namespace
- repository = self._repository_obj
- if repository is None:
- return None
+ repository = self._repository_obj
+ if repository is None:
+ return None
- return repository.namespace_user.stripe_id is None
+ return repository.namespace_user.stripe_id is None
- @property
- @optionalinput('repo_name')
- def name(self, repo_name=None):
- """ Returns the name of this repository.
+ @property
+ @optionalinput("repo_name")
+ def name(self, repo_name=None):
+ """ Returns the name of this repository.
"""
- if repo_name is not None:
- return repo_name
+ if repo_name is not None:
+ return repo_name
- repository = self._repository_obj
- if repository is None:
- return None
+ repository = self._repository_obj
+ if repository is None:
+ return None
- return repository.name
+ return repository.name
- @property
- @optionalinput('state')
- def state(self, state=None):
- """ Return the state of the Repository. """
- if state is not None:
- return state
+ @property
+ @optionalinput("state")
+ def state(self, state=None):
+ """ Return the state of the Repository. """
+ if state is not None:
+ return state
- repository = self._repository_obj
- if repository is None:
- return None
+ repository = self._repository_obj
+ if repository is None:
+ return None
- return repository.state
+ return repository.state
-class Label(datatype('Label', ['key', 'value', 'uuid', 'source_type_name', 'media_type_name'])):
- """ Label represents a label on a manifest. """
- @classmethod
- def for_label(cls, label):
- if label is None:
- return None
+class Label(
+ datatype("Label", ["key", "value", "uuid", "source_type_name", "media_type_name"])
+):
+ """ Label represents a label on a manifest. """
- return Label(db_id=label.id, key=label.key, value=label.value,
- uuid=label.uuid, media_type_name=label.media_type.name,
- source_type_name=label.source_type.name)
+ @classmethod
+ def for_label(cls, label):
+ if label is None:
+ return None
+
+ return Label(
+ db_id=label.id,
+ key=label.key,
+ value=label.value,
+ uuid=label.uuid,
+ media_type_name=label.media_type.name,
+ source_type_name=label.source_type.name,
+ )
-class ShallowTag(datatype('ShallowTag', ['name'])):
- """ ShallowTag represents a tag in a repository, but only contains basic information. """
- @classmethod
- def for_tag(cls, tag):
- if tag is None:
- return None
+class ShallowTag(datatype("ShallowTag", ["name"])):
+ """ ShallowTag represents a tag in a repository, but only contains basic information. """
- return ShallowTag(db_id=tag.id, name=tag.name)
+ @classmethod
+ def for_tag(cls, tag):
+ if tag is None:
+ return None
- @classmethod
- def for_repository_tag(cls, repository_tag):
- if repository_tag is None:
- return None
+ return ShallowTag(db_id=tag.id, name=tag.name)
- return ShallowTag(db_id=repository_tag.id, name=repository_tag.name)
+ @classmethod
+ def for_repository_tag(cls, repository_tag):
+ if repository_tag is None:
+ return None
- @property
- def id(self):
- """ The ID of this tag for pagination purposes only. """
- return self._db_id
+ return ShallowTag(db_id=repository_tag.id, name=repository_tag.name)
+
+ @property
+ def id(self):
+ """ The ID of this tag for pagination purposes only. """
+ return self._db_id
-class Tag(datatype('Tag', ['name', 'reversion', 'manifest_digest', 'lifetime_start_ts',
- 'lifetime_end_ts', 'lifetime_start_ms', 'lifetime_end_ms'])):
- """ Tag represents a tag in a repository, which points to a manifest or image. """
- @classmethod
- def for_tag(cls, tag, legacy_image=None):
- if tag is None:
- return None
+class Tag(
+ datatype(
+ "Tag",
+ [
+ "name",
+ "reversion",
+ "manifest_digest",
+ "lifetime_start_ts",
+ "lifetime_end_ts",
+ "lifetime_start_ms",
+ "lifetime_end_ms",
+ ],
+ )
+):
+ """ Tag represents a tag in a repository, which points to a manifest or image. """
- return Tag(db_id=tag.id,
- name=tag.name,
- reversion=tag.reversion,
- lifetime_start_ms=tag.lifetime_start_ms,
- lifetime_end_ms=tag.lifetime_end_ms,
- lifetime_start_ts=tag.lifetime_start_ms / 1000,
- lifetime_end_ts=tag.lifetime_end_ms / 1000 if tag.lifetime_end_ms else None,
- manifest_digest=tag.manifest.digest,
- inputs=dict(legacy_image=legacy_image,
- manifest=tag.manifest,
- repository=RepositoryReference.for_id(tag.repository_id)))
+ @classmethod
+ def for_tag(cls, tag, legacy_image=None):
+ if tag is None:
+ return None
- @classmethod
- def for_repository_tag(cls, repository_tag, manifest_digest=None, legacy_image=None):
- if repository_tag is None:
- return None
+ return Tag(
+ db_id=tag.id,
+ name=tag.name,
+ reversion=tag.reversion,
+ lifetime_start_ms=tag.lifetime_start_ms,
+ lifetime_end_ms=tag.lifetime_end_ms,
+ lifetime_start_ts=tag.lifetime_start_ms / 1000,
+ lifetime_end_ts=tag.lifetime_end_ms / 1000 if tag.lifetime_end_ms else None,
+ manifest_digest=tag.manifest.digest,
+ inputs=dict(
+ legacy_image=legacy_image,
+ manifest=tag.manifest,
+ repository=RepositoryReference.for_id(tag.repository_id),
+ ),
+ )
- return Tag(db_id=repository_tag.id,
- name=repository_tag.name,
- reversion=repository_tag.reversion,
- lifetime_start_ts=repository_tag.lifetime_start_ts,
- lifetime_end_ts=repository_tag.lifetime_end_ts,
- lifetime_start_ms=repository_tag.lifetime_start_ts * 1000,
- lifetime_end_ms=(repository_tag.lifetime_end_ts * 1000
- if repository_tag.lifetime_end_ts else None),
- manifest_digest=manifest_digest,
- inputs=dict(legacy_image=legacy_image,
- repository=RepositoryReference.for_id(repository_tag.repository_id)))
+ @classmethod
+ def for_repository_tag(
+ cls, repository_tag, manifest_digest=None, legacy_image=None
+ ):
+ if repository_tag is None:
+ return None
- @property
- @requiresinput('manifest')
- def _manifest(self, manifest):
- """ Returns the manifest for this tag. Will only apply to new-style OCI tags. """
- return manifest
+ return Tag(
+ db_id=repository_tag.id,
+ name=repository_tag.name,
+ reversion=repository_tag.reversion,
+ lifetime_start_ts=repository_tag.lifetime_start_ts,
+ lifetime_end_ts=repository_tag.lifetime_end_ts,
+ lifetime_start_ms=repository_tag.lifetime_start_ts * 1000,
+ lifetime_end_ms=(
+ repository_tag.lifetime_end_ts * 1000
+ if repository_tag.lifetime_end_ts
+ else None
+ ),
+ manifest_digest=manifest_digest,
+ inputs=dict(
+ legacy_image=legacy_image,
+ repository=RepositoryReference.for_id(repository_tag.repository_id),
+ ),
+ )
- @property
- @optionalinput('manifest')
- def manifest(self, manifest):
- """ Returns the manifest for this tag or None if none. Will only apply to new-style OCI tags.
+ @property
+ @requiresinput("manifest")
+ def _manifest(self, manifest):
+ """ Returns the manifest for this tag. Will only apply to new-style OCI tags. """
+ return manifest
+
+ @property
+ @optionalinput("manifest")
+ def manifest(self, manifest):
+ """ Returns the manifest for this tag or None if none. Will only apply to new-style OCI tags.
"""
- return Manifest.for_manifest(manifest, self.legacy_image_if_present)
+ return Manifest.for_manifest(manifest, self.legacy_image_if_present)
- @property
- @requiresinput('repository')
- def repository(self, repository):
- """ Returns the repository under which this tag lives.
+ @property
+ @requiresinput("repository")
+ def repository(self, repository):
+ """ Returns the repository under which this tag lives.
"""
- return repository
+ return repository
- @property
- @requiresinput('legacy_image')
- def legacy_image(self, legacy_image):
- """ Returns the legacy Docker V1-style image for this tag. Note that this
+ @property
+ @requiresinput("legacy_image")
+ def legacy_image(self, legacy_image):
+ """ Returns the legacy Docker V1-style image for this tag. Note that this
will be None for tags whose manifests point to other manifests instead of images.
"""
- return legacy_image
+ return legacy_image
- @property
- @optionalinput('legacy_image')
- def legacy_image_if_present(self, legacy_image):
- """ Returns the legacy Docker V1-style image for this tag. Note that this
+ @property
+ @optionalinput("legacy_image")
+ def legacy_image_if_present(self, legacy_image):
+ """ Returns the legacy Docker V1-style image for this tag. Note that this
will be None for tags whose manifests point to other manifests instead of images.
"""
- return legacy_image
+ return legacy_image
- @property
- def id(self):
- """ The ID of this tag for pagination purposes only. """
- return self._db_id
+ @property
+ def id(self):
+ """ The ID of this tag for pagination purposes only. """
+ return self._db_id
-class Manifest(datatype('Manifest', ['digest', 'media_type', 'internal_manifest_bytes'])):
- """ Manifest represents a manifest in a repository. """
- @classmethod
- def for_tag_manifest(cls, tag_manifest, legacy_image=None):
- if tag_manifest is None:
- return None
+class Manifest(
+ datatype("Manifest", ["digest", "media_type", "internal_manifest_bytes"])
+):
+ """ Manifest represents a manifest in a repository. """
- return Manifest(db_id=tag_manifest.id, digest=tag_manifest.digest,
- internal_manifest_bytes=Bytes.for_string_or_unicode(tag_manifest.json_data),
- media_type=DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE, # Always in legacy.
- inputs=dict(legacy_image=legacy_image, tag_manifest=True))
+ @classmethod
+ def for_tag_manifest(cls, tag_manifest, legacy_image=None):
+ if tag_manifest is None:
+ return None
- @classmethod
- def for_manifest(cls, manifest, legacy_image):
- if manifest is None:
- return None
+ return Manifest(
+ db_id=tag_manifest.id,
+ digest=tag_manifest.digest,
+ internal_manifest_bytes=Bytes.for_string_or_unicode(tag_manifest.json_data),
+ media_type=DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE, # Always in legacy.
+ inputs=dict(legacy_image=legacy_image, tag_manifest=True),
+ )
- # NOTE: `manifest_bytes` will be None if not selected by certain join queries.
- manifest_bytes = (Bytes.for_string_or_unicode(manifest.manifest_bytes)
- if manifest.manifest_bytes is not None else None)
- return Manifest(db_id=manifest.id,
- digest=manifest.digest,
- internal_manifest_bytes=manifest_bytes,
- media_type=ManifestTable.media_type.get_name(manifest.media_type_id),
- inputs=dict(legacy_image=legacy_image, tag_manifest=False))
+ @classmethod
+ def for_manifest(cls, manifest, legacy_image):
+ if manifest is None:
+ return None
- @property
- @requiresinput('tag_manifest')
- def _is_tag_manifest(self, tag_manifest):
- return tag_manifest
+ # NOTE: `manifest_bytes` will be None if not selected by certain join queries.
+ manifest_bytes = (
+ Bytes.for_string_or_unicode(manifest.manifest_bytes)
+ if manifest.manifest_bytes is not None
+ else None
+ )
+ return Manifest(
+ db_id=manifest.id,
+ digest=manifest.digest,
+ internal_manifest_bytes=manifest_bytes,
+ media_type=ManifestTable.media_type.get_name(manifest.media_type_id),
+ inputs=dict(legacy_image=legacy_image, tag_manifest=False),
+ )
- @property
- @requiresinput('legacy_image')
- def legacy_image(self, legacy_image):
- """ Returns the legacy Docker V1-style image for this manifest.
+ @property
+ @requiresinput("tag_manifest")
+ def _is_tag_manifest(self, tag_manifest):
+ return tag_manifest
+
+ @property
+ @requiresinput("legacy_image")
+ def legacy_image(self, legacy_image):
+ """ Returns the legacy Docker V1-style image for this manifest.
"""
- return legacy_image
+ return legacy_image
- @property
- @optionalinput('legacy_image')
- def legacy_image_if_present(self, legacy_image):
- """ Returns the legacy Docker V1-style image for this manifest. Note that this
+ @property
+ @optionalinput("legacy_image")
+ def legacy_image_if_present(self, legacy_image):
+ """ Returns the legacy Docker V1-style image for this manifest. Note that this
will be None for manifests that point to other manifests instead of images.
"""
- return legacy_image
+ return legacy_image
- def get_parsed_manifest(self, validate=True):
- """ Returns the parsed manifest for this manifest. """
- assert self.internal_manifest_bytes
- return parse_manifest_from_bytes(self.internal_manifest_bytes, self.media_type,
- validate=validate)
+ def get_parsed_manifest(self, validate=True):
+ """ Returns the parsed manifest for this manifest. """
+ assert self.internal_manifest_bytes
+ return parse_manifest_from_bytes(
+ self.internal_manifest_bytes, self.media_type, validate=validate
+ )
- @property
- def layers_compressed_size(self):
- """ Returns the total compressed size of the layers in the manifest or None if this could not
+ @property
+ def layers_compressed_size(self):
+ """ Returns the total compressed size of the layers in the manifest or None if this could not
be computed.
"""
- try:
- return self.get_parsed_manifest().layers_compressed_size
- except ManifestException:
- return None
+ try:
+ return self.get_parsed_manifest().layers_compressed_size
+ except ManifestException:
+ return None
- @property
- def is_manifest_list(self):
- """ Returns True if this manifest points to a list (instead of an image). """
- return self.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
+ @property
+ def is_manifest_list(self):
+ """ Returns True if this manifest points to a list (instead of an image). """
+ return self.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
-class LegacyImage(datatype('LegacyImage', ['docker_image_id', 'created', 'comment', 'command',
- 'image_size', 'aggregate_size', 'uploading',
- 'v1_metadata_string'])):
- """ LegacyImage represents a Docker V1-style image found in a repository. """
- @classmethod
- def for_image(cls, image, images_map=None, tags_map=None, blob=None):
- if image is None:
- return None
+class LegacyImage(
+ datatype(
+ "LegacyImage",
+ [
+ "docker_image_id",
+ "created",
+ "comment",
+ "command",
+ "image_size",
+ "aggregate_size",
+ "uploading",
+ "v1_metadata_string",
+ ],
+ )
+):
+ """ LegacyImage represents a Docker V1-style image found in a repository. """
- return LegacyImage(db_id=image.id,
- inputs=dict(images_map=images_map, tags_map=tags_map,
- ancestor_id_list=image.ancestor_id_list(),
- blob=blob),
- docker_image_id=image.docker_image_id,
- created=image.created,
- comment=image.comment,
- command=image.command,
- v1_metadata_string=image.v1_json_metadata,
- image_size=image.storage.image_size,
- aggregate_size=image.aggregate_size,
- uploading=image.storage.uploading)
+ @classmethod
+ def for_image(cls, image, images_map=None, tags_map=None, blob=None):
+ if image is None:
+ return None
- @property
- def id(self):
- """ Returns the database ID of the legacy image. """
- return self._db_id
+ return LegacyImage(
+ db_id=image.id,
+ inputs=dict(
+ images_map=images_map,
+ tags_map=tags_map,
+ ancestor_id_list=image.ancestor_id_list(),
+ blob=blob,
+ ),
+ docker_image_id=image.docker_image_id,
+ created=image.created,
+ comment=image.comment,
+ command=image.command,
+ v1_metadata_string=image.v1_json_metadata,
+ image_size=image.storage.image_size,
+ aggregate_size=image.aggregate_size,
+ uploading=image.storage.uploading,
+ )
- @property
- @requiresinput('images_map')
- @requiresinput('ancestor_id_list')
- def parents(self, images_map, ancestor_id_list):
- """ Returns the parent images for this image. Raises an exception if the parents have
+ @property
+ def id(self):
+ """ Returns the database ID of the legacy image. """
+ return self._db_id
+
+ @property
+ @requiresinput("images_map")
+ @requiresinput("ancestor_id_list")
+ def parents(self, images_map, ancestor_id_list):
+ """ Returns the parent images for this image. Raises an exception if the parents have
not been loaded before this property is invoked. Parents are returned starting at the
leaf image.
"""
- return [LegacyImage.for_image(images_map[ancestor_id], images_map=images_map)
+ return [
+ LegacyImage.for_image(images_map[ancestor_id], images_map=images_map)
for ancestor_id in reversed(ancestor_id_list)
- if images_map.get(ancestor_id)]
+ if images_map.get(ancestor_id)
+ ]
- @property
- @requiresinput('blob')
- def blob(self, blob):
- """ Returns the blob for this image. Raises an exception if the blob has
+ @property
+ @requiresinput("blob")
+ def blob(self, blob):
+ """ Returns the blob for this image. Raises an exception if the blob has
not been loaded before this property is invoked.
"""
- return blob
+ return blob
- @property
- @requiresinput('tags_map')
- def tags(self, tags_map):
- """ Returns the tags pointing to this image. Raises an exception if the tags have
+ @property
+ @requiresinput("tags_map")
+ def tags(self, tags_map):
+ """ Returns the tags pointing to this image. Raises an exception if the tags have
not been loaded before this property is invoked.
"""
- tags = tags_map.get(self._db_id)
- if not tags:
- return []
+ tags = tags_map.get(self._db_id)
+ if not tags:
+ return []
- return [Tag.for_repository_tag(tag) for tag in tags]
+ return [Tag.for_repository_tag(tag) for tag in tags]
@unique
class SecurityScanStatus(Enum):
- """ Security scan status enum """
- SCANNED = 'scanned'
- FAILED = 'failed'
- QUEUED = 'queued'
- UNSUPPORTED = 'unsupported'
+ """ Security scan status enum """
+
+ SCANNED = "scanned"
+ FAILED = "failed"
+ QUEUED = "queued"
+ UNSUPPORTED = "unsupported"
-class ManifestLayer(namedtuple('ManifestLayer', ['layer_info', 'blob'])):
- """ Represents a single layer in a manifest. The `layer_info` data will be manifest-type specific,
+class ManifestLayer(namedtuple("ManifestLayer", ["layer_info", "blob"])):
+ """ Represents a single layer in a manifest. The `layer_info` data will be manifest-type specific,
but will have a few expected fields (such as `digest`). The `blob` represents the associated
blob for this layer, optionally with placements. If the layer is a remote layer, the blob will
be None.
"""
- def estimated_size(self, estimate_multiplier):
- """ Returns the estimated size of this layer. If the layers' blob has an uncompressed size,
+ def estimated_size(self, estimate_multiplier):
+ """ Returns the estimated size of this layer. If the layers' blob has an uncompressed size,
it is used. Otherwise, the compressed_size field in the layer is multiplied by the
multiplier.
"""
- if self.blob.uncompressed_size:
- return self.blob.uncompressed_size
+ if self.blob.uncompressed_size:
+ return self.blob.uncompressed_size
- return (self.layer_info.compressed_size or 0) * estimate_multiplier
+ return (self.layer_info.compressed_size or 0) * estimate_multiplier
-class Blob(datatype('Blob', ['uuid', 'digest', 'compressed_size', 'uncompressed_size',
- 'uploading'])):
- """ Blob represents a content-addressable piece of storage. """
- @classmethod
- def for_image_storage(cls, image_storage, storage_path, placements=None):
- if image_storage is None:
- return None
+class Blob(
+ datatype(
+ "Blob", ["uuid", "digest", "compressed_size", "uncompressed_size", "uploading"]
+ )
+):
+ """ Blob represents a content-addressable piece of storage. """
- return Blob(db_id=image_storage.id,
- uuid=image_storage.uuid,
- inputs=dict(placements=placements, storage_path=storage_path),
- digest=image_storage.content_checksum,
- compressed_size=image_storage.image_size,
- uncompressed_size=image_storage.uncompressed_size,
- uploading=image_storage.uploading)
+ @classmethod
+ def for_image_storage(cls, image_storage, storage_path, placements=None):
+ if image_storage is None:
+ return None
- @property
- @requiresinput('storage_path')
- def storage_path(self, storage_path):
- """ Returns the path of this blob in storage. """
- # TODO: change this to take in the storage engine?
- return storage_path
+ return Blob(
+ db_id=image_storage.id,
+ uuid=image_storage.uuid,
+ inputs=dict(placements=placements, storage_path=storage_path),
+ digest=image_storage.content_checksum,
+ compressed_size=image_storage.image_size,
+ uncompressed_size=image_storage.uncompressed_size,
+ uploading=image_storage.uploading,
+ )
- @property
- @requiresinput('placements')
- def placements(self, placements):
- """ Returns all the storage placements at which the Blob can be found. """
- return placements
+ @property
+ @requiresinput("storage_path")
+ def storage_path(self, storage_path):
+ """ Returns the path of this blob in storage. """
+ # TODO: change this to take in the storage engine?
+ return storage_path
+
+ @property
+ @requiresinput("placements")
+ def placements(self, placements):
+ """ Returns all the storage placements at which the Blob can be found. """
+ return placements
-class DerivedImage(datatype('DerivedImage', ['verb', 'varying_metadata', 'blob'])):
- """ DerivedImage represents an image derived from a manifest via some form of verb. """
- @classmethod
- def for_derived_storage(cls, derived, verb, varying_metadata, blob):
- return DerivedImage(db_id=derived.id,
- verb=verb,
- varying_metadata=varying_metadata,
- blob=blob)
+class DerivedImage(datatype("DerivedImage", ["verb", "varying_metadata", "blob"])):
+ """ DerivedImage represents an image derived from a manifest via some form of verb. """
- @property
- def unique_id(self):
- """ Returns a unique ID for this derived image. This call will consistently produce the same
+ @classmethod
+ def for_derived_storage(cls, derived, verb, varying_metadata, blob):
+ return DerivedImage(
+ db_id=derived.id, verb=verb, varying_metadata=varying_metadata, blob=blob
+ )
+
+ @property
+ def unique_id(self):
+ """ Returns a unique ID for this derived image. This call will consistently produce the same
unique ID across calls in the same code base.
"""
- return hashlib.sha256('%s:%s' % (self.verb, self._db_id)).hexdigest()
+ return hashlib.sha256("%s:%s" % (self.verb, self._db_id)).hexdigest()
-class TorrentInfo(datatype('TorrentInfo', ['pieces', 'piece_length'])):
- """ TorrentInfo represents information to pull a blob via torrent. """
- @classmethod
- def for_torrent_info(cls, torrent_info):
- return TorrentInfo(db_id=torrent_info.id,
- pieces=torrent_info.pieces,
- piece_length=torrent_info.piece_length)
+class TorrentInfo(datatype("TorrentInfo", ["pieces", "piece_length"])):
+ """ TorrentInfo represents information to pull a blob via torrent. """
+
+ @classmethod
+ def for_torrent_info(cls, torrent_info):
+ return TorrentInfo(
+ db_id=torrent_info.id,
+ pieces=torrent_info.pieces,
+ piece_length=torrent_info.piece_length,
+ )
-class BlobUpload(datatype('BlobUpload', ['upload_id', 'byte_count', 'uncompressed_byte_count',
- 'chunk_count', 'sha_state', 'location_name',
- 'storage_metadata', 'piece_sha_state', 'piece_hashes'])):
- """ BlobUpload represents information about an in-progress upload to create a blob. """
- @classmethod
- def for_upload(cls, blob_upload, location_name=None):
- return BlobUpload(db_id=blob_upload.id,
- upload_id=blob_upload.uuid,
- byte_count=blob_upload.byte_count,
- uncompressed_byte_count=blob_upload.uncompressed_byte_count,
- chunk_count=blob_upload.chunk_count,
- sha_state=blob_upload.sha_state,
- location_name=location_name or blob_upload.location.name,
- storage_metadata=blob_upload.storage_metadata,
- piece_sha_state=blob_upload.piece_sha_state,
- piece_hashes=blob_upload.piece_hashes)
+class BlobUpload(
+ datatype(
+ "BlobUpload",
+ [
+ "upload_id",
+ "byte_count",
+ "uncompressed_byte_count",
+ "chunk_count",
+ "sha_state",
+ "location_name",
+ "storage_metadata",
+ "piece_sha_state",
+ "piece_hashes",
+ ],
+ )
+):
+ """ BlobUpload represents information about an in-progress upload to create a blob. """
+
+ @classmethod
+ def for_upload(cls, blob_upload, location_name=None):
+ return BlobUpload(
+ db_id=blob_upload.id,
+ upload_id=blob_upload.uuid,
+ byte_count=blob_upload.byte_count,
+ uncompressed_byte_count=blob_upload.uncompressed_byte_count,
+ chunk_count=blob_upload.chunk_count,
+ sha_state=blob_upload.sha_state,
+ location_name=location_name or blob_upload.location.name,
+ storage_metadata=blob_upload.storage_metadata,
+ piece_sha_state=blob_upload.piece_sha_state,
+ piece_hashes=blob_upload.piece_hashes,
+ )
-class LikelyVulnerableTag(datatype('LikelyVulnerableTag', ['layer_id', 'name'])):
- """ LikelyVulnerableTag represents a tag in a repository that is likely vulnerable to a notified
+class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])):
+ """ LikelyVulnerableTag represents a tag in a repository that is likely vulnerable to a notified
vulnerability.
"""
- # TODO: Remove all of this once we're on the new security model exclusively.
- @classmethod
- def for_tag(cls, tag, repository, docker_image_id, storage_uuid):
- layer_id = '%s.%s' % (docker_image_id, storage_uuid)
- return LikelyVulnerableTag(db_id=tag.id,
- name=tag.name,
- layer_id=layer_id,
- inputs=dict(repository=repository))
- @classmethod
- def for_repository_tag(cls, tag, repository):
- tag_layer_id = '%s.%s' % (tag.image.docker_image_id, tag.image.storage.uuid)
- return LikelyVulnerableTag(db_id=tag.id,
- name=tag.name,
- layer_id=tag_layer_id,
- inputs=dict(repository=repository))
+ # TODO: Remove all of this once we're on the new security model exclusively.
+ @classmethod
+ def for_tag(cls, tag, repository, docker_image_id, storage_uuid):
+ layer_id = "%s.%s" % (docker_image_id, storage_uuid)
+ return LikelyVulnerableTag(
+ db_id=tag.id,
+ name=tag.name,
+ layer_id=layer_id,
+ inputs=dict(repository=repository),
+ )
- @property
- @requiresinput('repository')
- def repository(self, repository):
- return RepositoryReference.for_repo_obj(repository)
+ @classmethod
+ def for_repository_tag(cls, tag, repository):
+ tag_layer_id = "%s.%s" % (tag.image.docker_image_id, tag.image.storage.uuid)
+ return LikelyVulnerableTag(
+ db_id=tag.id,
+ name=tag.name,
+ layer_id=tag_layer_id,
+ inputs=dict(repository=repository),
+ )
+
+ @property
+ @requiresinput("repository")
+ def repository(self, repository):
+ return RepositoryReference.for_repo_obj(repository)
diff --git a/data/registry_model/interface.py b/data/registry_model/interface.py
index 8862f88bc..505cd3899 100644
--- a/data/registry_model/interface.py
+++ b/data/registry_model/interface.py
@@ -1,64 +1,80 @@
from abc import ABCMeta, abstractmethod
from six import add_metaclass
+
@add_metaclass(ABCMeta)
class RegistryDataInterface(object):
- """ Interface for code to work with the registry data model. The registry data model consists
+ """ Interface for code to work with the registry data model. The registry data model consists
of all tables that store registry-specific information, such as Manifests, Blobs, Images,
and Labels.
"""
- @abstractmethod
- def supports_schema2(self, namespace_name):
- """ Returns whether the implementation of the data interface supports schema 2 format
+
+ @abstractmethod
+ def supports_schema2(self, namespace_name):
+ """ Returns whether the implementation of the data interface supports schema 2 format
manifests. """
- @abstractmethod
- def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
- """ Returns the legacy image ID for the tag with a legacy images in
+ @abstractmethod
+ def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
+ """ Returns the legacy image ID for the tag with a legacy images in
the repository. Returns None if None.
"""
- @abstractmethod
- def get_legacy_tags_map(self, repository_ref, storage):
- """ Returns a map from tag name to its legacy image ID, for all tags with legacy images in
+ @abstractmethod
+ def get_legacy_tags_map(self, repository_ref, storage):
+ """ Returns a map from tag name to its legacy image ID, for all tags with legacy images in
the repository. Note that this can be a *very* heavy operation.
"""
- @abstractmethod
- def find_matching_tag(self, repository_ref, tag_names):
- """ Finds an alive tag in the repository matching one of the given tag names and returns it
+ @abstractmethod
+ def find_matching_tag(self, repository_ref, tag_names):
+ """ Finds an alive tag in the repository matching one of the given tag names and returns it
or None if none.
"""
- @abstractmethod
- def get_most_recent_tag(self, repository_ref):
- """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
+ @abstractmethod
+ def get_most_recent_tag(self, repository_ref):
+ """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
None.
"""
- @abstractmethod
- def lookup_repository(self, namespace_name, repo_name, kind_filter=None):
- """ Looks up and returns a reference to the repository with the given namespace and name,
+ @abstractmethod
+ def lookup_repository(self, namespace_name, repo_name, kind_filter=None):
+ """ Looks up and returns a reference to the repository with the given namespace and name,
or None if none. """
- @abstractmethod
- def get_manifest_for_tag(self, tag, backfill_if_necessary=False, include_legacy_image=False):
- """ Returns the manifest associated with the given tag. """
+ @abstractmethod
+ def get_manifest_for_tag(
+ self, tag, backfill_if_necessary=False, include_legacy_image=False
+ ):
+ """ Returns the manifest associated with the given tag. """
- @abstractmethod
- def lookup_manifest_by_digest(self, repository_ref, manifest_digest, allow_dead=False,
- include_legacy_image=False, require_available=False):
- """ Looks up the manifest with the given digest under the given repository and returns it
+ @abstractmethod
+ def lookup_manifest_by_digest(
+ self,
+ repository_ref,
+ manifest_digest,
+ allow_dead=False,
+ include_legacy_image=False,
+ require_available=False,
+ ):
+ """ Looks up the manifest with the given digest under the given repository and returns it
or None if none. If allow_dead is True, manifests pointed to by dead tags will also
be returned. If require_available is True, a temporary tag will be added onto the
returned manifest (before it is returned) to ensure it is available until another
tagging or manifest operation is taken.
"""
- @abstractmethod
- def create_manifest_and_retarget_tag(self, repository_ref, manifest_interface_instance, tag_name,
- storage, raise_on_error=False):
- """ Creates a manifest in a repository, adding all of the necessary data in the model.
+ @abstractmethod
+ def create_manifest_and_retarget_tag(
+ self,
+ repository_ref,
+ manifest_interface_instance,
+ tag_name,
+ storage,
+ raise_on_error=False,
+ ):
+ """ Creates a manifest in a repository, adding all of the necessary data in the model.
The `manifest_interface_instance` parameter must be an instance of the manifest
interface as returned by the image/docker package.
@@ -69,275 +85,317 @@ class RegistryDataInterface(object):
Returns a reference to the (created manifest, tag) or (None, None) on error.
"""
- @abstractmethod
- def get_legacy_images(self, repository_ref):
- """
+ @abstractmethod
+ def get_legacy_images(self, repository_ref):
+ """
Returns an iterator of all the LegacyImage's defined in the matching repository.
"""
- @abstractmethod
- def get_legacy_image(self, repository_ref, docker_image_id, include_parents=False,
- include_blob=False):
- """
+ @abstractmethod
+ def get_legacy_image(
+ self, repository_ref, docker_image_id, include_parents=False, include_blob=False
+ ):
+ """
Returns the matching LegacyImages under the matching repository, if any. If none,
returns None.
"""
- @abstractmethod
- def create_manifest_label(self, manifest, key, value, source_type_name, media_type_name=None):
- """ Creates a label on the manifest with the given key and value.
+ @abstractmethod
+ def create_manifest_label(
+ self, manifest, key, value, source_type_name, media_type_name=None
+ ):
+ """ Creates a label on the manifest with the given key and value.
Can raise InvalidLabelKeyException or InvalidMediaTypeException depending
on the validation errors.
"""
- @abstractmethod
- def batch_create_manifest_labels(self, manifest):
- """ Returns a context manager for batch creation of labels on a manifest.
+ @abstractmethod
+ def batch_create_manifest_labels(self, manifest):
+ """ Returns a context manager for batch creation of labels on a manifest.
Can raise InvalidLabelKeyException or InvalidMediaTypeException depending
on the validation errors.
"""
- @abstractmethod
- def list_manifest_labels(self, manifest, key_prefix=None):
- """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
+ @abstractmethod
+ def list_manifest_labels(self, manifest, key_prefix=None):
+ """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
labels returned to those keys that start with the given prefix.
"""
- @abstractmethod
- def get_manifest_label(self, manifest, label_uuid):
- """ Returns the label with the specified UUID on the manifest or None if none. """
+ @abstractmethod
+ def get_manifest_label(self, manifest, label_uuid):
+ """ Returns the label with the specified UUID on the manifest or None if none. """
- @abstractmethod
- def delete_manifest_label(self, manifest, label_uuid):
- """ Delete the label with the specified UUID on the manifest. Returns the label deleted
+ @abstractmethod
+ def delete_manifest_label(self, manifest, label_uuid):
+ """ Delete the label with the specified UUID on the manifest. Returns the label deleted
or None if none.
"""
- @abstractmethod
- def lookup_cached_active_repository_tags(self, model_cache, repository_ref, start_pagination_id,
- limit):
- """
+ @abstractmethod
+ def lookup_cached_active_repository_tags(
+ self, model_cache, repository_ref, start_pagination_id, limit
+ ):
+ """
Returns a page of active tags in a repository. Note that the tags returned by this method
are ShallowTag objects, which only contain the tag name. This method will automatically cache
the result and check the cache before making a call.
"""
- @abstractmethod
- def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
- """
+ @abstractmethod
+ def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
+ """
Returns a page of active tags in a repository. Note that the tags returned by this method
are ShallowTag objects, which only contain the tag name.
"""
- @abstractmethod
- def list_all_active_repository_tags(self, repository_ref, include_legacy_images=False):
- """
+ @abstractmethod
+ def list_all_active_repository_tags(
+ self, repository_ref, include_legacy_images=False
+ ):
+ """
Returns a list of all the active tags in the repository. Note that this is a *HEAVY*
operation on repositories with a lot of tags, and should only be used for testing or
where other more specific operations are not possible.
"""
- @abstractmethod
- def list_repository_tag_history(self, repository_ref, page=1, size=100, specific_tag_name=None,
- active_tags_only=False, since_time_ms=None):
- """
+ @abstractmethod
+ def list_repository_tag_history(
+ self,
+ repository_ref,
+ page=1,
+ size=100,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+ ):
+ """
Returns the history of all tags in the repository (unless filtered). This includes tags that
have been made in-active due to newer versions of those tags coming into service.
"""
- @abstractmethod
- def get_most_recent_tag_lifetime_start(self, repository_refs):
- """
+ @abstractmethod
+ def get_most_recent_tag_lifetime_start(self, repository_refs):
+ """
Returns a map from repository ID to the last modified time ( seconds from epoch, UTC)
for each repository in the given repository reference list.
"""
- @abstractmethod
- def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
- """
+ @abstractmethod
+ def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
+ """
Returns the latest, *active* tag found in the repository, with the matching name
or None if none.
"""
- @abstractmethod
- def has_expired_tag(self, repository_ref, tag_name):
- """
+ @abstractmethod
+ def has_expired_tag(self, repository_ref, tag_name):
+ """
Returns true if and only if the repository contains a tag with the given name that is expired.
"""
- @abstractmethod
- def retarget_tag(self, repository_ref, tag_name, manifest_or_legacy_image,
- storage, legacy_manifest_key, is_reversion=False):
- """
+ @abstractmethod
+ def retarget_tag(
+ self,
+ repository_ref,
+ tag_name,
+ manifest_or_legacy_image,
+ storage,
+ legacy_manifest_key,
+ is_reversion=False,
+ ):
+ """
Creates, updates or moves a tag to a new entry in history, pointing to the manifest or
legacy image specified. If is_reversion is set to True, this operation is considered a
reversion over a previous tag move operation. Returns the updated Tag or None on error.
"""
- @abstractmethod
- def delete_tag(self, repository_ref, tag_name):
- """
+ @abstractmethod
+ def delete_tag(self, repository_ref, tag_name):
+ """
Deletes the latest, *active* tag with the given name in the repository.
"""
- @abstractmethod
- def delete_tags_for_manifest(self, manifest):
- """
+ @abstractmethod
+ def delete_tags_for_manifest(self, manifest):
+ """
Deletes all tags pointing to the given manifest, making the manifest inaccessible for pulling.
Returns the tags deleted, if any. Returns None on error.
"""
- @abstractmethod
- def change_repository_tag_expiration(self, tag, expiration_date):
- """ Sets the expiration date of the tag under the matching repository to that given. If the
+ @abstractmethod
+ def change_repository_tag_expiration(self, tag, expiration_date):
+ """ Sets the expiration date of the tag under the matching repository to that given. If the
expiration date is None, then the tag will not expire. Returns a tuple of the previous
expiration timestamp in seconds (if any), and whether the operation succeeded.
"""
- @abstractmethod
- def get_legacy_images_owned_by_tag(self, tag):
- """ Returns all legacy images *solely owned and used* by the given tag. """
+ @abstractmethod
+ def get_legacy_images_owned_by_tag(self, tag):
+ """ Returns all legacy images *solely owned and used* by the given tag. """
- @abstractmethod
- def get_security_status(self, manifest_or_legacy_image):
- """ Returns the security status for the given manifest or legacy image or None if none. """
+ @abstractmethod
+ def get_security_status(self, manifest_or_legacy_image):
+ """ Returns the security status for the given manifest or legacy image or None if none. """
- @abstractmethod
- def reset_security_status(self, manifest_or_legacy_image):
- """ Resets the security status for the given manifest or legacy image, ensuring that it will
+ @abstractmethod
+ def reset_security_status(self, manifest_or_legacy_image):
+ """ Resets the security status for the given manifest or legacy image, ensuring that it will
get re-indexed.
"""
- @abstractmethod
- def backfill_manifest_for_tag(self, tag):
- """ Backfills a manifest for the V1 tag specified.
+ @abstractmethod
+ def backfill_manifest_for_tag(self, tag):
+ """ Backfills a manifest for the V1 tag specified.
If a manifest already exists for the tag, returns that manifest.
NOTE: This method will only be necessary until we've completed the backfill, at which point
it should be removed.
"""
- @abstractmethod
- def is_existing_disabled_namespace(self, namespace_name):
- """ Returns whether the given namespace exists and is disabled. """
+ @abstractmethod
+ def is_existing_disabled_namespace(self, namespace_name):
+ """ Returns whether the given namespace exists and is disabled. """
- @abstractmethod
- def is_namespace_enabled(self, namespace_name):
- """ Returns whether the given namespace exists and is enabled. """
+ @abstractmethod
+ def is_namespace_enabled(self, namespace_name):
+ """ Returns whether the given namespace exists and is enabled. """
- @abstractmethod
- def get_manifest_local_blobs(self, manifest, include_placements=False):
- """ Returns the set of local blobs for the given manifest or None if none. """
+ @abstractmethod
+ def get_manifest_local_blobs(self, manifest, include_placements=False):
+ """ Returns the set of local blobs for the given manifest or None if none. """
- @abstractmethod
- def list_manifest_layers(self, manifest, storage, include_placements=False):
- """ Returns an *ordered list* of the layers found in the manifest, starting at the base
+ @abstractmethod
+ def list_manifest_layers(self, manifest, storage, include_placements=False):
+ """ Returns an *ordered list* of the layers found in the manifest, starting at the base
and working towards the leaf, including the associated Blob and its placements
(if specified). The layer information in `layer_info` will be of type
`image.docker.types.ManifestImageLayer`. Should not be called for a manifest list.
"""
- @abstractmethod
- def list_parsed_manifest_layers(self, repository_ref, parsed_manifest, storage,
- include_placements=False):
- """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
+ @abstractmethod
+ def list_parsed_manifest_layers(
+ self, repository_ref, parsed_manifest, storage, include_placements=False
+ ):
+ """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
and working towards the leaf, including the associated Blob and its placements
(if specified). The layer information in `layer_info` will be of type
`image.docker.types.ManifestImageLayer`. Should not be called for a manifest list.
"""
- @abstractmethod
- def lookup_derived_image(self, manifest, verb, storage, varying_metadata=None,
- include_placements=False):
- """
+ @abstractmethod
+ def lookup_derived_image(
+ self, manifest, verb, storage, varying_metadata=None, include_placements=False
+ ):
+ """
Looks up the derived image for the given manifest, verb and optional varying metadata and
returns it or None if none.
"""
- @abstractmethod
- def lookup_or_create_derived_image(self, manifest, verb, storage_location, storage,
- varying_metadata=None, include_placements=False):
- """
+ @abstractmethod
+ def lookup_or_create_derived_image(
+ self,
+ manifest,
+ verb,
+ storage_location,
+ storage,
+ varying_metadata=None,
+ include_placements=False,
+ ):
+ """
Looks up the derived image for the given maniest, verb and optional varying metadata
and returns it. If none exists, a new derived image is created.
"""
- @abstractmethod
- def get_derived_image_signature(self, derived_image, signer_name):
- """
+ @abstractmethod
+ def get_derived_image_signature(self, derived_image, signer_name):
+ """
Returns the signature associated with the derived image and a specific signer or None if none.
"""
- @abstractmethod
- def set_derived_image_signature(self, derived_image, signer_name, signature):
- """
+ @abstractmethod
+ def set_derived_image_signature(self, derived_image, signer_name, signature):
+ """
Sets the calculated signature for the given derived image and signer to that specified.
"""
- @abstractmethod
- def delete_derived_image(self, derived_image):
- """
+ @abstractmethod
+ def delete_derived_image(self, derived_image):
+ """
Deletes a derived image and all of its storage.
"""
- @abstractmethod
- def set_derived_image_size(self, derived_image, compressed_size):
- """
+ @abstractmethod
+ def set_derived_image_size(self, derived_image, compressed_size):
+ """
Sets the compressed size on the given derived image.
"""
- @abstractmethod
- def get_torrent_info(self, blob):
- """
+ @abstractmethod
+ def get_torrent_info(self, blob):
+ """
Returns the torrent information associated with the given blob or None if none.
"""
- @abstractmethod
- def set_torrent_info(self, blob, piece_length, pieces):
- """
+ @abstractmethod
+ def set_torrent_info(self, blob, piece_length, pieces):
+ """
Sets the torrent infomation associated with the given blob to that specified.
"""
- @abstractmethod
- def get_repo_blob_by_digest(self, repository_ref, blob_digest, include_placements=False):
- """
+ @abstractmethod
+ def get_repo_blob_by_digest(
+ self, repository_ref, blob_digest, include_placements=False
+ ):
+ """
Returns the blob in the repository with the given digest, if any or None if none. Note that
there may be multiple records in the same repository for the same blob digest, so the return
value of this function may change.
"""
- @abstractmethod
- def create_blob_upload(self, repository_ref, upload_id, location_name, storage_metadata):
- """ Creates a new blob upload and returns a reference. If the blob upload could not be
+ @abstractmethod
+ def create_blob_upload(
+ self, repository_ref, upload_id, location_name, storage_metadata
+ ):
+ """ Creates a new blob upload and returns a reference. If the blob upload could not be
created, returns None. """
- @abstractmethod
- def lookup_blob_upload(self, repository_ref, blob_upload_id):
- """ Looks up the blob upload with the given ID under the specified repository and returns it
+ @abstractmethod
+ def lookup_blob_upload(self, repository_ref, blob_upload_id):
+ """ Looks up the blob upload with the given ID under the specified repository and returns it
or None if none.
"""
- @abstractmethod
- def update_blob_upload(self, blob_upload, uncompressed_byte_count, piece_hashes, piece_sha_state,
- storage_metadata, byte_count, chunk_count, sha_state):
- """ Updates the fields of the blob upload to match those given. Returns the updated blob upload
+ @abstractmethod
+ def update_blob_upload(
+ self,
+ blob_upload,
+ uncompressed_byte_count,
+ piece_hashes,
+ piece_sha_state,
+ storage_metadata,
+ byte_count,
+ chunk_count,
+ sha_state,
+ ):
+ """ Updates the fields of the blob upload to match those given. Returns the updated blob upload
or None if the record does not exists.
"""
- @abstractmethod
- def delete_blob_upload(self, blob_upload):
- """ Deletes a blob upload record. """
+ @abstractmethod
+ def delete_blob_upload(self, blob_upload):
+ """ Deletes a blob upload record. """
- @abstractmethod
- def commit_blob_upload(self, blob_upload, blob_digest_str, blob_expiration_seconds):
- """ Commits the blob upload into a blob and sets an expiration before that blob will be GCed.
+ @abstractmethod
+ def commit_blob_upload(self, blob_upload, blob_digest_str, blob_expiration_seconds):
+ """ Commits the blob upload into a blob and sets an expiration before that blob will be GCed.
"""
- @abstractmethod
- def mount_blob_into_repository(self, blob, target_repository_ref, expiration_sec):
- """
+ @abstractmethod
+ def mount_blob_into_repository(self, blob, target_repository_ref, expiration_sec):
+ """
Mounts the blob from another repository into the specified target repository, and adds an
expiration before that blob is automatically GCed. This function is useful during push
operations if an existing blob from another repository is being pushed. Returns False if
@@ -346,39 +404,43 @@ class RegistryDataInterface(object):
endpoints/v2/blob.py).
"""
- @abstractmethod
- def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
- """
+ @abstractmethod
+ def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
+ """
Sets the expiration on all tags that point to the given manifest to that specified.
"""
- @abstractmethod
- def get_schema1_parsed_manifest(self, manifest, namespace_name, repo_name, tag_name, storage):
- """ Returns the schema 1 version of this manifest, or None if none. """
+ @abstractmethod
+ def get_schema1_parsed_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, storage
+ ):
+ """ Returns the schema 1 version of this manifest, or None if none. """
- @abstractmethod
- def create_manifest_with_temp_tag(self, repository_ref, manifest_interface_instance,
- expiration_sec, storage):
- """ Creates a manifest under the repository and sets a temporary tag to point to it.
+ @abstractmethod
+ def create_manifest_with_temp_tag(
+ self, repository_ref, manifest_interface_instance, expiration_sec, storage
+ ):
+ """ Creates a manifest under the repository and sets a temporary tag to point to it.
Returns the manifest object created or None on error.
"""
- @abstractmethod
- def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
- """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace
+ @abstractmethod
+ def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
+ """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace
or None if the list could not be loaded.
"""
- @abstractmethod
- def convert_manifest(self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes,
- storage):
- """ Attempts to convert the specified into a parsed manifest with a media type
+ @abstractmethod
+ def convert_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes, storage
+ ):
+ """ Attempts to convert the specified into a parsed manifest with a media type
in the allowed_mediatypes set. If not possible, or an error occurs, returns None.
"""
- @abstractmethod
- def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
- """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
+ @abstractmethod
+ def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
+ """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
which have been registered for vulnerability_found notifications. Returns an iterator
of LikelyVulnerableTag instances.
"""
diff --git a/data/registry_model/label_handlers.py b/data/registry_model/label_handlers.py
index 96afe0d94..18190334c 100644
--- a/data/registry_model/label_handlers.py
+++ b/data/registry_model/label_handlers.py
@@ -4,25 +4,25 @@ from util.timedeltastring import convert_to_timedelta
logger = logging.getLogger(__name__)
+
def _expires_after(label_dict, manifest, model):
- """ Sets the expiration of a manifest based on the quay.expires-in label. """
- try:
- timedelta = convert_to_timedelta(label_dict['value'])
- except ValueError:
- logger.exception('Could not convert %s to timedeltastring', label_dict['value'])
- return
+ """ Sets the expiration of a manifest based on the quay.expires-in label. """
+ try:
+ timedelta = convert_to_timedelta(label_dict["value"])
+ except ValueError:
+ logger.exception("Could not convert %s to timedeltastring", label_dict["value"])
+ return
- total_seconds = timedelta.total_seconds()
- logger.debug('Labeling manifest %s with expiration of %s', manifest, total_seconds)
- model.set_tags_expiration_for_manifest(manifest, total_seconds)
+ total_seconds = timedelta.total_seconds()
+ logger.debug("Labeling manifest %s with expiration of %s", manifest, total_seconds)
+ model.set_tags_expiration_for_manifest(manifest, total_seconds)
-_LABEL_HANDLERS = {
- 'quay.expires-after': _expires_after,
-}
+_LABEL_HANDLERS = {"quay.expires-after": _expires_after}
+
def apply_label_to_manifest(label_dict, manifest, model):
- """ Runs the handler defined, if any, for the given label. """
- handler = _LABEL_HANDLERS.get(label_dict['key'])
- if handler is not None:
- handler(label_dict, manifest, model)
+ """ Runs the handler defined, if any, for the given label. """
+ handler = _LABEL_HANDLERS.get(label_dict["key"])
+ if handler is not None:
+ handler(label_dict, manifest, model)
diff --git a/data/registry_model/manifestbuilder.py b/data/registry_model/manifestbuilder.py
index 384ecb604..a704e65f7 100644
--- a/data/registry_model/manifestbuilder.py
+++ b/data/registry_model/manifestbuilder.py
@@ -13,208 +13,250 @@ from image.docker.schema2 import EMPTY_LAYER_BLOB_DIGEST
logger = logging.getLogger(__name__)
-ManifestLayer = namedtuple('ManifestLayer', ['layer_id', 'v1_metadata_string', 'db_id'])
-_BuilderState = namedtuple('_BuilderState', ['builder_id', 'images', 'tags', 'checksums',
- 'temp_storages'])
+ManifestLayer = namedtuple("ManifestLayer", ["layer_id", "v1_metadata_string", "db_id"])
+_BuilderState = namedtuple(
+ "_BuilderState", ["builder_id", "images", "tags", "checksums", "temp_storages"]
+)
-_SESSION_KEY = '__manifestbuilder'
+_SESSION_KEY = "__manifestbuilder"
def create_manifest_builder(repository_ref, storage, legacy_signing_key):
- """ Creates a new manifest builder for populating manifests under the specified repository
+ """ Creates a new manifest builder for populating manifests under the specified repository
and returns it. Returns None if the builder could not be constructed.
"""
- builder_id = str(uuid.uuid4())
- builder = _ManifestBuilder(repository_ref, _BuilderState(builder_id, {}, {}, {}, []), storage,
- legacy_signing_key)
- builder._save_to_session()
- return builder
+ builder_id = str(uuid.uuid4())
+ builder = _ManifestBuilder(
+ repository_ref,
+ _BuilderState(builder_id, {}, {}, {}, []),
+ storage,
+ legacy_signing_key,
+ )
+ builder._save_to_session()
+ return builder
def lookup_manifest_builder(repository_ref, builder_id, storage, legacy_signing_key):
- """ Looks up the manifest builder with the given ID under the specified repository and returns
+ """ Looks up the manifest builder with the given ID under the specified repository and returns
it or None if none.
"""
- builder_state_tuple = session.get(_SESSION_KEY)
- if builder_state_tuple is None:
- return None
+ builder_state_tuple = session.get(_SESSION_KEY)
+ if builder_state_tuple is None:
+ return None
- builder_state = _BuilderState(*builder_state_tuple)
- if builder_state.builder_id != builder_id:
- return None
+ builder_state = _BuilderState(*builder_state_tuple)
+ if builder_state.builder_id != builder_id:
+ return None
- return _ManifestBuilder(repository_ref, builder_state, storage, legacy_signing_key)
+ return _ManifestBuilder(repository_ref, builder_state, storage, legacy_signing_key)
class _ManifestBuilder(object):
- """ Helper class which provides an interface for bookkeeping the layers and configuration of
+ """ Helper class which provides an interface for bookkeeping the layers and configuration of
manifests being constructed.
"""
- def __init__(self, repository_ref, builder_state, storage, legacy_signing_key):
- self._repository_ref = repository_ref
- self._builder_state = builder_state
- self._storage = storage
- self._legacy_signing_key = legacy_signing_key
- @property
- def builder_id(self):
- """ Returns the unique ID for this builder. """
- return self._builder_state.builder_id
+ def __init__(self, repository_ref, builder_state, storage, legacy_signing_key):
+ self._repository_ref = repository_ref
+ self._builder_state = builder_state
+ self._storage = storage
+ self._legacy_signing_key = legacy_signing_key
- @property
- def committed_tags(self):
- """ Returns the tags committed by this builder, if any. """
- return [registry_model.get_repo_tag(self._repository_ref, tag_name, include_legacy_image=True)
- for tag_name in self._builder_state.tags.keys()]
+ @property
+ def builder_id(self):
+ """ Returns the unique ID for this builder. """
+ return self._builder_state.builder_id
- def start_layer(self, layer_id, v1_metadata_string, location_name, calling_user,
- temp_tag_expiration):
- """ Starts a new layer with the given ID to be placed into a manifest. Returns the layer
+ @property
+ def committed_tags(self):
+ """ Returns the tags committed by this builder, if any. """
+ return [
+ registry_model.get_repo_tag(
+ self._repository_ref, tag_name, include_legacy_image=True
+ )
+ for tag_name in self._builder_state.tags.keys()
+ ]
+
+ def start_layer(
+ self,
+ layer_id,
+ v1_metadata_string,
+ location_name,
+ calling_user,
+ temp_tag_expiration,
+ ):
+ """ Starts a new layer with the given ID to be placed into a manifest. Returns the layer
started or None if an error occurred.
"""
- # Ensure the repository still exists.
- repository = model.repository.lookup_repository(self._repository_ref._db_id)
- if repository is None:
- return None
+ # Ensure the repository still exists.
+ repository = model.repository.lookup_repository(self._repository_ref._db_id)
+ if repository is None:
+ return None
- namespace_name = repository.namespace_user.username
- repo_name = repository.name
+ namespace_name = repository.namespace_user.username
+ repo_name = repository.name
- try:
- v1_metadata = json.loads(v1_metadata_string)
- except ValueError:
- logger.exception('Exception when trying to parse V1 metadata JSON for layer %s', layer_id)
- return None
- except TypeError:
- logger.exception('Exception when trying to parse V1 metadata JSON for layer %s', layer_id)
- return None
+ try:
+ v1_metadata = json.loads(v1_metadata_string)
+ except ValueError:
+ logger.exception(
+ "Exception when trying to parse V1 metadata JSON for layer %s", layer_id
+ )
+ return None
+ except TypeError:
+ logger.exception(
+ "Exception when trying to parse V1 metadata JSON for layer %s", layer_id
+ )
+ return None
- # Sanity check that the ID matches the v1 metadata.
- if layer_id != v1_metadata['id']:
- return None
+ # Sanity check that the ID matches the v1 metadata.
+ if layer_id != v1_metadata["id"]:
+ return None
- # Ensure the parent already exists in the repository.
- parent_id = v1_metadata.get('parent', None)
- parent_image = None
+ # Ensure the parent already exists in the repository.
+ parent_id = v1_metadata.get("parent", None)
+ parent_image = None
- if parent_id is not None:
- parent_image = model.image.get_repo_image(namespace_name, repo_name, parent_id)
- if parent_image is None:
- return None
+ if parent_id is not None:
+ parent_image = model.image.get_repo_image(
+ namespace_name, repo_name, parent_id
+ )
+ if parent_image is None:
+ return None
- # Check to see if this layer already exists in the repository. If so, we can skip the creation.
- existing_image = registry_model.get_legacy_image(self._repository_ref, layer_id)
- if existing_image is not None:
- self._builder_state.images[layer_id] = existing_image.id
- self._save_to_session()
- return ManifestLayer(layer_id, v1_metadata_string, existing_image.id)
+ # Check to see if this layer already exists in the repository. If so, we can skip the creation.
+ existing_image = registry_model.get_legacy_image(self._repository_ref, layer_id)
+ if existing_image is not None:
+ self._builder_state.images[layer_id] = existing_image.id
+ self._save_to_session()
+ return ManifestLayer(layer_id, v1_metadata_string, existing_image.id)
- with db_transaction():
- # Otherwise, create a new legacy image and point a temporary tag at it.
- created = model.image.find_create_or_link_image(layer_id, repository, calling_user, {},
- location_name)
- model.tag.create_temporary_hidden_tag(repository, created, temp_tag_expiration)
+ with db_transaction():
+ # Otherwise, create a new legacy image and point a temporary tag at it.
+ created = model.image.find_create_or_link_image(
+ layer_id, repository, calling_user, {}, location_name
+ )
+ model.tag.create_temporary_hidden_tag(
+ repository, created, temp_tag_expiration
+ )
- # Save its V1 metadata.
- command_list = v1_metadata.get('container_config', {}).get('Cmd', None)
- command = json.dumps(command_list) if command_list else None
+ # Save its V1 metadata.
+ command_list = v1_metadata.get("container_config", {}).get("Cmd", None)
+ command = json.dumps(command_list) if command_list else None
- model.image.set_image_metadata(layer_id, namespace_name, repo_name,
- v1_metadata.get('created'),
- v1_metadata.get('comment'),
- command, v1_metadata_string,
- parent=parent_image)
+ model.image.set_image_metadata(
+ layer_id,
+ namespace_name,
+ repo_name,
+ v1_metadata.get("created"),
+ v1_metadata.get("comment"),
+ command,
+ v1_metadata_string,
+ parent=parent_image,
+ )
- # Save the changes to the builder.
- self._builder_state.images[layer_id] = created.id
- self._save_to_session()
+ # Save the changes to the builder.
+ self._builder_state.images[layer_id] = created.id
+ self._save_to_session()
- return ManifestLayer(layer_id, v1_metadata_string, created.id)
+ return ManifestLayer(layer_id, v1_metadata_string, created.id)
- def lookup_layer(self, layer_id):
- """ Returns a layer with the given ID under this builder. If none exists, returns None. """
- if layer_id not in self._builder_state.images:
- return None
+ def lookup_layer(self, layer_id):
+ """ Returns a layer with the given ID under this builder. If none exists, returns None. """
+ if layer_id not in self._builder_state.images:
+ return None
- image = model.image.get_image_by_db_id(self._builder_state.images[layer_id])
- if image is None:
- return None
+ image = model.image.get_image_by_db_id(self._builder_state.images[layer_id])
+ if image is None:
+ return None
- return ManifestLayer(layer_id, image.v1_json_metadata, image.id)
+ return ManifestLayer(layer_id, image.v1_json_metadata, image.id)
- def assign_layer_blob(self, layer, blob, computed_checksums):
- """ Assigns a blob to a layer. """
- assert blob
- assert not blob.uploading
+ def assign_layer_blob(self, layer, blob, computed_checksums):
+ """ Assigns a blob to a layer. """
+ assert blob
+ assert not blob.uploading
- repo_image = model.image.get_image_by_db_id(layer.db_id)
- if repo_image is None:
- return None
+ repo_image = model.image.get_image_by_db_id(layer.db_id)
+ if repo_image is None:
+ return None
- with db_transaction():
- existing_storage = repo_image.storage
- repo_image.storage = blob._db_id
- repo_image.save()
+ with db_transaction():
+ existing_storage = repo_image.storage
+ repo_image.storage = blob._db_id
+ repo_image.save()
- if existing_storage.uploading:
- self._builder_state.temp_storages.append(existing_storage.id)
+ if existing_storage.uploading:
+ self._builder_state.temp_storages.append(existing_storage.id)
- self._builder_state.checksums[layer.layer_id] = computed_checksums
- self._save_to_session()
- return True
+ self._builder_state.checksums[layer.layer_id] = computed_checksums
+ self._save_to_session()
+ return True
- def validate_layer_checksum(self, layer, checksum):
- """ Returns whether the checksum for a layer matches that specified.
+ def validate_layer_checksum(self, layer, checksum):
+ """ Returns whether the checksum for a layer matches that specified.
"""
- return checksum in self.get_layer_checksums(layer)
+ return checksum in self.get_layer_checksums(layer)
- def get_layer_checksums(self, layer):
- """ Returns the registered defined for the layer, if any. """
- return self._builder_state.checksums.get(layer.layer_id) or []
+ def get_layer_checksums(self, layer):
+ """ Returns the registered defined for the layer, if any. """
+ return self._builder_state.checksums.get(layer.layer_id) or []
- def save_precomputed_checksum(self, layer, checksum):
- """ Saves a precomputed checksum for a layer. """
- checksums = self._builder_state.checksums.get(layer.layer_id) or []
- checksums.append(checksum)
- self._builder_state.checksums[layer.layer_id] = checksums
- self._save_to_session()
+ def save_precomputed_checksum(self, layer, checksum):
+ """ Saves a precomputed checksum for a layer. """
+ checksums = self._builder_state.checksums.get(layer.layer_id) or []
+ checksums.append(checksum)
+ self._builder_state.checksums[layer.layer_id] = checksums
+ self._save_to_session()
- def commit_tag_and_manifest(self, tag_name, layer):
- """ Commits a new tag + manifest for that tag to the repository with the given name,
+ def commit_tag_and_manifest(self, tag_name, layer):
+ """ Commits a new tag + manifest for that tag to the repository with the given name,
pointing to the given layer.
"""
- legacy_image = registry_model.get_legacy_image(self._repository_ref, layer.layer_id)
- if legacy_image is None:
- return None
+ legacy_image = registry_model.get_legacy_image(
+ self._repository_ref, layer.layer_id
+ )
+ if legacy_image is None:
+ return None
- tag = registry_model.retarget_tag(self._repository_ref, tag_name, legacy_image, self._storage,
- self._legacy_signing_key)
- if tag is None:
- return None
+ tag = registry_model.retarget_tag(
+ self._repository_ref,
+ tag_name,
+ legacy_image,
+ self._storage,
+ self._legacy_signing_key,
+ )
+ if tag is None:
+ return None
- self._builder_state.tags[tag_name] = tag._db_id
- self._save_to_session()
- return tag
+ self._builder_state.tags[tag_name] = tag._db_id
+ self._save_to_session()
+ return tag
- def done(self):
- """ Marks the manifest builder as complete and disposes of any state. This call is optional
+ def done(self):
+ """ Marks the manifest builder as complete and disposes of any state. This call is optional
and it is expected manifest builders will eventually time out if unused for an
extended period of time.
"""
- temp_storages = self._builder_state.temp_storages
- for storage_id in temp_storages:
- try:
- storage = ImageStorage.get(id=storage_id)
- if storage.uploading and storage.content_checksum != EMPTY_LAYER_BLOB_DIGEST:
- # Delete all the placements pointing to the storage.
- ImageStoragePlacement.delete().where(ImageStoragePlacement.storage == storage).execute()
+ temp_storages = self._builder_state.temp_storages
+ for storage_id in temp_storages:
+ try:
+ storage = ImageStorage.get(id=storage_id)
+ if (
+ storage.uploading
+ and storage.content_checksum != EMPTY_LAYER_BLOB_DIGEST
+ ):
+ # Delete all the placements pointing to the storage.
+ ImageStoragePlacement.delete().where(
+ ImageStoragePlacement.storage == storage
+ ).execute()
- # Delete the storage.
- storage.delete_instance()
- except ImageStorage.DoesNotExist:
- pass
+ # Delete the storage.
+ storage.delete_instance()
+ except ImageStorage.DoesNotExist:
+ pass
- session.pop(_SESSION_KEY, None)
+ session.pop(_SESSION_KEY, None)
- def _save_to_session(self):
- session[_SESSION_KEY] = self._builder_state
+ def _save_to_session(self):
+ session[_SESSION_KEY] = self._builder_state
diff --git a/data/registry_model/modelsplitter.py b/data/registry_model/modelsplitter.py
index 675a66928..f4864fed7 100644
--- a/data/registry_model/modelsplitter.py
+++ b/data/registry_model/modelsplitter.py
@@ -12,101 +12,129 @@ logger = logging.getLogger(__name__)
class SplitModel(object):
- def __init__(self, oci_model_proportion, oci_namespace_whitelist, v22_namespace_whitelist,
- oci_only_mode):
- self.v22_namespace_whitelist = set(v22_namespace_whitelist)
+ def __init__(
+ self,
+ oci_model_proportion,
+ oci_namespace_whitelist,
+ v22_namespace_whitelist,
+ oci_only_mode,
+ ):
+ self.v22_namespace_whitelist = set(v22_namespace_whitelist)
- self.oci_namespace_whitelist = set(oci_namespace_whitelist)
- self.oci_namespace_whitelist.update(v22_namespace_whitelist)
+ self.oci_namespace_whitelist = set(oci_namespace_whitelist)
+ self.oci_namespace_whitelist.update(v22_namespace_whitelist)
- self.oci_model_proportion = oci_model_proportion
- self.oci_only_mode = oci_only_mode
+ self.oci_model_proportion = oci_model_proportion
+ self.oci_only_mode = oci_only_mode
- def supports_schema2(self, namespace_name):
- """ Returns whether the implementation of the data interface supports schema 2 format
+ def supports_schema2(self, namespace_name):
+ """ Returns whether the implementation of the data interface supports schema 2 format
manifests. """
- return namespace_name in self.v22_namespace_whitelist
+ return namespace_name in self.v22_namespace_whitelist
- def _namespace_from_kwargs(self, args_dict):
- if 'namespace_name' in args_dict:
- return args_dict['namespace_name']
+ def _namespace_from_kwargs(self, args_dict):
+ if "namespace_name" in args_dict:
+ return args_dict["namespace_name"]
- if 'repository_ref' in args_dict:
- return args_dict['repository_ref'].namespace_name
+ if "repository_ref" in args_dict:
+ return args_dict["repository_ref"].namespace_name
- if 'tag' in args_dict:
- return args_dict['tag'].repository.namespace_name
+ if "tag" in args_dict:
+ return args_dict["tag"].repository.namespace_name
- if 'manifest' in args_dict:
- manifest = args_dict['manifest']
- if manifest._is_tag_manifest:
- return TagManifest.get(id=manifest._db_id).tag.repository.namespace_user.username
- else:
- return Manifest.get(id=manifest._db_id).repository.namespace_user.username
+ if "manifest" in args_dict:
+ manifest = args_dict["manifest"]
+ if manifest._is_tag_manifest:
+ return TagManifest.get(
+ id=manifest._db_id
+ ).tag.repository.namespace_user.username
+ else:
+ return Manifest.get(
+ id=manifest._db_id
+ ).repository.namespace_user.username
- if 'manifest_or_legacy_image' in args_dict:
- manifest_or_legacy_image = args_dict['manifest_or_legacy_image']
- if isinstance(manifest_or_legacy_image, LegacyImage):
- return Image.get(id=manifest_or_legacy_image._db_id).repository.namespace_user.username
- else:
- manifest = manifest_or_legacy_image
- if manifest._is_tag_manifest:
- return TagManifest.get(id=manifest._db_id).tag.repository.namespace_user.username
- else:
- return Manifest.get(id=manifest._db_id).repository.namespace_user.username
+ if "manifest_or_legacy_image" in args_dict:
+ manifest_or_legacy_image = args_dict["manifest_or_legacy_image"]
+ if isinstance(manifest_or_legacy_image, LegacyImage):
+ return Image.get(
+ id=manifest_or_legacy_image._db_id
+ ).repository.namespace_user.username
+ else:
+ manifest = manifest_or_legacy_image
+ if manifest._is_tag_manifest:
+ return TagManifest.get(
+ id=manifest._db_id
+ ).tag.repository.namespace_user.username
+ else:
+ return Manifest.get(
+ id=manifest._db_id
+ ).repository.namespace_user.username
- if 'derived_image' in args_dict:
- return (DerivedStorageForImage
- .get(id=args_dict['derived_image']._db_id)
- .source_image
- .repository
- .namespace_user
- .username)
+ if "derived_image" in args_dict:
+ return DerivedStorageForImage.get(
+ id=args_dict["derived_image"]._db_id
+ ).source_image.repository.namespace_user.username
- if 'blob' in args_dict:
- return '' # Blob functions are shared, so no need to do anything.
+ if "blob" in args_dict:
+ return "" # Blob functions are shared, so no need to do anything.
- if 'blob_upload' in args_dict:
- return '' # Blob functions are shared, so no need to do anything.
+ if "blob_upload" in args_dict:
+ return "" # Blob functions are shared, so no need to do anything.
- raise Exception('Unknown namespace for dict `%s`' % args_dict)
+ raise Exception("Unknown namespace for dict `%s`" % args_dict)
- def __getattr__(self, attr):
- def method(*args, **kwargs):
- if self.oci_model_proportion >= 1.0:
- if self.oci_only_mode:
- logger.debug('Calling method `%s` under full OCI data model for all namespaces', attr)
- return getattr(oci_model, attr)(*args, **kwargs)
- else:
- logger.debug('Calling method `%s` under compat OCI data model for all namespaces', attr)
- return getattr(back_compat_oci_model, attr)(*args, **kwargs)
+ def __getattr__(self, attr):
+ def method(*args, **kwargs):
+ if self.oci_model_proportion >= 1.0:
+ if self.oci_only_mode:
+ logger.debug(
+ "Calling method `%s` under full OCI data model for all namespaces",
+ attr,
+ )
+ return getattr(oci_model, attr)(*args, **kwargs)
+ else:
+ logger.debug(
+ "Calling method `%s` under compat OCI data model for all namespaces",
+ attr,
+ )
+ return getattr(back_compat_oci_model, attr)(*args, **kwargs)
- argnames = inspect.getargspec(getattr(back_compat_oci_model, attr))[0]
- if not argnames and isinstance(args[0], ManifestDataType):
- args_dict = dict(manifest=args[0])
- else:
- args_dict = {argnames[index + 1]: value for index, value in enumerate(args)}
+ argnames = inspect.getargspec(getattr(back_compat_oci_model, attr))[0]
+ if not argnames and isinstance(args[0], ManifestDataType):
+ args_dict = dict(manifest=args[0])
+ else:
+ args_dict = {
+ argnames[index + 1]: value for index, value in enumerate(args)
+ }
- if attr in ['yield_tags_for_vulnerability_notification', 'get_most_recent_tag_lifetime_start']:
- use_oci = self.oci_model_proportion >= 1.0
- namespace_name = '(implicit for ' + attr + ')'
- else:
- namespace_name = self._namespace_from_kwargs(args_dict)
- use_oci = namespace_name in self.oci_namespace_whitelist
+ if attr in [
+ "yield_tags_for_vulnerability_notification",
+ "get_most_recent_tag_lifetime_start",
+ ]:
+ use_oci = self.oci_model_proportion >= 1.0
+ namespace_name = "(implicit for " + attr + ")"
+ else:
+ namespace_name = self._namespace_from_kwargs(args_dict)
+ use_oci = namespace_name in self.oci_namespace_whitelist
- if not use_oci and self.oci_model_proportion:
- # Hash the namespace name and see if it falls into the proportion bucket.
- bucket = (int(hashlib.md5(namespace_name).hexdigest(), 16) % 100)
- if bucket <= int(self.oci_model_proportion * 100):
- logger.debug('Enabling OCI for namespace `%s` in proportional bucket',
- namespace_name)
- use_oci = True
+ if not use_oci and self.oci_model_proportion:
+ # Hash the namespace name and see if it falls into the proportion bucket.
+ bucket = int(hashlib.md5(namespace_name).hexdigest(), 16) % 100
+ if bucket <= int(self.oci_model_proportion * 100):
+ logger.debug(
+ "Enabling OCI for namespace `%s` in proportional bucket",
+ namespace_name,
+ )
+ use_oci = True
- if use_oci:
- logger.debug('Calling method `%s` under OCI data model for namespace `%s`',
- attr, namespace_name)
- return getattr(back_compat_oci_model, attr)(*args, **kwargs)
- else:
- return getattr(pre_oci_model, attr)(*args, **kwargs)
+ if use_oci:
+ logger.debug(
+ "Calling method `%s` under OCI data model for namespace `%s`",
+ attr,
+ namespace_name,
+ )
+ return getattr(back_compat_oci_model, attr)(*args, **kwargs)
+ else:
+ return getattr(pre_oci_model, attr)(*args, **kwargs)
- return method
+ return method
diff --git a/data/registry_model/registry_oci_model.py b/data/registry_model/registry_oci_model.py
index 8821a747b..1e3cf0f12 100644
--- a/data/registry_model/registry_oci_model.py
+++ b/data/registry_model/registry_oci_model.py
@@ -10,8 +10,16 @@ from data.model import oci, DataModelException
from data.model.oci.retriever import RepositoryContentRetriever
from data.database import db_transaction, Image, IMAGE_NOT_SCANNED_ENGINE_VERSION
from data.registry_model.interface import RegistryDataInterface
-from data.registry_model.datatypes import (Tag, Manifest, LegacyImage, Label, SecurityScanStatus,
- Blob, ShallowTag, LikelyVulnerableTag)
+from data.registry_model.datatypes import (
+ Tag,
+ Manifest,
+ LegacyImage,
+ Label,
+ SecurityScanStatus,
+ Blob,
+ ShallowTag,
+ LikelyVulnerableTag,
+)
from data.registry_model.shared import SharedModel
from data.registry_model.label_handlers import apply_label_to_manifest
from image.docker import ManifestException
@@ -23,265 +31,346 @@ logger = logging.getLogger(__name__)
class OCIModel(SharedModel, RegistryDataInterface):
- """
+ """
OCIModel implements the data model for the registry API using a database schema
after it was changed to support the OCI specification.
"""
- def __init__(self, oci_model_only=True):
- self.oci_model_only = oci_model_only
- def supports_schema2(self, namespace_name):
- """ Returns whether the implementation of the data interface supports schema 2 format
+ def __init__(self, oci_model_only=True):
+ self.oci_model_only = oci_model_only
+
+ def supports_schema2(self, namespace_name):
+ """ Returns whether the implementation of the data interface supports schema 2 format
manifests. """
- return True
+ return True
- def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
- """ Returns the legacy image ID for the tag with a legacy images in
+ def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
+ """ Returns the legacy image ID for the tag with a legacy images in
the repository. Returns None if None.
"""
- tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
- if tag is None:
- return None
+ tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
+ if tag is None:
+ return None
- if tag.legacy_image_if_present is not None:
- return tag.legacy_image_if_present.docker_image_id
+ if tag.legacy_image_if_present is not None:
+ return tag.legacy_image_if_present.docker_image_id
- if tag.manifest.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE:
- # See if we can lookup a schema1 legacy image.
- v1_compatible = self.get_schema1_parsed_manifest(tag.manifest, '', '', '', storage)
- if v1_compatible is not None:
- return v1_compatible.leaf_layer_v1_image_id
+ if tag.manifest.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE:
+ # See if we can lookup a schema1 legacy image.
+ v1_compatible = self.get_schema1_parsed_manifest(
+ tag.manifest, "", "", "", storage
+ )
+ if v1_compatible is not None:
+ return v1_compatible.leaf_layer_v1_image_id
- return None
+ return None
- def get_legacy_tags_map(self, repository_ref, storage):
- """ Returns a map from tag name to its legacy image ID, for all tags with legacy images in
+ def get_legacy_tags_map(self, repository_ref, storage):
+ """ Returns a map from tag name to its legacy image ID, for all tags with legacy images in
the repository. Note that this can be a *very* heavy operation.
"""
- tags = oci.tag.list_alive_tags(repository_ref._db_id)
- legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
+ tags = oci.tag.list_alive_tags(repository_ref._db_id)
+ legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
- tags_map = {}
- for tag in tags:
- legacy_image = legacy_images_map.get(tag.id)
- if legacy_image is not None:
- tags_map[tag.name] = legacy_image.docker_image_id
- else:
- manifest = Manifest.for_manifest(tag.manifest, None)
- if legacy_image is None and manifest.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE:
- # See if we can lookup a schema1 legacy image.
- v1_compatible = self.get_schema1_parsed_manifest(manifest, '', '', '', storage)
- if v1_compatible is not None:
- v1_id = v1_compatible.leaf_layer_v1_image_id
- if v1_id is not None:
- tags_map[tag.name] = v1_id
+ tags_map = {}
+ for tag in tags:
+ legacy_image = legacy_images_map.get(tag.id)
+ if legacy_image is not None:
+ tags_map[tag.name] = legacy_image.docker_image_id
+ else:
+ manifest = Manifest.for_manifest(tag.manifest, None)
+ if (
+ legacy_image is None
+ and manifest.media_type == DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
+ ):
+ # See if we can lookup a schema1 legacy image.
+ v1_compatible = self.get_schema1_parsed_manifest(
+ manifest, "", "", "", storage
+ )
+ if v1_compatible is not None:
+ v1_id = v1_compatible.leaf_layer_v1_image_id
+ if v1_id is not None:
+ tags_map[tag.name] = v1_id
- return tags_map
+ return tags_map
- def _get_legacy_compatible_image_for_manifest(self, manifest, storage):
- # Check for a legacy image directly on the manifest.
- if manifest.media_type != DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE:
- return oci.shared.get_legacy_image_for_manifest(manifest._db_id)
-
- # Otherwise, lookup a legacy image associated with the v1-compatible manifest
- # in the list.
- try:
- manifest_obj = database.Manifest.get(id=manifest._db_id)
- except database.Manifest.DoesNotExist:
- logger.exception('Could not find manifest for manifest `%s`', manifest._db_id)
- return None
+ def _get_legacy_compatible_image_for_manifest(self, manifest, storage):
+ # Check for a legacy image directly on the manifest.
+ if manifest.media_type != DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE:
+ return oci.shared.get_legacy_image_for_manifest(manifest._db_id)
- # See if we can lookup a schema1 legacy image.
- v1_compatible = self.get_schema1_parsed_manifest(manifest, '', '', '', storage)
- if v1_compatible is None:
- return None
+ # Otherwise, lookup a legacy image associated with the v1-compatible manifest
+ # in the list.
+ try:
+ manifest_obj = database.Manifest.get(id=manifest._db_id)
+ except database.Manifest.DoesNotExist:
+ logger.exception(
+ "Could not find manifest for manifest `%s`", manifest._db_id
+ )
+ return None
- v1_id = v1_compatible.leaf_layer_v1_image_id
- if v1_id is None:
- return None
+ # See if we can lookup a schema1 legacy image.
+ v1_compatible = self.get_schema1_parsed_manifest(manifest, "", "", "", storage)
+ if v1_compatible is None:
+ return None
- return model.image.get_image(manifest_obj.repository_id, v1_id)
+ v1_id = v1_compatible.leaf_layer_v1_image_id
+ if v1_id is None:
+ return None
- def find_matching_tag(self, repository_ref, tag_names):
- """ Finds an alive tag in the repository matching one of the given tag names and returns it
+ return model.image.get_image(manifest_obj.repository_id, v1_id)
+
+ def find_matching_tag(self, repository_ref, tag_names):
+ """ Finds an alive tag in the repository matching one of the given tag names and returns it
or None if none.
"""
- found_tag = oci.tag.find_matching_tag(repository_ref._db_id, tag_names)
- assert found_tag is None or not found_tag.hidden
- return Tag.for_tag(found_tag)
+ found_tag = oci.tag.find_matching_tag(repository_ref._db_id, tag_names)
+ assert found_tag is None or not found_tag.hidden
+ return Tag.for_tag(found_tag)
- def get_most_recent_tag(self, repository_ref):
- """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
+ def get_most_recent_tag(self, repository_ref):
+ """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
None.
"""
- found_tag = oci.tag.get_most_recent_tag(repository_ref._db_id)
- assert found_tag is None or not found_tag.hidden
- return Tag.for_tag(found_tag)
+ found_tag = oci.tag.get_most_recent_tag(repository_ref._db_id)
+ assert found_tag is None or not found_tag.hidden
+ return Tag.for_tag(found_tag)
- def get_manifest_for_tag(self, tag, backfill_if_necessary=False, include_legacy_image=False):
- """ Returns the manifest associated with the given tag. """
- legacy_image = None
- if include_legacy_image:
- legacy_image = oci.shared.get_legacy_image_for_manifest(tag._manifest)
+ def get_manifest_for_tag(
+ self, tag, backfill_if_necessary=False, include_legacy_image=False
+ ):
+ """ Returns the manifest associated with the given tag. """
+ legacy_image = None
+ if include_legacy_image:
+ legacy_image = oci.shared.get_legacy_image_for_manifest(tag._manifest)
- return Manifest.for_manifest(tag._manifest, LegacyImage.for_image(legacy_image))
+ return Manifest.for_manifest(tag._manifest, LegacyImage.for_image(legacy_image))
- def lookup_manifest_by_digest(self, repository_ref, manifest_digest, allow_dead=False,
- include_legacy_image=False, require_available=False):
- """ Looks up the manifest with the given digest under the given repository and returns it
+ def lookup_manifest_by_digest(
+ self,
+ repository_ref,
+ manifest_digest,
+ allow_dead=False,
+ include_legacy_image=False,
+ require_available=False,
+ ):
+ """ Looks up the manifest with the given digest under the given repository and returns it
or None if none. """
- manifest = oci.manifest.lookup_manifest(repository_ref._db_id, manifest_digest,
- allow_dead=allow_dead,
- require_available=require_available)
- if manifest is None:
- return None
+ manifest = oci.manifest.lookup_manifest(
+ repository_ref._db_id,
+ manifest_digest,
+ allow_dead=allow_dead,
+ require_available=require_available,
+ )
+ if manifest is None:
+ return None
- legacy_image = None
- if include_legacy_image:
- try:
- legacy_image_id = database.ManifestLegacyImage.get(manifest=manifest).image.docker_image_id
- legacy_image = self.get_legacy_image(repository_ref, legacy_image_id, include_parents=True)
- except database.ManifestLegacyImage.DoesNotExist:
- pass
+ legacy_image = None
+ if include_legacy_image:
+ try:
+ legacy_image_id = database.ManifestLegacyImage.get(
+ manifest=manifest
+ ).image.docker_image_id
+ legacy_image = self.get_legacy_image(
+ repository_ref, legacy_image_id, include_parents=True
+ )
+ except database.ManifestLegacyImage.DoesNotExist:
+ pass
- return Manifest.for_manifest(manifest, legacy_image)
+ return Manifest.for_manifest(manifest, legacy_image)
- def create_manifest_label(self, manifest, key, value, source_type_name, media_type_name=None):
- """ Creates a label on the manifest with the given key and value. """
- label_data = dict(key=key, value=value, source_type_name=source_type_name,
- media_type_name=media_type_name)
+ def create_manifest_label(
+ self, manifest, key, value, source_type_name, media_type_name=None
+ ):
+ """ Creates a label on the manifest with the given key and value. """
+ label_data = dict(
+ key=key,
+ value=value,
+ source_type_name=source_type_name,
+ media_type_name=media_type_name,
+ )
- # Create the label itself.
- label = oci.label.create_manifest_label(manifest._db_id, key, value, source_type_name,
- media_type_name,
- adjust_old_model=not self.oci_model_only)
- if label is None:
- return None
-
- # Apply any changes to the manifest that the label prescribes.
- apply_label_to_manifest(label_data, manifest, self)
-
- return Label.for_label(label)
-
- @contextmanager
- def batch_create_manifest_labels(self, manifest):
- """ Returns a context manager for batch creation of labels on a manifest.
-
- Can raise InvalidLabelKeyException or InvalidMediaTypeException depending
- on the validation errors.
- """
- labels_to_add = []
- def add_label(key, value, source_type_name, media_type_name=None):
- labels_to_add.append(dict(key=key, value=value, source_type_name=source_type_name,
- media_type_name=media_type_name))
-
- yield add_label
-
- # TODO: make this truly batch once we've fully transitioned to V2_2 and no longer need
- # the mapping tables.
- for label_data in labels_to_add:
- with db_transaction():
# Create the label itself.
- oci.label.create_manifest_label(manifest._db_id, **label_data)
+ label = oci.label.create_manifest_label(
+ manifest._db_id,
+ key,
+ value,
+ source_type_name,
+ media_type_name,
+ adjust_old_model=not self.oci_model_only,
+ )
+ if label is None:
+ return None
# Apply any changes to the manifest that the label prescribes.
apply_label_to_manifest(label_data, manifest, self)
- def list_manifest_labels(self, manifest, key_prefix=None):
- """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
+ return Label.for_label(label)
+
+ @contextmanager
+ def batch_create_manifest_labels(self, manifest):
+ """ Returns a context manager for batch creation of labels on a manifest.
+
+ Can raise InvalidLabelKeyException or InvalidMediaTypeException depending
+ on the validation errors.
+ """
+ labels_to_add = []
+
+ def add_label(key, value, source_type_name, media_type_name=None):
+ labels_to_add.append(
+ dict(
+ key=key,
+ value=value,
+ source_type_name=source_type_name,
+ media_type_name=media_type_name,
+ )
+ )
+
+ yield add_label
+
+ # TODO: make this truly batch once we've fully transitioned to V2_2 and no longer need
+ # the mapping tables.
+ for label_data in labels_to_add:
+ with db_transaction():
+ # Create the label itself.
+ oci.label.create_manifest_label(manifest._db_id, **label_data)
+
+ # Apply any changes to the manifest that the label prescribes.
+ apply_label_to_manifest(label_data, manifest, self)
+
+ def list_manifest_labels(self, manifest, key_prefix=None):
+ """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
labels returned to those keys that start with the given prefix.
"""
- labels = oci.label.list_manifest_labels(manifest._db_id, prefix_filter=key_prefix)
- return [Label.for_label(l) for l in labels]
+ labels = oci.label.list_manifest_labels(
+ manifest._db_id, prefix_filter=key_prefix
+ )
+ return [Label.for_label(l) for l in labels]
- def get_manifest_label(self, manifest, label_uuid):
- """ Returns the label with the specified UUID on the manifest or None if none. """
- return Label.for_label(oci.label.get_manifest_label(label_uuid, manifest._db_id))
+ def get_manifest_label(self, manifest, label_uuid):
+ """ Returns the label with the specified UUID on the manifest or None if none. """
+ return Label.for_label(
+ oci.label.get_manifest_label(label_uuid, manifest._db_id)
+ )
- def delete_manifest_label(self, manifest, label_uuid):
- """ Delete the label with the specified UUID on the manifest. Returns the label deleted
+ def delete_manifest_label(self, manifest, label_uuid):
+ """ Delete the label with the specified UUID on the manifest. Returns the label deleted
or None if none.
"""
- return Label.for_label(oci.label.delete_manifest_label(label_uuid, manifest._db_id))
+ return Label.for_label(
+ oci.label.delete_manifest_label(label_uuid, manifest._db_id)
+ )
- def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
- """
+ def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
+ """
Returns a page of actvie tags in a repository. Note that the tags returned by this method
are ShallowTag objects, which only contain the tag name.
"""
- tags = oci.tag.lookup_alive_tags_shallow(repository_ref._db_id, start_pagination_id, limit)
- return [ShallowTag.for_tag(tag) for tag in tags]
+ tags = oci.tag.lookup_alive_tags_shallow(
+ repository_ref._db_id, start_pagination_id, limit
+ )
+ return [ShallowTag.for_tag(tag) for tag in tags]
- def list_all_active_repository_tags(self, repository_ref, include_legacy_images=False):
- """
+ def list_all_active_repository_tags(
+ self, repository_ref, include_legacy_images=False
+ ):
+ """
Returns a list of all the active tags in the repository. Note that this is a *HEAVY*
operation on repositories with a lot of tags, and should only be used for testing or
where other more specific operations are not possible.
"""
- tags = list(oci.tag.list_alive_tags(repository_ref._db_id))
- legacy_images_map = {}
- if include_legacy_images:
- legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
+ tags = list(oci.tag.list_alive_tags(repository_ref._db_id))
+ legacy_images_map = {}
+ if include_legacy_images:
+ legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
- return [Tag.for_tag(tag, legacy_image=LegacyImage.for_image(legacy_images_map.get(tag.id)))
- for tag in tags]
+ return [
+ Tag.for_tag(
+ tag, legacy_image=LegacyImage.for_image(legacy_images_map.get(tag.id))
+ )
+ for tag in tags
+ ]
- def list_repository_tag_history(self, repository_ref, page=1, size=100, specific_tag_name=None,
- active_tags_only=False, since_time_ms=None):
- """
+ def list_repository_tag_history(
+ self,
+ repository_ref,
+ page=1,
+ size=100,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+ ):
+ """
Returns the history of all tags in the repository (unless filtered). This includes tags that
have been made in-active due to newer versions of those tags coming into service.
"""
- tags, has_more = oci.tag.list_repository_tag_history(repository_ref._db_id,
- page, size,
- specific_tag_name,
- active_tags_only,
- since_time_ms)
+ tags, has_more = oci.tag.list_repository_tag_history(
+ repository_ref._db_id,
+ page,
+ size,
+ specific_tag_name,
+ active_tags_only,
+ since_time_ms,
+ )
- # TODO: do we need legacy images here?
- legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
- return [Tag.for_tag(tag, LegacyImage.for_image(legacy_images_map.get(tag.id))) for tag in tags], has_more
+ # TODO: do we need legacy images here?
+ legacy_images_map = oci.tag.get_legacy_images_for_tags(tags)
+ return (
+ [
+ Tag.for_tag(tag, LegacyImage.for_image(legacy_images_map.get(tag.id)))
+ for tag in tags
+ ],
+ has_more,
+ )
- def has_expired_tag(self, repository_ref, tag_name):
- """
+ def has_expired_tag(self, repository_ref, tag_name):
+ """
Returns true if and only if the repository contains a tag with the given name that is expired.
"""
- return bool(oci.tag.get_expired_tag(repository_ref._db_id, tag_name))
+ return bool(oci.tag.get_expired_tag(repository_ref._db_id, tag_name))
- def get_most_recent_tag_lifetime_start(self, repository_refs):
- """
+ def get_most_recent_tag_lifetime_start(self, repository_refs):
+ """
Returns a map from repository ID to the last modified time (in s) for each repository in the
given repository reference list.
"""
- if not repository_refs:
- return {}
+ if not repository_refs:
+ return {}
- toSeconds = lambda ms: ms / 1000 if ms is not None else None
- last_modified = oci.tag.get_most_recent_tag_lifetime_start([r.id for r in repository_refs])
+ toSeconds = lambda ms: ms / 1000 if ms is not None else None
+ last_modified = oci.tag.get_most_recent_tag_lifetime_start(
+ [r.id for r in repository_refs]
+ )
- return {repo_id: toSeconds(ms) for repo_id, ms in last_modified.items()}
+ return {repo_id: toSeconds(ms) for repo_id, ms in last_modified.items()}
- def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
- """
+ def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
+ """
Returns the latest, *active* tag found in the repository, with the matching name
or None if none.
"""
- assert isinstance(tag_name, basestring)
+ assert isinstance(tag_name, basestring)
- tag = oci.tag.get_tag(repository_ref._db_id, tag_name)
- if tag is None:
- return None
+ tag = oci.tag.get_tag(repository_ref._db_id, tag_name)
+ if tag is None:
+ return None
- legacy_image = None
- if include_legacy_image:
- legacy_images = oci.tag.get_legacy_images_for_tags([tag])
- legacy_image = legacy_images.get(tag.id)
+ legacy_image = None
+ if include_legacy_image:
+ legacy_images = oci.tag.get_legacy_images_for_tags([tag])
+ legacy_image = legacy_images.get(tag.id)
- return Tag.for_tag(tag, legacy_image=LegacyImage.for_image(legacy_image))
+ return Tag.for_tag(tag, legacy_image=LegacyImage.for_image(legacy_image))
- def create_manifest_and_retarget_tag(self, repository_ref, manifest_interface_instance, tag_name,
- storage, raise_on_error=False):
- """ Creates a manifest in a repository, adding all of the necessary data in the model.
+ def create_manifest_and_retarget_tag(
+ self,
+ repository_ref,
+ manifest_interface_instance,
+ tag_name,
+ storage,
+ raise_on_error=False,
+ ):
+ """ Creates a manifest in a repository, adding all of the necessary data in the model.
The `manifest_interface_instance` parameter must be an instance of the manifest
interface as returned by the image/docker package.
@@ -293,376 +382,463 @@ class OCIModel(SharedModel, RegistryDataInterface):
raise_on_error is set to True, in which case a CreateManifestException may also be
raised.
"""
- # Get or create the manifest itself.
- created_manifest = oci.manifest.get_or_create_manifest(repository_ref._db_id,
- manifest_interface_instance,
- storage,
- for_tagging=True,
- raise_on_error=raise_on_error)
- if created_manifest is None:
- return (None, None)
+ # Get or create the manifest itself.
+ created_manifest = oci.manifest.get_or_create_manifest(
+ repository_ref._db_id,
+ manifest_interface_instance,
+ storage,
+ for_tagging=True,
+ raise_on_error=raise_on_error,
+ )
+ if created_manifest is None:
+ return (None, None)
- # Re-target the tag to it.
- tag = oci.tag.retarget_tag(tag_name, created_manifest.manifest,
- adjust_old_model=not self.oci_model_only)
- if tag is None:
- return (None, None)
+ # Re-target the tag to it.
+ tag = oci.tag.retarget_tag(
+ tag_name,
+ created_manifest.manifest,
+ adjust_old_model=not self.oci_model_only,
+ )
+ if tag is None:
+ return (None, None)
- legacy_image = oci.shared.get_legacy_image_for_manifest(created_manifest.manifest)
- li = LegacyImage.for_image(legacy_image)
- wrapped_manifest = Manifest.for_manifest(created_manifest.manifest, li)
+ legacy_image = oci.shared.get_legacy_image_for_manifest(
+ created_manifest.manifest
+ )
+ li = LegacyImage.for_image(legacy_image)
+ wrapped_manifest = Manifest.for_manifest(created_manifest.manifest, li)
- # Apply any labels that should modify the created tag.
- if created_manifest.labels_to_apply:
- for key, value in created_manifest.labels_to_apply.iteritems():
- apply_label_to_manifest(dict(key=key, value=value), wrapped_manifest, self)
+ # Apply any labels that should modify the created tag.
+ if created_manifest.labels_to_apply:
+ for key, value in created_manifest.labels_to_apply.iteritems():
+ apply_label_to_manifest(
+ dict(key=key, value=value), wrapped_manifest, self
+ )
- # Reload the tag in case any updates were applied.
- tag = database.Tag.get(id=tag.id)
+ # Reload the tag in case any updates were applied.
+ tag = database.Tag.get(id=tag.id)
- return (wrapped_manifest, Tag.for_tag(tag, li))
+ return (wrapped_manifest, Tag.for_tag(tag, li))
- def retarget_tag(self, repository_ref, tag_name, manifest_or_legacy_image, storage,
- legacy_manifest_key, is_reversion=False):
- """
+ def retarget_tag(
+ self,
+ repository_ref,
+ tag_name,
+ manifest_or_legacy_image,
+ storage,
+ legacy_manifest_key,
+ is_reversion=False,
+ ):
+ """
Creates, updates or moves a tag to a new entry in history, pointing to the manifest or
legacy image specified. If is_reversion is set to True, this operation is considered a
reversion over a previous tag move operation. Returns the updated Tag or None on error.
"""
- assert legacy_manifest_key is not None
- manifest_id = manifest_or_legacy_image._db_id
- if isinstance(manifest_or_legacy_image, LegacyImage):
- # If a legacy image was required, build a new manifest for it and move the tag to that.
- try:
- image_row = database.Image.get(id=manifest_or_legacy_image._db_id)
- except database.Image.DoesNotExist:
- return None
+ assert legacy_manifest_key is not None
+ manifest_id = manifest_or_legacy_image._db_id
+ if isinstance(manifest_or_legacy_image, LegacyImage):
+ # If a legacy image was required, build a new manifest for it and move the tag to that.
+ try:
+ image_row = database.Image.get(id=manifest_or_legacy_image._db_id)
+ except database.Image.DoesNotExist:
+ return None
- manifest_instance = self._build_manifest_for_legacy_image(tag_name, image_row)
- if manifest_instance is None:
- return None
+ manifest_instance = self._build_manifest_for_legacy_image(
+ tag_name, image_row
+ )
+ if manifest_instance is None:
+ return None
- created = oci.manifest.get_or_create_manifest(repository_ref._db_id, manifest_instance,
- storage)
- if created is None:
- return None
+ created = oci.manifest.get_or_create_manifest(
+ repository_ref._db_id, manifest_instance, storage
+ )
+ if created is None:
+ return None
- manifest_id = created.manifest.id
- else:
- # If the manifest is a schema 1 manifest and its tag name does not match that
- # specified, then we need to create a new manifest, but with that tag name.
- if manifest_or_legacy_image.media_type in DOCKER_SCHEMA1_CONTENT_TYPES:
- try:
- parsed = manifest_or_legacy_image.get_parsed_manifest()
- except ManifestException:
- logger.exception('Could not parse manifest `%s` in retarget_tag',
- manifest_or_legacy_image._db_id)
- return None
+ manifest_id = created.manifest.id
+ else:
+ # If the manifest is a schema 1 manifest and its tag name does not match that
+ # specified, then we need to create a new manifest, but with that tag name.
+ if manifest_or_legacy_image.media_type in DOCKER_SCHEMA1_CONTENT_TYPES:
+ try:
+ parsed = manifest_or_legacy_image.get_parsed_manifest()
+ except ManifestException:
+ logger.exception(
+ "Could not parse manifest `%s` in retarget_tag",
+ manifest_or_legacy_image._db_id,
+ )
+ return None
- if parsed.tag != tag_name:
- logger.debug('Rewriting manifest `%s` for tag named `%s`',
- manifest_or_legacy_image._db_id, tag_name)
+ if parsed.tag != tag_name:
+ logger.debug(
+ "Rewriting manifest `%s` for tag named `%s`",
+ manifest_or_legacy_image._db_id,
+ tag_name,
+ )
- repository_id = repository_ref._db_id
- updated = parsed.with_tag_name(tag_name, legacy_manifest_key)
- assert updated.is_signed
+ repository_id = repository_ref._db_id
+ updated = parsed.with_tag_name(tag_name, legacy_manifest_key)
+ assert updated.is_signed
- created = oci.manifest.get_or_create_manifest(repository_id, updated, storage)
- if created is None:
- return None
+ created = oci.manifest.get_or_create_manifest(
+ repository_id, updated, storage
+ )
+ if created is None:
+ return None
- manifest_id = created.manifest.id
+ manifest_id = created.manifest.id
- tag = oci.tag.retarget_tag(tag_name, manifest_id, is_reversion=is_reversion)
- legacy_image = LegacyImage.for_image(oci.shared.get_legacy_image_for_manifest(manifest_id))
- return Tag.for_tag(tag, legacy_image)
+ tag = oci.tag.retarget_tag(tag_name, manifest_id, is_reversion=is_reversion)
+ legacy_image = LegacyImage.for_image(
+ oci.shared.get_legacy_image_for_manifest(manifest_id)
+ )
+ return Tag.for_tag(tag, legacy_image)
- def delete_tag(self, repository_ref, tag_name):
- """
+ def delete_tag(self, repository_ref, tag_name):
+ """
Deletes the latest, *active* tag with the given name in the repository.
"""
- deleted_tag = oci.tag.delete_tag(repository_ref._db_id, tag_name)
- if deleted_tag is None:
- # TODO: This is only needed because preoci raises an exception. Remove and fix
- # expected status codes once PreOCIModel is gone.
- msg = ('Invalid repository tag \'%s\' on repository' % tag_name)
- raise DataModelException(msg)
+ deleted_tag = oci.tag.delete_tag(repository_ref._db_id, tag_name)
+ if deleted_tag is None:
+ # TODO: This is only needed because preoci raises an exception. Remove and fix
+ # expected status codes once PreOCIModel is gone.
+ msg = "Invalid repository tag '%s' on repository" % tag_name
+ raise DataModelException(msg)
- return Tag.for_tag(deleted_tag)
+ return Tag.for_tag(deleted_tag)
- def delete_tags_for_manifest(self, manifest):
- """
+ def delete_tags_for_manifest(self, manifest):
+ """
Deletes all tags pointing to the given manifest, making the manifest inaccessible for pulling.
Returns the tags deleted, if any. Returns None on error.
"""
- deleted_tags = oci.tag.delete_tags_for_manifest(manifest._db_id)
- return [Tag.for_tag(tag) for tag in deleted_tags]
+ deleted_tags = oci.tag.delete_tags_for_manifest(manifest._db_id)
+ return [Tag.for_tag(tag) for tag in deleted_tags]
- def change_repository_tag_expiration(self, tag, expiration_date):
- """ Sets the expiration date of the tag under the matching repository to that given. If the
+ def change_repository_tag_expiration(self, tag, expiration_date):
+ """ Sets the expiration date of the tag under the matching repository to that given. If the
expiration date is None, then the tag will not expire. Returns a tuple of the previous
expiration timestamp in seconds (if any), and whether the operation succeeded.
"""
- return oci.tag.change_tag_expiration(tag._db_id, expiration_date)
+ return oci.tag.change_tag_expiration(tag._db_id, expiration_date)
- def get_legacy_images_owned_by_tag(self, tag):
- """ Returns all legacy images *solely owned and used* by the given tag. """
- tag_obj = oci.tag.get_tag_by_id(tag._db_id)
- if tag_obj is None:
- return None
+ def get_legacy_images_owned_by_tag(self, tag):
+ """ Returns all legacy images *solely owned and used* by the given tag. """
+ tag_obj = oci.tag.get_tag_by_id(tag._db_id)
+ if tag_obj is None:
+ return None
- tags = oci.tag.list_alive_tags(tag_obj.repository_id)
- legacy_images = oci.tag.get_legacy_images_for_tags(tags)
+ tags = oci.tag.list_alive_tags(tag_obj.repository_id)
+ legacy_images = oci.tag.get_legacy_images_for_tags(tags)
- tag_legacy_image = legacy_images.get(tag._db_id)
- if tag_legacy_image is None:
- return None
+ tag_legacy_image = legacy_images.get(tag._db_id)
+ if tag_legacy_image is None:
+ return None
- assert isinstance(tag_legacy_image, Image)
+ assert isinstance(tag_legacy_image, Image)
- # Collect the IDs of all images that the tag uses.
- tag_image_ids = set()
- tag_image_ids.add(tag_legacy_image.id)
- tag_image_ids.update(tag_legacy_image.ancestor_id_list())
+ # Collect the IDs of all images that the tag uses.
+ tag_image_ids = set()
+ tag_image_ids.add(tag_legacy_image.id)
+ tag_image_ids.update(tag_legacy_image.ancestor_id_list())
- # Remove any images shared by other tags.
- for current in tags:
- if current == tag_obj:
- continue
+ # Remove any images shared by other tags.
+ for current in tags:
+ if current == tag_obj:
+ continue
- current_image = legacy_images.get(current.id)
- if current_image is None:
- continue
+ current_image = legacy_images.get(current.id)
+ if current_image is None:
+ continue
- tag_image_ids.discard(current_image.id)
- tag_image_ids = tag_image_ids.difference(current_image.ancestor_id_list())
- if not tag_image_ids:
- return []
+ tag_image_ids.discard(current_image.id)
+ tag_image_ids = tag_image_ids.difference(current_image.ancestor_id_list())
+ if not tag_image_ids:
+ return []
- if not tag_image_ids:
- return []
+ if not tag_image_ids:
+ return []
- # Load the images we need to return.
- images = database.Image.select().where(database.Image.id << list(tag_image_ids))
- all_image_ids = set()
- for image in images:
- all_image_ids.add(image.id)
- all_image_ids.update(image.ancestor_id_list())
+ # Load the images we need to return.
+ images = database.Image.select().where(database.Image.id << list(tag_image_ids))
+ all_image_ids = set()
+ for image in images:
+ all_image_ids.add(image.id)
+ all_image_ids.update(image.ancestor_id_list())
- # Build a map of all the images and their parents.
- images_map = {}
- all_images = database.Image.select().where(database.Image.id << list(all_image_ids))
- for image in all_images:
- images_map[image.id] = image
+ # Build a map of all the images and their parents.
+ images_map = {}
+ all_images = database.Image.select().where(
+ database.Image.id << list(all_image_ids)
+ )
+ for image in all_images:
+ images_map[image.id] = image
- return [LegacyImage.for_image(image, images_map=images_map) for image in images]
+ return [LegacyImage.for_image(image, images_map=images_map) for image in images]
- def get_security_status(self, manifest_or_legacy_image):
- """ Returns the security status for the given manifest or legacy image or None if none. """
- image = None
+ def get_security_status(self, manifest_or_legacy_image):
+ """ Returns the security status for the given manifest or legacy image or None if none. """
+ image = None
- if isinstance(manifest_or_legacy_image, Manifest):
- image = oci.shared.get_legacy_image_for_manifest(manifest_or_legacy_image._db_id)
- if image is None:
- return SecurityScanStatus.UNSUPPORTED
- else:
- try:
- image = database.Image.get(id=manifest_or_legacy_image._db_id)
- except database.Image.DoesNotExist:
- return None
+ if isinstance(manifest_or_legacy_image, Manifest):
+ image = oci.shared.get_legacy_image_for_manifest(
+ manifest_or_legacy_image._db_id
+ )
+ if image is None:
+ return SecurityScanStatus.UNSUPPORTED
+ else:
+ try:
+ image = database.Image.get(id=manifest_or_legacy_image._db_id)
+ except database.Image.DoesNotExist:
+ return None
- if image.security_indexed_engine is not None and image.security_indexed_engine >= 0:
- return SecurityScanStatus.SCANNED if image.security_indexed else SecurityScanStatus.FAILED
+ if (
+ image.security_indexed_engine is not None
+ and image.security_indexed_engine >= 0
+ ):
+ return (
+ SecurityScanStatus.SCANNED
+ if image.security_indexed
+ else SecurityScanStatus.FAILED
+ )
- return SecurityScanStatus.QUEUED
+ return SecurityScanStatus.QUEUED
- def reset_security_status(self, manifest_or_legacy_image):
- """ Resets the security status for the given manifest or legacy image, ensuring that it will
+ def reset_security_status(self, manifest_or_legacy_image):
+ """ Resets the security status for the given manifest or legacy image, ensuring that it will
get re-indexed.
"""
- image = None
+ image = None
- if isinstance(manifest_or_legacy_image, Manifest):
- image = oci.shared.get_legacy_image_for_manifest(manifest_or_legacy_image._db_id)
- if image is None:
- return None
- else:
- try:
- image = database.Image.get(id=manifest_or_legacy_image._db_id)
- except database.Image.DoesNotExist:
- return None
+ if isinstance(manifest_or_legacy_image, Manifest):
+ image = oci.shared.get_legacy_image_for_manifest(
+ manifest_or_legacy_image._db_id
+ )
+ if image is None:
+ return None
+ else:
+ try:
+ image = database.Image.get(id=manifest_or_legacy_image._db_id)
+ except database.Image.DoesNotExist:
+ return None
- assert image
- image.security_indexed = False
- image.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
- image.save()
+ assert image
+ image.security_indexed = False
+ image.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
+ image.save()
- def backfill_manifest_for_tag(self, tag):
- """ Backfills a manifest for the V1 tag specified.
+ def backfill_manifest_for_tag(self, tag):
+ """ Backfills a manifest for the V1 tag specified.
If a manifest already exists for the tag, returns that manifest.
NOTE: This method will only be necessary until we've completed the backfill, at which point
it should be removed.
"""
- # Nothing to do for OCI tags.
- manifest = tag.manifest
- if manifest is None:
- return None
+ # Nothing to do for OCI tags.
+ manifest = tag.manifest
+ if manifest is None:
+ return None
- legacy_image = oci.shared.get_legacy_image_for_manifest(manifest)
- return Manifest.for_manifest(manifest, LegacyImage.for_image(legacy_image))
+ legacy_image = oci.shared.get_legacy_image_for_manifest(manifest)
+ return Manifest.for_manifest(manifest, LegacyImage.for_image(legacy_image))
- def list_manifest_layers(self, manifest, storage, include_placements=False):
- try:
- manifest_obj = database.Manifest.get(id=manifest._db_id)
- except database.Manifest.DoesNotExist:
- logger.exception('Could not find manifest for manifest `%s`', manifest._db_id)
- return None
+ def list_manifest_layers(self, manifest, storage, include_placements=False):
+ try:
+ manifest_obj = database.Manifest.get(id=manifest._db_id)
+ except database.Manifest.DoesNotExist:
+ logger.exception(
+ "Could not find manifest for manifest `%s`", manifest._db_id
+ )
+ return None
- try:
- parsed = manifest.get_parsed_manifest()
- except ManifestException:
- logger.exception('Could not parse and validate manifest `%s`', manifest._db_id)
- return None
+ try:
+ parsed = manifest.get_parsed_manifest()
+ except ManifestException:
+ logger.exception(
+ "Could not parse and validate manifest `%s`", manifest._db_id
+ )
+ return None
- return self._list_manifest_layers(manifest_obj.repository_id, parsed, storage,
- include_placements, by_manifest=True)
+ return self._list_manifest_layers(
+ manifest_obj.repository_id,
+ parsed,
+ storage,
+ include_placements,
+ by_manifest=True,
+ )
- def lookup_derived_image(self, manifest, verb, storage, varying_metadata=None,
- include_placements=False):
- """
+ def lookup_derived_image(
+ self, manifest, verb, storage, varying_metadata=None, include_placements=False
+ ):
+ """
Looks up the derived image for the given manifest, verb and optional varying metadata and
returns it or None if none.
"""
- legacy_image = self._get_legacy_compatible_image_for_manifest(manifest, storage)
- if legacy_image is None:
- return None
+ legacy_image = self._get_legacy_compatible_image_for_manifest(manifest, storage)
+ if legacy_image is None:
+ return None
- derived = model.image.find_derived_storage_for_image(legacy_image, verb, varying_metadata)
- return self._build_derived(derived, verb, varying_metadata, include_placements)
+ derived = model.image.find_derived_storage_for_image(
+ legacy_image, verb, varying_metadata
+ )
+ return self._build_derived(derived, verb, varying_metadata, include_placements)
- def lookup_or_create_derived_image(self, manifest, verb, storage_location, storage,
- varying_metadata=None,
- include_placements=False):
- """
+ def lookup_or_create_derived_image(
+ self,
+ manifest,
+ verb,
+ storage_location,
+ storage,
+ varying_metadata=None,
+ include_placements=False,
+ ):
+ """
Looks up the derived image for the given maniest, verb and optional varying metadata
and returns it. If none exists, a new derived image is created.
"""
- legacy_image = self._get_legacy_compatible_image_for_manifest(manifest, storage)
- if legacy_image is None:
- return None
+ legacy_image = self._get_legacy_compatible_image_for_manifest(manifest, storage)
+ if legacy_image is None:
+ return None
- derived = model.image.find_or_create_derived_storage(legacy_image, verb, storage_location,
- varying_metadata)
- return self._build_derived(derived, verb, varying_metadata, include_placements)
+ derived = model.image.find_or_create_derived_storage(
+ legacy_image, verb, storage_location, varying_metadata
+ )
+ return self._build_derived(derived, verb, varying_metadata, include_placements)
- def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
- """
+ def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
+ """
Sets the expiration on all tags that point to the given manifest to that specified.
"""
- oci.tag.set_tag_expiration_sec_for_manifest(manifest._db_id, expiration_sec)
+ oci.tag.set_tag_expiration_sec_for_manifest(manifest._db_id, expiration_sec)
- def get_schema1_parsed_manifest(self, manifest, namespace_name, repo_name, tag_name, storage):
- """ Returns the schema 1 manifest for this manifest, or None if none. """
- try:
- parsed = manifest.get_parsed_manifest()
- except ManifestException:
- return None
+ def get_schema1_parsed_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, storage
+ ):
+ """ Returns the schema 1 manifest for this manifest, or None if none. """
+ try:
+ parsed = manifest.get_parsed_manifest()
+ except ManifestException:
+ return None
- try:
- manifest_row = database.Manifest.get(id=manifest._db_id)
- except database.Manifest.DoesNotExist:
- return None
+ try:
+ manifest_row = database.Manifest.get(id=manifest._db_id)
+ except database.Manifest.DoesNotExist:
+ return None
- retriever = RepositoryContentRetriever(manifest_row.repository_id, storage)
- return parsed.get_schema1_manifest(namespace_name, repo_name, tag_name, retriever)
+ retriever = RepositoryContentRetriever(manifest_row.repository_id, storage)
+ return parsed.get_schema1_manifest(
+ namespace_name, repo_name, tag_name, retriever
+ )
- def convert_manifest(self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes,
- storage):
- try:
- parsed = manifest.get_parsed_manifest()
- except ManifestException:
- return None
+ def convert_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes, storage
+ ):
+ try:
+ parsed = manifest.get_parsed_manifest()
+ except ManifestException:
+ return None
- try:
- manifest_row = database.Manifest.get(id=manifest._db_id)
- except database.Manifest.DoesNotExist:
- return None
+ try:
+ manifest_row = database.Manifest.get(id=manifest._db_id)
+ except database.Manifest.DoesNotExist:
+ return None
- retriever = RepositoryContentRetriever(manifest_row.repository_id, storage)
- return parsed.convert_manifest(allowed_mediatypes, namespace_name, repo_name, tag_name,
- retriever)
+ retriever = RepositoryContentRetriever(manifest_row.repository_id, storage)
+ return parsed.convert_manifest(
+ allowed_mediatypes, namespace_name, repo_name, tag_name, retriever
+ )
- def create_manifest_with_temp_tag(self, repository_ref, manifest_interface_instance,
- expiration_sec, storage):
- """ Creates a manifest under the repository and sets a temporary tag to point to it.
+ def create_manifest_with_temp_tag(
+ self, repository_ref, manifest_interface_instance, expiration_sec, storage
+ ):
+ """ Creates a manifest under the repository and sets a temporary tag to point to it.
Returns the manifest object created or None on error.
"""
- # Get or create the manifest itself. get_or_create_manifest will take care of the
- # temporary tag work.
- created_manifest = oci.manifest.get_or_create_manifest(repository_ref._db_id,
- manifest_interface_instance,
- storage,
- temp_tag_expiration_sec=expiration_sec)
- if created_manifest is None:
- return None
+ # Get or create the manifest itself. get_or_create_manifest will take care of the
+ # temporary tag work.
+ created_manifest = oci.manifest.get_or_create_manifest(
+ repository_ref._db_id,
+ manifest_interface_instance,
+ storage,
+ temp_tag_expiration_sec=expiration_sec,
+ )
+ if created_manifest is None:
+ return None
- legacy_image = oci.shared.get_legacy_image_for_manifest(created_manifest.manifest)
- li = LegacyImage.for_image(legacy_image)
- return Manifest.for_manifest(created_manifest.manifest, li)
+ legacy_image = oci.shared.get_legacy_image_for_manifest(
+ created_manifest.manifest
+ )
+ li = LegacyImage.for_image(legacy_image)
+ return Manifest.for_manifest(created_manifest.manifest, li)
- def get_repo_blob_by_digest(self, repository_ref, blob_digest, include_placements=False):
- """
+ def get_repo_blob_by_digest(
+ self, repository_ref, blob_digest, include_placements=False
+ ):
+ """
Returns the blob in the repository with the given digest, if any or None if none. Note that
there may be multiple records in the same repository for the same blob digest, so the return
value of this function may change.
"""
- image_storage = self._get_shared_storage(blob_digest)
- if image_storage is None:
- image_storage = oci.blob.get_repository_blob_by_digest(repository_ref._db_id, blob_digest)
- if image_storage is None:
- return None
+ image_storage = self._get_shared_storage(blob_digest)
+ if image_storage is None:
+ image_storage = oci.blob.get_repository_blob_by_digest(
+ repository_ref._db_id, blob_digest
+ )
+ if image_storage is None:
+ return None
- assert image_storage.cas_path is not None
+ assert image_storage.cas_path is not None
- placements = None
- if include_placements:
- placements = list(model.storage.get_storage_locations(image_storage.uuid))
+ placements = None
+ if include_placements:
+ placements = list(model.storage.get_storage_locations(image_storage.uuid))
- return Blob.for_image_storage(image_storage,
- storage_path=model.storage.get_layer_path(image_storage),
- placements=placements)
+ return Blob.for_image_storage(
+ image_storage,
+ storage_path=model.storage.get_layer_path(image_storage),
+ placements=placements,
+ )
- def list_parsed_manifest_layers(self, repository_ref, parsed_manifest, storage,
- include_placements=False):
- """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
+ def list_parsed_manifest_layers(
+ self, repository_ref, parsed_manifest, storage, include_placements=False
+ ):
+ """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
and working towards the leaf, including the associated Blob and its placements
(if specified).
"""
- return self._list_manifest_layers(repository_ref._db_id, parsed_manifest, storage,
- include_placements=include_placements,
- by_manifest=True)
+ return self._list_manifest_layers(
+ repository_ref._db_id,
+ parsed_manifest,
+ storage,
+ include_placements=include_placements,
+ by_manifest=True,
+ )
- def get_manifest_local_blobs(self, manifest, include_placements=False):
- """ Returns the set of local blobs for the given manifest or None if none. """
- try:
- manifest_row = database.Manifest.get(id=manifest._db_id)
- except database.Manifest.DoesNotExist:
- return None
+ def get_manifest_local_blobs(self, manifest, include_placements=False):
+ """ Returns the set of local blobs for the given manifest or None if none. """
+ try:
+ manifest_row = database.Manifest.get(id=manifest._db_id)
+ except database.Manifest.DoesNotExist:
+ return None
- return self._get_manifest_local_blobs(manifest, manifest_row.repository_id, include_placements,
- by_manifest=True)
+ return self._get_manifest_local_blobs(
+ manifest, manifest_row.repository_id, include_placements, by_manifest=True
+ )
- def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
- """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
+ def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
+ """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
which have been registered for vulnerability_found notifications. Returns an iterator
of LikelyVulnerableTag instances.
"""
- for docker_image_id, storage_uuid in layer_id_pairs:
- tags = oci.tag.lookup_notifiable_tags_for_legacy_image(docker_image_id, storage_uuid,
- 'vulnerability_found')
- for tag in tags:
- yield LikelyVulnerableTag.for_tag(tag, tag.repository, docker_image_id, storage_uuid)
+ for docker_image_id, storage_uuid in layer_id_pairs:
+ tags = oci.tag.lookup_notifiable_tags_for_legacy_image(
+ docker_image_id, storage_uuid, "vulnerability_found"
+ )
+ for tag in tags:
+ yield LikelyVulnerableTag.for_tag(
+ tag, tag.repository, docker_image_id, storage_uuid
+ )
+
oci_model = OCIModel()
back_compat_oci_model = OCIModel(oci_model_only=False)
diff --git a/data/registry_model/registry_pre_oci_model.py b/data/registry_model/registry_pre_oci_model.py
index ec69328d5..798455a43 100644
--- a/data/registry_model/registry_pre_oci_model.py
+++ b/data/registry_model/registry_pre_oci_model.py
@@ -9,9 +9,17 @@ from data import database
from data import model
from data.database import db_transaction, IMAGE_NOT_SCANNED_ENGINE_VERSION
from data.registry_model.interface import RegistryDataInterface
-from data.registry_model.datatypes import (Tag, Manifest, LegacyImage, Label, SecurityScanStatus,
- Blob, RepositoryReference, ShallowTag,
- LikelyVulnerableTag)
+from data.registry_model.datatypes import (
+ Tag,
+ Manifest,
+ LegacyImage,
+ Label,
+ SecurityScanStatus,
+ Blob,
+ RepositoryReference,
+ ShallowTag,
+ LikelyVulnerableTag,
+)
from data.registry_model.shared import SharedModel
from data.registry_model.label_handlers import apply_label_to_manifest
from image.docker.schema1 import ManifestException, DockerSchema1Manifest
@@ -22,86 +30,108 @@ logger = logging.getLogger(__name__)
class PreOCIModel(SharedModel, RegistryDataInterface):
- """
+ """
PreOCIModel implements the data model for the registry API using a database schema
before it was changed to support the OCI specification.
"""
- def supports_schema2(self, namespace_name):
- """ Returns whether the implementation of the data interface supports schema 2 format
- manifests. """
- return False
- def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
- """ Returns the legacy image ID for the tag with a legacy images in
+ def supports_schema2(self, namespace_name):
+ """ Returns whether the implementation of the data interface supports schema 2 format
+ manifests. """
+ return False
+
+ def get_tag_legacy_image_id(self, repository_ref, tag_name, storage):
+ """ Returns the legacy image ID for the tag with a legacy images in
the repository. Returns None if None.
"""
- tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
- if tag is None:
- return None
+ tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
+ if tag is None:
+ return None
- return tag.legacy_image.docker_image_id
+ return tag.legacy_image.docker_image_id
- def get_legacy_tags_map(self, repository_ref, storage):
- """ Returns a map from tag name to its legacy image, for all tags with legacy images in
+ def get_legacy_tags_map(self, repository_ref, storage):
+ """ Returns a map from tag name to its legacy image, for all tags with legacy images in
the repository.
"""
- tags = self.list_all_active_repository_tags(repository_ref, include_legacy_images=True)
- return {tag.name: tag.legacy_image.docker_image_id for tag in tags}
+ tags = self.list_all_active_repository_tags(
+ repository_ref, include_legacy_images=True
+ )
+ return {tag.name: tag.legacy_image.docker_image_id for tag in tags}
- def find_matching_tag(self, repository_ref, tag_names):
- """ Finds an alive tag in the repository matching one of the given tag names and returns it
+ def find_matching_tag(self, repository_ref, tag_names):
+ """ Finds an alive tag in the repository matching one of the given tag names and returns it
or None if none.
"""
- found_tag = model.tag.find_matching_tag(repository_ref._db_id, tag_names)
- assert found_tag is None or not found_tag.hidden
- return Tag.for_repository_tag(found_tag)
+ found_tag = model.tag.find_matching_tag(repository_ref._db_id, tag_names)
+ assert found_tag is None or not found_tag.hidden
+ return Tag.for_repository_tag(found_tag)
- def get_most_recent_tag(self, repository_ref):
- """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
+ def get_most_recent_tag(self, repository_ref):
+ """ Returns the most recently pushed alive tag in the repository, if any. If none, returns
None.
"""
- found_tag = model.tag.get_most_recent_tag(repository_ref._db_id)
- assert found_tag is None or not found_tag.hidden
- return Tag.for_repository_tag(found_tag)
+ found_tag = model.tag.get_most_recent_tag(repository_ref._db_id)
+ assert found_tag is None or not found_tag.hidden
+ return Tag.for_repository_tag(found_tag)
- def get_manifest_for_tag(self, tag, backfill_if_necessary=False, include_legacy_image=False):
- """ Returns the manifest associated with the given tag. """
- try:
- tag_manifest = database.TagManifest.get(tag_id=tag._db_id)
- except database.TagManifest.DoesNotExist:
- if backfill_if_necessary:
- return self.backfill_manifest_for_tag(tag)
+ def get_manifest_for_tag(
+ self, tag, backfill_if_necessary=False, include_legacy_image=False
+ ):
+ """ Returns the manifest associated with the given tag. """
+ try:
+ tag_manifest = database.TagManifest.get(tag_id=tag._db_id)
+ except database.TagManifest.DoesNotExist:
+ if backfill_if_necessary:
+ return self.backfill_manifest_for_tag(tag)
- return None
+ return None
- return Manifest.for_tag_manifest(tag_manifest)
+ return Manifest.for_tag_manifest(tag_manifest)
- def lookup_manifest_by_digest(self, repository_ref, manifest_digest, allow_dead=False,
- include_legacy_image=False, require_available=False):
- """ Looks up the manifest with the given digest under the given repository and returns it
+ def lookup_manifest_by_digest(
+ self,
+ repository_ref,
+ manifest_digest,
+ allow_dead=False,
+ include_legacy_image=False,
+ require_available=False,
+ ):
+ """ Looks up the manifest with the given digest under the given repository and returns it
or None if none. """
- repo = model.repository.lookup_repository(repository_ref._db_id)
- if repo is None:
- return None
+ repo = model.repository.lookup_repository(repository_ref._db_id)
+ if repo is None:
+ return None
- try:
- tag_manifest = model.tag.load_manifest_by_digest(repo.namespace_user.username,
- repo.name,
- manifest_digest,
- allow_dead=allow_dead)
- except model.tag.InvalidManifestException:
- return None
+ try:
+ tag_manifest = model.tag.load_manifest_by_digest(
+ repo.namespace_user.username,
+ repo.name,
+ manifest_digest,
+ allow_dead=allow_dead,
+ )
+ except model.tag.InvalidManifestException:
+ return None
- legacy_image = None
- if include_legacy_image:
- legacy_image = self.get_legacy_image(repository_ref, tag_manifest.tag.image.docker_image_id,
- include_parents=True)
+ legacy_image = None
+ if include_legacy_image:
+ legacy_image = self.get_legacy_image(
+ repository_ref,
+ tag_manifest.tag.image.docker_image_id,
+ include_parents=True,
+ )
- return Manifest.for_tag_manifest(tag_manifest, legacy_image)
+ return Manifest.for_tag_manifest(tag_manifest, legacy_image)
- def create_manifest_and_retarget_tag(self, repository_ref, manifest_interface_instance, tag_name,
- storage, raise_on_error=False):
- """ Creates a manifest in a repository, adding all of the necessary data in the model.
+ def create_manifest_and_retarget_tag(
+ self,
+ repository_ref,
+ manifest_interface_instance,
+ tag_name,
+ storage,
+ raise_on_error=False,
+ ):
+ """ Creates a manifest in a repository, adding all of the necessary data in the model.
The `manifest_interface_instance` parameter must be an instance of the manifest
interface as returned by the image/docker package.
@@ -111,584 +141,726 @@ class PreOCIModel(SharedModel, RegistryDataInterface):
Returns a reference to the (created manifest, tag) or (None, None) on error.
"""
- # NOTE: Only Schema1 is supported by the pre_oci_model.
- assert isinstance(manifest_interface_instance, DockerSchema1Manifest)
- if not manifest_interface_instance.layers:
- return None, None
+ # NOTE: Only Schema1 is supported by the pre_oci_model.
+ assert isinstance(manifest_interface_instance, DockerSchema1Manifest)
+ if not manifest_interface_instance.layers:
+ return None, None
- # Ensure all the blobs in the manifest exist.
- digests = manifest_interface_instance.checksums
- query = self._lookup_repo_storages_by_content_checksum(repository_ref._db_id, digests)
- blob_map = {s.content_checksum: s for s in query}
- for layer in manifest_interface_instance.layers:
- digest_str = str(layer.digest)
- if digest_str not in blob_map:
- return None, None
-
- # Lookup all the images and their parent images (if any) inside the manifest.
- # This will let us know which v1 images we need to synthesize and which ones are invalid.
- docker_image_ids = list(manifest_interface_instance.legacy_image_ids)
- images_query = model.image.lookup_repository_images(repository_ref._db_id, docker_image_ids)
- image_storage_map = {i.docker_image_id: i.storage for i in images_query}
-
- # Rewrite any v1 image IDs that do not match the checksum in the database.
- try:
- rewritten_images = manifest_interface_instance.rewrite_invalid_image_ids(image_storage_map)
- rewritten_images = list(rewritten_images)
- parent_image_map = {}
-
- for rewritten_image in rewritten_images:
- if not rewritten_image.image_id in image_storage_map:
- parent_image = None
- if rewritten_image.parent_image_id:
- parent_image = parent_image_map.get(rewritten_image.parent_image_id)
- if parent_image is None:
- parent_image = model.image.get_image(repository_ref._db_id,
- rewritten_image.parent_image_id)
- if parent_image is None:
+ # Ensure all the blobs in the manifest exist.
+ digests = manifest_interface_instance.checksums
+ query = self._lookup_repo_storages_by_content_checksum(
+ repository_ref._db_id, digests
+ )
+ blob_map = {s.content_checksum: s for s in query}
+ for layer in manifest_interface_instance.layers:
+ digest_str = str(layer.digest)
+ if digest_str not in blob_map:
return None, None
- synthesized = model.image.synthesize_v1_image(
+ # Lookup all the images and their parent images (if any) inside the manifest.
+ # This will let us know which v1 images we need to synthesize and which ones are invalid.
+ docker_image_ids = list(manifest_interface_instance.legacy_image_ids)
+ images_query = model.image.lookup_repository_images(
+ repository_ref._db_id, docker_image_ids
+ )
+ image_storage_map = {i.docker_image_id: i.storage for i in images_query}
+
+ # Rewrite any v1 image IDs that do not match the checksum in the database.
+ try:
+ rewritten_images = manifest_interface_instance.rewrite_invalid_image_ids(
+ image_storage_map
+ )
+ rewritten_images = list(rewritten_images)
+ parent_image_map = {}
+
+ for rewritten_image in rewritten_images:
+ if not rewritten_image.image_id in image_storage_map:
+ parent_image = None
+ if rewritten_image.parent_image_id:
+ parent_image = parent_image_map.get(
+ rewritten_image.parent_image_id
+ )
+ if parent_image is None:
+ parent_image = model.image.get_image(
+ repository_ref._db_id, rewritten_image.parent_image_id
+ )
+ if parent_image is None:
+ return None, None
+
+ synthesized = model.image.synthesize_v1_image(
+ repository_ref._db_id,
+ blob_map[rewritten_image.content_checksum].id,
+ blob_map[rewritten_image.content_checksum].image_size,
+ rewritten_image.image_id,
+ rewritten_image.created,
+ rewritten_image.comment,
+ rewritten_image.command,
+ rewritten_image.compat_json,
+ parent_image,
+ )
+
+ parent_image_map[rewritten_image.image_id] = synthesized
+ except ManifestException:
+ logger.exception("exception when rewriting v1 metadata")
+ return None, None
+
+ # Store the manifest pointing to the tag.
+ leaf_layer_id = rewritten_images[-1].image_id
+ tag_manifest, newly_created = model.tag.store_tag_manifest_for_repo(
repository_ref._db_id,
- blob_map[rewritten_image.content_checksum].id,
- blob_map[rewritten_image.content_checksum].image_size,
- rewritten_image.image_id,
- rewritten_image.created,
- rewritten_image.comment,
- rewritten_image.command,
- rewritten_image.compat_json,
- parent_image,
- )
+ tag_name,
+ manifest_interface_instance,
+ leaf_layer_id,
+ blob_map,
+ )
- parent_image_map[rewritten_image.image_id] = synthesized
- except ManifestException:
- logger.exception("exception when rewriting v1 metadata")
- return None, None
+ manifest = Manifest.for_tag_manifest(tag_manifest)
- # Store the manifest pointing to the tag.
- leaf_layer_id = rewritten_images[-1].image_id
- tag_manifest, newly_created = model.tag.store_tag_manifest_for_repo(repository_ref._db_id,
- tag_name,
- manifest_interface_instance,
- leaf_layer_id,
- blob_map)
+ # Save the labels on the manifest.
+ repo_tag = tag_manifest.tag
+ if newly_created:
+ has_labels = False
+ with self.batch_create_manifest_labels(manifest) as add_label:
+ if add_label is None:
+ return None, None
- manifest = Manifest.for_tag_manifest(tag_manifest)
+ for key, value in manifest_interface_instance.layers[
+ -1
+ ].v1_metadata.labels.iteritems():
+ media_type = "application/json" if is_json(value) else "text/plain"
+ add_label(key, value, "manifest", media_type)
+ has_labels = True
- # Save the labels on the manifest.
- repo_tag = tag_manifest.tag
- if newly_created:
- has_labels = False
- with self.batch_create_manifest_labels(manifest) as add_label:
- if add_label is None:
- return None, None
+ # Reload the tag in case any updates were applied.
+ if has_labels:
+ repo_tag = database.RepositoryTag.get(id=repo_tag.id)
- for key, value in manifest_interface_instance.layers[-1].v1_metadata.labels.iteritems():
- media_type = 'application/json' if is_json(value) else 'text/plain'
- add_label(key, value, 'manifest', media_type)
- has_labels = True
+ return manifest, Tag.for_repository_tag(repo_tag)
- # Reload the tag in case any updates were applied.
- if has_labels:
- repo_tag = database.RepositoryTag.get(id=repo_tag.id)
+ def create_manifest_label(
+ self, manifest, key, value, source_type_name, media_type_name=None
+ ):
+ """ Creates a label on the manifest with the given key and value. """
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ return None
- return manifest, Tag.for_repository_tag(repo_tag)
+ label_data = dict(
+ key=key,
+ value=value,
+ source_type_name=source_type_name,
+ media_type_name=media_type_name,
+ )
- def create_manifest_label(self, manifest, key, value, source_type_name, media_type_name=None):
- """ Creates a label on the manifest with the given key and value. """
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- return None
+ with db_transaction():
+ # Create the label itself.
+ label = model.label.create_manifest_label(
+ tag_manifest, key, value, source_type_name, media_type_name
+ )
- label_data = dict(key=key, value=value, source_type_name=source_type_name,
- media_type_name=media_type_name)
+ # Apply any changes to the manifest that the label prescribes.
+ apply_label_to_manifest(label_data, manifest, self)
- with db_transaction():
- # Create the label itself.
- label = model.label.create_manifest_label(tag_manifest, key, value, source_type_name,
- media_type_name)
+ return Label.for_label(label)
- # Apply any changes to the manifest that the label prescribes.
- apply_label_to_manifest(label_data, manifest, self)
-
- return Label.for_label(label)
-
- @contextmanager
- def batch_create_manifest_labels(self, manifest):
- """ Returns a context manager for batch creation of labels on a manifest.
+ @contextmanager
+ def batch_create_manifest_labels(self, manifest):
+ """ Returns a context manager for batch creation of labels on a manifest.
Can raise InvalidLabelKeyException or InvalidMediaTypeException depending
on the validation errors.
"""
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- yield None
- return
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ yield None
+ return
- labels_to_add = []
- def add_label(key, value, source_type_name, media_type_name=None):
- labels_to_add.append(dict(key=key, value=value, source_type_name=source_type_name,
- media_type_name=media_type_name))
+ labels_to_add = []
- yield add_label
+ def add_label(key, value, source_type_name, media_type_name=None):
+ labels_to_add.append(
+ dict(
+ key=key,
+ value=value,
+ source_type_name=source_type_name,
+ media_type_name=media_type_name,
+ )
+ )
- # TODO: make this truly batch once we've fully transitioned to V2_2 and no longer need
- # the mapping tables.
- for label in labels_to_add:
- with db_transaction():
- # Create the label itself.
- model.label.create_manifest_label(tag_manifest, **label)
+ yield add_label
- # Apply any changes to the manifest that the label prescribes.
- apply_label_to_manifest(label, manifest, self)
+ # TODO: make this truly batch once we've fully transitioned to V2_2 and no longer need
+ # the mapping tables.
+ for label in labels_to_add:
+ with db_transaction():
+ # Create the label itself.
+ model.label.create_manifest_label(tag_manifest, **label)
- def list_manifest_labels(self, manifest, key_prefix=None):
- """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
+ # Apply any changes to the manifest that the label prescribes.
+ apply_label_to_manifest(label, manifest, self)
+
+ def list_manifest_labels(self, manifest, key_prefix=None):
+ """ Returns all labels found on the manifest. If specified, the key_prefix will filter the
labels returned to those keys that start with the given prefix.
"""
- labels = model.label.list_manifest_labels(manifest._db_id, prefix_filter=key_prefix)
- return [Label.for_label(l) for l in labels]
+ labels = model.label.list_manifest_labels(
+ manifest._db_id, prefix_filter=key_prefix
+ )
+ return [Label.for_label(l) for l in labels]
- def get_manifest_label(self, manifest, label_uuid):
- """ Returns the label with the specified UUID on the manifest or None if none. """
- return Label.for_label(model.label.get_manifest_label(label_uuid, manifest._db_id))
+ def get_manifest_label(self, manifest, label_uuid):
+ """ Returns the label with the specified UUID on the manifest or None if none. """
+ return Label.for_label(
+ model.label.get_manifest_label(label_uuid, manifest._db_id)
+ )
- def delete_manifest_label(self, manifest, label_uuid):
- """ Delete the label with the specified UUID on the manifest. Returns the label deleted
+ def delete_manifest_label(self, manifest, label_uuid):
+ """ Delete the label with the specified UUID on the manifest. Returns the label deleted
or None if none.
"""
- return Label.for_label(model.label.delete_manifest_label(label_uuid, manifest._db_id))
+ return Label.for_label(
+ model.label.delete_manifest_label(label_uuid, manifest._db_id)
+ )
- def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
- """
+ def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
+ """
Returns a page of actvie tags in a repository. Note that the tags returned by this method
are ShallowTag objects, which only contain the tag name.
"""
- tags = model.tag.list_active_repo_tags(repository_ref._db_id, include_images=False,
- start_id=start_pagination_id, limit=limit)
- return [ShallowTag.for_repository_tag(tag) for tag in tags]
+ tags = model.tag.list_active_repo_tags(
+ repository_ref._db_id,
+ include_images=False,
+ start_id=start_pagination_id,
+ limit=limit,
+ )
+ return [ShallowTag.for_repository_tag(tag) for tag in tags]
- def list_all_active_repository_tags(self, repository_ref, include_legacy_images=False):
- """
+ def list_all_active_repository_tags(
+ self, repository_ref, include_legacy_images=False
+ ):
+ """
Returns a list of all the active tags in the repository. Note that this is a *HEAVY*
operation on repositories with a lot of tags, and should only be used for testing or
where other more specific operations are not possible.
"""
- if not include_legacy_images:
- tags = model.tag.list_active_repo_tags(repository_ref._db_id, include_images=False)
- return [Tag.for_repository_tag(tag) for tag in tags]
+ if not include_legacy_images:
+ tags = model.tag.list_active_repo_tags(
+ repository_ref._db_id, include_images=False
+ )
+ return [Tag.for_repository_tag(tag) for tag in tags]
- tags = model.tag.list_active_repo_tags(repository_ref._db_id)
- return [Tag.for_repository_tag(tag,
- legacy_image=LegacyImage.for_image(tag.image),
- manifest_digest=(tag.tagmanifest.digest
- if hasattr(tag, 'tagmanifest')
- else None))
- for tag in tags]
+ tags = model.tag.list_active_repo_tags(repository_ref._db_id)
+ return [
+ Tag.for_repository_tag(
+ tag,
+ legacy_image=LegacyImage.for_image(tag.image),
+ manifest_digest=(
+ tag.tagmanifest.digest if hasattr(tag, "tagmanifest") else None
+ ),
+ )
+ for tag in tags
+ ]
- def list_repository_tag_history(self, repository_ref, page=1, size=100, specific_tag_name=None,
- active_tags_only=False, since_time_ms=None):
- """
+ def list_repository_tag_history(
+ self,
+ repository_ref,
+ page=1,
+ size=100,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+ ):
+ """
Returns the history of all tags in the repository (unless filtered). This includes tags that
have been made in-active due to newer versions of those tags coming into service.
"""
- # Only available on OCI model
- if since_time_ms is not None:
- raise NotImplementedError
+ # Only available on OCI model
+ if since_time_ms is not None:
+ raise NotImplementedError
- tags, manifest_map, has_more = model.tag.list_repository_tag_history(repository_ref._db_id,
- page, size,
- specific_tag_name,
- active_tags_only)
- return [Tag.for_repository_tag(tag, manifest_map.get(tag.id),
- legacy_image=LegacyImage.for_image(tag.image))
- for tag in tags], has_more
+ tags, manifest_map, has_more = model.tag.list_repository_tag_history(
+ repository_ref._db_id, page, size, specific_tag_name, active_tags_only
+ )
+ return (
+ [
+ Tag.for_repository_tag(
+ tag,
+ manifest_map.get(tag.id),
+ legacy_image=LegacyImage.for_image(tag.image),
+ )
+ for tag in tags
+ ],
+ has_more,
+ )
- def has_expired_tag(self, repository_ref, tag_name):
- """
+ def has_expired_tag(self, repository_ref, tag_name):
+ """
Returns true if and only if the repository contains a tag with the given name that is expired.
"""
- try:
- model.tag.get_expired_tag_in_repo(repository_ref._db_id, tag_name)
- return True
- except database.RepositoryTag.DoesNotExist:
- return False
+ try:
+ model.tag.get_expired_tag_in_repo(repository_ref._db_id, tag_name)
+ return True
+ except database.RepositoryTag.DoesNotExist:
+ return False
- def get_most_recent_tag_lifetime_start(self, repository_refs):
- """
+ def get_most_recent_tag_lifetime_start(self, repository_refs):
+ """
Returns a map from repository ID to the last modified time (in s) for each repository in the
given repository reference list.
"""
- if not repository_refs:
- return {}
+ if not repository_refs:
+ return {}
- tuples = (database.RepositoryTag.select(database.RepositoryTag.repository,
- fn.Max(database.RepositoryTag.lifetime_start_ts))
- .where(database.RepositoryTag.repository << [r.id for r in repository_refs])
- .group_by(database.RepositoryTag.repository)
- .tuples())
+ tuples = (
+ database.RepositoryTag.select(
+ database.RepositoryTag.repository,
+ fn.Max(database.RepositoryTag.lifetime_start_ts),
+ )
+ .where(database.RepositoryTag.repository << [r.id for r in repository_refs])
+ .group_by(database.RepositoryTag.repository)
+ .tuples()
+ )
- return {repo_id: seconds for repo_id, seconds in tuples}
+ return {repo_id: seconds for repo_id, seconds in tuples}
- def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
- """
+ def get_repo_tag(self, repository_ref, tag_name, include_legacy_image=False):
+ """
Returns the latest, *active* tag found in the repository, with the matching name
or None if none.
"""
- assert isinstance(tag_name, basestring)
- tag = model.tag.get_active_tag_for_repo(repository_ref._db_id, tag_name)
- if tag is None:
- return None
+ assert isinstance(tag_name, basestring)
+ tag = model.tag.get_active_tag_for_repo(repository_ref._db_id, tag_name)
+ if tag is None:
+ return None
- legacy_image = LegacyImage.for_image(tag.image) if include_legacy_image else None
- tag_manifest = model.tag.get_tag_manifest(tag)
- manifest_digest = tag_manifest.digest if tag_manifest else None
- return Tag.for_repository_tag(tag, legacy_image=legacy_image, manifest_digest=manifest_digest)
+ legacy_image = (
+ LegacyImage.for_image(tag.image) if include_legacy_image else None
+ )
+ tag_manifest = model.tag.get_tag_manifest(tag)
+ manifest_digest = tag_manifest.digest if tag_manifest else None
+ return Tag.for_repository_tag(
+ tag, legacy_image=legacy_image, manifest_digest=manifest_digest
+ )
- def retarget_tag(self, repository_ref, tag_name, manifest_or_legacy_image, storage,
- legacy_manifest_key, is_reversion=False):
- """
+ def retarget_tag(
+ self,
+ repository_ref,
+ tag_name,
+ manifest_or_legacy_image,
+ storage,
+ legacy_manifest_key,
+ is_reversion=False,
+ ):
+ """
Creates, updates or moves a tag to a new entry in history, pointing to the manifest or
legacy image specified. If is_reversion is set to True, this operation is considered a
reversion over a previous tag move operation. Returns the updated Tag or None on error.
"""
- # TODO: unify this.
- assert legacy_manifest_key is not None
- if not is_reversion:
- if isinstance(manifest_or_legacy_image, Manifest):
- raise NotImplementedError('Not yet implemented')
- else:
- model.tag.create_or_update_tag_for_repo(repository_ref._db_id, tag_name,
- manifest_or_legacy_image.docker_image_id)
- else:
- if isinstance(manifest_or_legacy_image, Manifest):
- model.tag.restore_tag_to_manifest(repository_ref._db_id, tag_name,
- manifest_or_legacy_image.digest)
- else:
- model.tag.restore_tag_to_image(repository_ref._db_id, tag_name,
- manifest_or_legacy_image.docker_image_id)
+ # TODO: unify this.
+ assert legacy_manifest_key is not None
+ if not is_reversion:
+ if isinstance(manifest_or_legacy_image, Manifest):
+ raise NotImplementedError("Not yet implemented")
+ else:
+ model.tag.create_or_update_tag_for_repo(
+ repository_ref._db_id,
+ tag_name,
+ manifest_or_legacy_image.docker_image_id,
+ )
+ else:
+ if isinstance(manifest_or_legacy_image, Manifest):
+ model.tag.restore_tag_to_manifest(
+ repository_ref._db_id, tag_name, manifest_or_legacy_image.digest
+ )
+ else:
+ model.tag.restore_tag_to_image(
+ repository_ref._db_id,
+ tag_name,
+ manifest_or_legacy_image.docker_image_id,
+ )
- # Generate a manifest for the tag, if necessary.
- tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
- if tag is None:
- return None
+ # Generate a manifest for the tag, if necessary.
+ tag = self.get_repo_tag(repository_ref, tag_name, include_legacy_image=True)
+ if tag is None:
+ return None
- self.backfill_manifest_for_tag(tag)
- return tag
+ self.backfill_manifest_for_tag(tag)
+ return tag
- def delete_tag(self, repository_ref, tag_name):
- """
+ def delete_tag(self, repository_ref, tag_name):
+ """
Deletes the latest, *active* tag with the given name in the repository.
"""
- repo = model.repository.lookup_repository(repository_ref._db_id)
- if repo is None:
- return None
+ repo = model.repository.lookup_repository(repository_ref._db_id)
+ if repo is None:
+ return None
- deleted_tag = model.tag.delete_tag(repo.namespace_user.username, repo.name, tag_name)
- return Tag.for_repository_tag(deleted_tag)
+ deleted_tag = model.tag.delete_tag(
+ repo.namespace_user.username, repo.name, tag_name
+ )
+ return Tag.for_repository_tag(deleted_tag)
- def delete_tags_for_manifest(self, manifest):
- """
+ def delete_tags_for_manifest(self, manifest):
+ """
Deletes all tags pointing to the given manifest, making the manifest inaccessible for pulling.
Returns the tags deleted, if any. Returns None on error.
"""
- try:
- tagmanifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- return None
+ try:
+ tagmanifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ return None
- namespace_name = tagmanifest.tag.repository.namespace_user.username
- repo_name = tagmanifest.tag.repository.name
- tags = model.tag.delete_manifest_by_digest(namespace_name, repo_name, manifest.digest)
- return [Tag.for_repository_tag(tag) for tag in tags]
+ namespace_name = tagmanifest.tag.repository.namespace_user.username
+ repo_name = tagmanifest.tag.repository.name
+ tags = model.tag.delete_manifest_by_digest(
+ namespace_name, repo_name, manifest.digest
+ )
+ return [Tag.for_repository_tag(tag) for tag in tags]
- def change_repository_tag_expiration(self, tag, expiration_date):
- """ Sets the expiration date of the tag under the matching repository to that given. If the
+ def change_repository_tag_expiration(self, tag, expiration_date):
+ """ Sets the expiration date of the tag under the matching repository to that given. If the
expiration date is None, then the tag will not expire. Returns a tuple of the previous
expiration timestamp in seconds (if any), and whether the operation succeeded.
"""
- try:
- tag_obj = database.RepositoryTag.get(id=tag._db_id)
- except database.RepositoryTag.DoesNotExist:
- return (None, False)
+ try:
+ tag_obj = database.RepositoryTag.get(id=tag._db_id)
+ except database.RepositoryTag.DoesNotExist:
+ return (None, False)
- return model.tag.change_tag_expiration(tag_obj, expiration_date)
+ return model.tag.change_tag_expiration(tag_obj, expiration_date)
- def get_legacy_images_owned_by_tag(self, tag):
- """ Returns all legacy images *solely owned and used* by the given tag. """
- try:
- tag_obj = database.RepositoryTag.get(id=tag._db_id)
- except database.RepositoryTag.DoesNotExist:
- return None
+ def get_legacy_images_owned_by_tag(self, tag):
+ """ Returns all legacy images *solely owned and used* by the given tag. """
+ try:
+ tag_obj = database.RepositoryTag.get(id=tag._db_id)
+ except database.RepositoryTag.DoesNotExist:
+ return None
- # Collect the IDs of all images that the tag uses.
- tag_image_ids = set()
- tag_image_ids.add(tag_obj.image.id)
- tag_image_ids.update(tag_obj.image.ancestor_id_list())
+ # Collect the IDs of all images that the tag uses.
+ tag_image_ids = set()
+ tag_image_ids.add(tag_obj.image.id)
+ tag_image_ids.update(tag_obj.image.ancestor_id_list())
- # Remove any images shared by other tags.
- for current_tag in model.tag.list_active_repo_tags(tag_obj.repository_id):
- if current_tag == tag_obj:
- continue
+ # Remove any images shared by other tags.
+ for current_tag in model.tag.list_active_repo_tags(tag_obj.repository_id):
+ if current_tag == tag_obj:
+ continue
- tag_image_ids.discard(current_tag.image.id)
- tag_image_ids = tag_image_ids.difference(current_tag.image.ancestor_id_list())
- if not tag_image_ids:
- return []
+ tag_image_ids.discard(current_tag.image.id)
+ tag_image_ids = tag_image_ids.difference(
+ current_tag.image.ancestor_id_list()
+ )
+ if not tag_image_ids:
+ return []
- if not tag_image_ids:
- return []
+ if not tag_image_ids:
+ return []
- # Load the images we need to return.
- images = database.Image.select().where(database.Image.id << list(tag_image_ids))
- all_image_ids = set()
- for image in images:
- all_image_ids.add(image.id)
- all_image_ids.update(image.ancestor_id_list())
+ # Load the images we need to return.
+ images = database.Image.select().where(database.Image.id << list(tag_image_ids))
+ all_image_ids = set()
+ for image in images:
+ all_image_ids.add(image.id)
+ all_image_ids.update(image.ancestor_id_list())
- # Build a map of all the images and their parents.
- images_map = {}
- all_images = database.Image.select().where(database.Image.id << list(all_image_ids))
- for image in all_images:
- images_map[image.id] = image
+ # Build a map of all the images and their parents.
+ images_map = {}
+ all_images = database.Image.select().where(
+ database.Image.id << list(all_image_ids)
+ )
+ for image in all_images:
+ images_map[image.id] = image
- return [LegacyImage.for_image(image, images_map=images_map) for image in images]
+ return [LegacyImage.for_image(image, images_map=images_map) for image in images]
- def get_security_status(self, manifest_or_legacy_image):
- """ Returns the security status for the given manifest or legacy image or None if none. """
- image = None
+ def get_security_status(self, manifest_or_legacy_image):
+ """ Returns the security status for the given manifest or legacy image or None if none. """
+ image = None
- if isinstance(manifest_or_legacy_image, Manifest):
- try:
- tag_manifest = database.TagManifest.get(id=manifest_or_legacy_image._db_id)
- image = tag_manifest.tag.image
- except database.TagManifest.DoesNotExist:
- return None
- else:
- try:
- image = database.Image.get(id=manifest_or_legacy_image._db_id)
- except database.Image.DoesNotExist:
- return None
+ if isinstance(manifest_or_legacy_image, Manifest):
+ try:
+ tag_manifest = database.TagManifest.get(
+ id=manifest_or_legacy_image._db_id
+ )
+ image = tag_manifest.tag.image
+ except database.TagManifest.DoesNotExist:
+ return None
+ else:
+ try:
+ image = database.Image.get(id=manifest_or_legacy_image._db_id)
+ except database.Image.DoesNotExist:
+ return None
- if image.security_indexed_engine is not None and image.security_indexed_engine >= 0:
- return SecurityScanStatus.SCANNED if image.security_indexed else SecurityScanStatus.FAILED
+ if (
+ image.security_indexed_engine is not None
+ and image.security_indexed_engine >= 0
+ ):
+ return (
+ SecurityScanStatus.SCANNED
+ if image.security_indexed
+ else SecurityScanStatus.FAILED
+ )
- return SecurityScanStatus.QUEUED
+ return SecurityScanStatus.QUEUED
- def reset_security_status(self, manifest_or_legacy_image):
- """ Resets the security status for the given manifest or legacy image, ensuring that it will
+ def reset_security_status(self, manifest_or_legacy_image):
+ """ Resets the security status for the given manifest or legacy image, ensuring that it will
get re-indexed.
"""
- image = None
+ image = None
- if isinstance(manifest_or_legacy_image, Manifest):
- try:
- tag_manifest = database.TagManifest.get(id=manifest_or_legacy_image._db_id)
- image = tag_manifest.tag.image
- except database.TagManifest.DoesNotExist:
- return None
- else:
- try:
- image = database.Image.get(id=manifest_or_legacy_image._db_id)
- except database.Image.DoesNotExist:
- return None
+ if isinstance(manifest_or_legacy_image, Manifest):
+ try:
+ tag_manifest = database.TagManifest.get(
+ id=manifest_or_legacy_image._db_id
+ )
+ image = tag_manifest.tag.image
+ except database.TagManifest.DoesNotExist:
+ return None
+ else:
+ try:
+ image = database.Image.get(id=manifest_or_legacy_image._db_id)
+ except database.Image.DoesNotExist:
+ return None
- assert image
- image.security_indexed = False
- image.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
- image.save()
+ assert image
+ image.security_indexed = False
+ image.security_indexed_engine = IMAGE_NOT_SCANNED_ENGINE_VERSION
+ image.save()
- def backfill_manifest_for_tag(self, tag):
- """ Backfills a manifest for the V1 tag specified.
+ def backfill_manifest_for_tag(self, tag):
+ """ Backfills a manifest for the V1 tag specified.
If a manifest already exists for the tag, returns that manifest.
NOTE: This method will only be necessary until we've completed the backfill, at which point
it should be removed.
"""
- # Ensure that there isn't already a manifest for the tag.
- tag_manifest = model.tag.get_tag_manifest(tag._db_id)
- if tag_manifest is not None:
- return Manifest.for_tag_manifest(tag_manifest)
+ # Ensure that there isn't already a manifest for the tag.
+ tag_manifest = model.tag.get_tag_manifest(tag._db_id)
+ if tag_manifest is not None:
+ return Manifest.for_tag_manifest(tag_manifest)
- # Create the manifest.
- try:
- tag_obj = database.RepositoryTag.get(id=tag._db_id)
- except database.RepositoryTag.DoesNotExist:
- return None
+ # Create the manifest.
+ try:
+ tag_obj = database.RepositoryTag.get(id=tag._db_id)
+ except database.RepositoryTag.DoesNotExist:
+ return None
- assert not tag_obj.hidden
+ assert not tag_obj.hidden
- repo = tag_obj.repository
+ repo = tag_obj.repository
- # Write the manifest to the DB.
- manifest = self._build_manifest_for_legacy_image(tag_obj.name, tag_obj.image)
- if manifest is None:
- return None
+ # Write the manifest to the DB.
+ manifest = self._build_manifest_for_legacy_image(tag_obj.name, tag_obj.image)
+ if manifest is None:
+ return None
- blob_query = self._lookup_repo_storages_by_content_checksum(repo, manifest.checksums)
- storage_map = {blob.content_checksum: blob.id for blob in blob_query}
- try:
- tag_manifest = model.tag.associate_generated_tag_manifest_with_tag(tag_obj, manifest,
- storage_map)
- assert tag_manifest
- except IntegrityError:
- tag_manifest = model.tag.get_tag_manifest(tag_obj)
+ blob_query = self._lookup_repo_storages_by_content_checksum(
+ repo, manifest.checksums
+ )
+ storage_map = {blob.content_checksum: blob.id for blob in blob_query}
+ try:
+ tag_manifest = model.tag.associate_generated_tag_manifest_with_tag(
+ tag_obj, manifest, storage_map
+ )
+ assert tag_manifest
+ except IntegrityError:
+ tag_manifest = model.tag.get_tag_manifest(tag_obj)
- return Manifest.for_tag_manifest(tag_manifest)
+ return Manifest.for_tag_manifest(tag_manifest)
- def list_manifest_layers(self, manifest, storage, include_placements=False):
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- logger.exception('Could not find tag manifest for manifest `%s`', manifest._db_id)
- return None
+ def list_manifest_layers(self, manifest, storage, include_placements=False):
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ logger.exception(
+ "Could not find tag manifest for manifest `%s`", manifest._db_id
+ )
+ return None
- try:
- parsed = manifest.get_parsed_manifest()
- except ManifestException:
- logger.exception('Could not parse and validate manifest `%s`', manifest._db_id)
- return None
+ try:
+ parsed = manifest.get_parsed_manifest()
+ except ManifestException:
+ logger.exception(
+ "Could not parse and validate manifest `%s`", manifest._db_id
+ )
+ return None
- repo_ref = RepositoryReference.for_id(tag_manifest.tag.repository_id)
- return self.list_parsed_manifest_layers(repo_ref, parsed, storage, include_placements)
+ repo_ref = RepositoryReference.for_id(tag_manifest.tag.repository_id)
+ return self.list_parsed_manifest_layers(
+ repo_ref, parsed, storage, include_placements
+ )
- def lookup_derived_image(self, manifest, verb, storage, varying_metadata=None,
- include_placements=False):
- """
+ def lookup_derived_image(
+ self, manifest, verb, storage, varying_metadata=None, include_placements=False
+ ):
+ """
Looks up the derived image for the given manifest, verb and optional varying metadata and
returns it or None if none.
"""
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- logger.exception('Could not find tag manifest for manifest `%s`', manifest._db_id)
- return None
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ logger.exception(
+ "Could not find tag manifest for manifest `%s`", manifest._db_id
+ )
+ return None
- repo_image = tag_manifest.tag.image
- derived = model.image.find_derived_storage_for_image(repo_image, verb, varying_metadata)
- return self._build_derived(derived, verb, varying_metadata, include_placements)
+ repo_image = tag_manifest.tag.image
+ derived = model.image.find_derived_storage_for_image(
+ repo_image, verb, varying_metadata
+ )
+ return self._build_derived(derived, verb, varying_metadata, include_placements)
- def lookup_or_create_derived_image(self, manifest, verb, storage_location, storage,
- varying_metadata=None, include_placements=False):
- """
+ def lookup_or_create_derived_image(
+ self,
+ manifest,
+ verb,
+ storage_location,
+ storage,
+ varying_metadata=None,
+ include_placements=False,
+ ):
+ """
Looks up the derived image for the given maniest, verb and optional varying metadata
and returns it. If none exists, a new derived image is created.
"""
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- logger.exception('Could not find tag manifest for manifest `%s`', manifest._db_id)
- return None
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ logger.exception(
+ "Could not find tag manifest for manifest `%s`", manifest._db_id
+ )
+ return None
- repo_image = tag_manifest.tag.image
- derived = model.image.find_or_create_derived_storage(repo_image, verb, storage_location,
- varying_metadata)
- return self._build_derived(derived, verb, varying_metadata, include_placements)
+ repo_image = tag_manifest.tag.image
+ derived = model.image.find_or_create_derived_storage(
+ repo_image, verb, storage_location, varying_metadata
+ )
+ return self._build_derived(derived, verb, varying_metadata, include_placements)
- def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
- """
+ def set_tags_expiration_for_manifest(self, manifest, expiration_sec):
+ """
Sets the expiration on all tags that point to the given manifest to that specified.
"""
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- return
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ return
- model.tag.set_tag_expiration_for_manifest(tag_manifest, expiration_sec)
+ model.tag.set_tag_expiration_for_manifest(tag_manifest, expiration_sec)
- def get_schema1_parsed_manifest(self, manifest, namespace_name, repo_name, tag_name, storage):
- """ Returns the schema 1 version of this manifest, or None if none. """
- try:
- return manifest.get_parsed_manifest()
- except ManifestException:
- return None
+ def get_schema1_parsed_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, storage
+ ):
+ """ Returns the schema 1 version of this manifest, or None if none. """
+ try:
+ return manifest.get_parsed_manifest()
+ except ManifestException:
+ return None
- def convert_manifest(self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes,
- storage):
- try:
- parsed = manifest.get_parsed_manifest()
- except ManifestException:
- return None
+ def convert_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes, storage
+ ):
+ try:
+ parsed = manifest.get_parsed_manifest()
+ except ManifestException:
+ return None
- try:
- return parsed.convert_manifest(allowed_mediatypes, namespace_name, repo_name, tag_name, None)
- except ManifestException:
- return None
+ try:
+ return parsed.convert_manifest(
+ allowed_mediatypes, namespace_name, repo_name, tag_name, None
+ )
+ except ManifestException:
+ return None
- def create_manifest_with_temp_tag(self, repository_ref, manifest_interface_instance,
- expiration_sec, storage):
- """ Creates a manifest under the repository and sets a temporary tag to point to it.
+ def create_manifest_with_temp_tag(
+ self, repository_ref, manifest_interface_instance, expiration_sec, storage
+ ):
+ """ Creates a manifest under the repository and sets a temporary tag to point to it.
Returns the manifest object created or None on error.
"""
- raise NotImplementedError('Unsupported in pre OCI model')
+ raise NotImplementedError("Unsupported in pre OCI model")
- def get_repo_blob_by_digest(self, repository_ref, blob_digest, include_placements=False):
- """
+ def get_repo_blob_by_digest(
+ self, repository_ref, blob_digest, include_placements=False
+ ):
+ """
Returns the blob in the repository with the given digest, if any or None if none. Note that
there may be multiple records in the same repository for the same blob digest, so the return
value of this function may change.
"""
- image_storage = self._get_shared_storage(blob_digest)
- if image_storage is None:
- try:
- image_storage = model.blob.get_repository_blob_by_digest(repository_ref._db_id, blob_digest)
- except model.BlobDoesNotExist:
- return None
+ image_storage = self._get_shared_storage(blob_digest)
+ if image_storage is None:
+ try:
+ image_storage = model.blob.get_repository_blob_by_digest(
+ repository_ref._db_id, blob_digest
+ )
+ except model.BlobDoesNotExist:
+ return None
- assert image_storage.cas_path is not None
+ assert image_storage.cas_path is not None
- placements = None
- if include_placements:
- placements = list(model.storage.get_storage_locations(image_storage.uuid))
+ placements = None
+ if include_placements:
+ placements = list(model.storage.get_storage_locations(image_storage.uuid))
- return Blob.for_image_storage(image_storage,
- storage_path=model.storage.get_layer_path(image_storage),
- placements=placements)
+ return Blob.for_image_storage(
+ image_storage,
+ storage_path=model.storage.get_layer_path(image_storage),
+ placements=placements,
+ )
- def list_parsed_manifest_layers(self, repository_ref, parsed_manifest, storage,
- include_placements=False):
- """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
+ def list_parsed_manifest_layers(
+ self, repository_ref, parsed_manifest, storage, include_placements=False
+ ):
+ """ Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
and working towards the leaf, including the associated Blob and its placements
(if specified).
"""
- return self._list_manifest_layers(repository_ref._db_id, parsed_manifest, storage,
- include_placements=include_placements)
+ return self._list_manifest_layers(
+ repository_ref._db_id,
+ parsed_manifest,
+ storage,
+ include_placements=include_placements,
+ )
- def get_manifest_local_blobs(self, manifest, include_placements=False):
- """ Returns the set of local blobs for the given manifest or None if none. """
- try:
- tag_manifest = database.TagManifest.get(id=manifest._db_id)
- except database.TagManifest.DoesNotExist:
- return None
+ def get_manifest_local_blobs(self, manifest, include_placements=False):
+ """ Returns the set of local blobs for the given manifest or None if none. """
+ try:
+ tag_manifest = database.TagManifest.get(id=manifest._db_id)
+ except database.TagManifest.DoesNotExist:
+ return None
- return self._get_manifest_local_blobs(manifest, tag_manifest.tag.repository_id,
- include_placements)
+ return self._get_manifest_local_blobs(
+ manifest, tag_manifest.tag.repository_id, include_placements
+ )
- def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
- """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
+ def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
+ """ Yields tags that contain one (or more) of the given layer ID pairs, in repositories
which have been registered for vulnerability_found notifications. Returns an iterator
of LikelyVulnerableTag instances.
"""
- event = database.ExternalNotificationEvent.get(name='vulnerability_found')
+ event = database.ExternalNotificationEvent.get(name="vulnerability_found")
- def filter_notifying_repos(query):
- return model.tag.filter_has_repository_event(query, event)
+ def filter_notifying_repos(query):
+ return model.tag.filter_has_repository_event(query, event)
- def filter_and_order(query):
- return model.tag.filter_tags_have_repository_event(query, event)
+ def filter_and_order(query):
+ return model.tag.filter_tags_have_repository_event(query, event)
- # Find the matching tags.
- tags = model.tag.get_matching_tags_for_images(layer_id_pairs,
- selections=[database.RepositoryTag,
- database.Image,
- database.ImageStorage],
- filter_images=filter_notifying_repos,
- filter_tags=filter_and_order)
- for tag in tags:
- yield LikelyVulnerableTag.for_repository_tag(tag, tag.repository)
+ # Find the matching tags.
+ tags = model.tag.get_matching_tags_for_images(
+ layer_id_pairs,
+ selections=[database.RepositoryTag, database.Image, database.ImageStorage],
+ filter_images=filter_notifying_repos,
+ filter_tags=filter_and_order,
+ )
+ for tag in tags:
+ yield LikelyVulnerableTag.for_repository_tag(tag, tag.repository)
pre_oci_model = PreOCIModel()
diff --git a/data/registry_model/shared.py b/data/registry_model/shared.py
index 82a01aa67..51df3e22e 100644
--- a/data/registry_model/shared.py
+++ b/data/registry_model/shared.py
@@ -10,500 +10,626 @@ from data.cache import cache_key
from data.model.oci.retriever import RepositoryContentRetriever
from data.model.blob import get_shared_blob
from data.registry_model.datatype import FromDictionaryException
-from data.registry_model.datatypes import (RepositoryReference, Blob, TorrentInfo, BlobUpload,
- LegacyImage, ManifestLayer, DerivedImage, ShallowTag)
+from data.registry_model.datatypes import (
+ RepositoryReference,
+ Blob,
+ TorrentInfo,
+ BlobUpload,
+ LegacyImage,
+ ManifestLayer,
+ DerivedImage,
+ ShallowTag,
+)
from image.docker.schema1 import ManifestException, DockerSchema1ManifestBuilder
from image.docker.schema2 import EMPTY_LAYER_BLOB_DIGEST
logger = logging.getLogger(__name__)
# The maximum size for generated manifest after which we remove extra metadata.
-MAXIMUM_GENERATED_MANIFEST_SIZE = 3 * 1024 * 1024 # 3 MB
+MAXIMUM_GENERATED_MANIFEST_SIZE = 3 * 1024 * 1024 # 3 MB
+
class SharedModel:
- """
+ """
SharedModel implements those data model operations for the registry API that are unchanged
between the old and new data models.
"""
- def lookup_repository(self, namespace_name, repo_name, kind_filter=None):
- """ Looks up and returns a reference to the repository with the given namespace and name,
+
+ def lookup_repository(self, namespace_name, repo_name, kind_filter=None):
+ """ Looks up and returns a reference to the repository with the given namespace and name,
or None if none. """
- repo = model.repository.get_repository(namespace_name, repo_name, kind_filter=kind_filter)
- state = repo.state if repo is not None else None
- return RepositoryReference.for_repo_obj(repo, namespace_name, repo_name,
- repo.namespace_user.stripe_id is None if repo else None,
- state=state)
+ repo = model.repository.get_repository(
+ namespace_name, repo_name, kind_filter=kind_filter
+ )
+ state = repo.state if repo is not None else None
+ return RepositoryReference.for_repo_obj(
+ repo,
+ namespace_name,
+ repo_name,
+ repo.namespace_user.stripe_id is None if repo else None,
+ state=state,
+ )
- def is_existing_disabled_namespace(self, namespace_name):
- """ Returns whether the given namespace exists and is disabled. """
- namespace = model.user.get_namespace_user(namespace_name)
- return namespace is not None and not namespace.enabled
+ def is_existing_disabled_namespace(self, namespace_name):
+ """ Returns whether the given namespace exists and is disabled. """
+ namespace = model.user.get_namespace_user(namespace_name)
+ return namespace is not None and not namespace.enabled
- def is_namespace_enabled(self, namespace_name):
- """ Returns whether the given namespace exists and is enabled. """
- namespace = model.user.get_namespace_user(namespace_name)
- return namespace is not None and namespace.enabled
+ def is_namespace_enabled(self, namespace_name):
+ """ Returns whether the given namespace exists and is enabled. """
+ namespace = model.user.get_namespace_user(namespace_name)
+ return namespace is not None and namespace.enabled
- def get_derived_image_signature(self, derived_image, signer_name):
- """
+ def get_derived_image_signature(self, derived_image, signer_name):
+ """
Returns the signature associated with the derived image and a specific signer or None if none.
"""
- try:
- derived_storage = database.DerivedStorageForImage.get(id=derived_image._db_id)
- except database.DerivedStorageForImage.DoesNotExist:
- return None
+ try:
+ derived_storage = database.DerivedStorageForImage.get(
+ id=derived_image._db_id
+ )
+ except database.DerivedStorageForImage.DoesNotExist:
+ return None
- storage = derived_storage.derivative
- signature_entry = model.storage.lookup_storage_signature(storage, signer_name)
- if signature_entry is None:
- return None
+ storage = derived_storage.derivative
+ signature_entry = model.storage.lookup_storage_signature(storage, signer_name)
+ if signature_entry is None:
+ return None
- return signature_entry.signature
+ return signature_entry.signature
- def set_derived_image_signature(self, derived_image, signer_name, signature):
- """
+ def set_derived_image_signature(self, derived_image, signer_name, signature):
+ """
Sets the calculated signature for the given derived image and signer to that specified.
"""
- try:
- derived_storage = database.DerivedStorageForImage.get(id=derived_image._db_id)
- except database.DerivedStorageForImage.DoesNotExist:
- return None
+ try:
+ derived_storage = database.DerivedStorageForImage.get(
+ id=derived_image._db_id
+ )
+ except database.DerivedStorageForImage.DoesNotExist:
+ return None
- storage = derived_storage.derivative
- signature_entry = model.storage.find_or_create_storage_signature(storage, signer_name)
- signature_entry.signature = signature
- signature_entry.uploading = False
- signature_entry.save()
+ storage = derived_storage.derivative
+ signature_entry = model.storage.find_or_create_storage_signature(
+ storage, signer_name
+ )
+ signature_entry.signature = signature
+ signature_entry.uploading = False
+ signature_entry.save()
- def delete_derived_image(self, derived_image):
- """
+ def delete_derived_image(self, derived_image):
+ """
Deletes a derived image and all of its storage.
"""
- try:
- derived_storage = database.DerivedStorageForImage.get(id=derived_image._db_id)
- except database.DerivedStorageForImage.DoesNotExist:
- return None
+ try:
+ derived_storage = database.DerivedStorageForImage.get(
+ id=derived_image._db_id
+ )
+ except database.DerivedStorageForImage.DoesNotExist:
+ return None
- model.image.delete_derived_storage(derived_storage)
+ model.image.delete_derived_storage(derived_storage)
- def set_derived_image_size(self, derived_image, compressed_size):
- """
+ def set_derived_image_size(self, derived_image, compressed_size):
+ """
Sets the compressed size on the given derived image.
"""
- try:
- derived_storage = database.DerivedStorageForImage.get(id=derived_image._db_id)
- except database.DerivedStorageForImage.DoesNotExist:
- return None
+ try:
+ derived_storage = database.DerivedStorageForImage.get(
+ id=derived_image._db_id
+ )
+ except database.DerivedStorageForImage.DoesNotExist:
+ return None
- storage_entry = derived_storage.derivative
- storage_entry.image_size = compressed_size
- storage_entry.uploading = False
- storage_entry.save()
+ storage_entry = derived_storage.derivative
+ storage_entry.image_size = compressed_size
+ storage_entry.uploading = False
+ storage_entry.save()
- def get_torrent_info(self, blob):
- """
+ def get_torrent_info(self, blob):
+ """
Returns the torrent information associated with the given blob or None if none.
"""
- try:
- image_storage = database.ImageStorage.get(id=blob._db_id)
- except database.ImageStorage.DoesNotExist:
- return None
+ try:
+ image_storage = database.ImageStorage.get(id=blob._db_id)
+ except database.ImageStorage.DoesNotExist:
+ return None
- try:
- torrent_info = model.storage.get_torrent_info(image_storage)
- except model.TorrentInfoDoesNotExist:
- return None
+ try:
+ torrent_info = model.storage.get_torrent_info(image_storage)
+ except model.TorrentInfoDoesNotExist:
+ return None
- return TorrentInfo.for_torrent_info(torrent_info)
+ return TorrentInfo.for_torrent_info(torrent_info)
- def set_torrent_info(self, blob, piece_length, pieces):
- """
+ def set_torrent_info(self, blob, piece_length, pieces):
+ """
Sets the torrent infomation associated with the given blob to that specified.
"""
- try:
- image_storage = database.ImageStorage.get(id=blob._db_id)
- except database.ImageStorage.DoesNotExist:
- return None
+ try:
+ image_storage = database.ImageStorage.get(id=blob._db_id)
+ except database.ImageStorage.DoesNotExist:
+ return None
- torrent_info = model.storage.save_torrent_info(image_storage, piece_length, pieces)
- return TorrentInfo.for_torrent_info(torrent_info)
+ torrent_info = model.storage.save_torrent_info(
+ image_storage, piece_length, pieces
+ )
+ return TorrentInfo.for_torrent_info(torrent_info)
- @abstractmethod
- def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
- pass
+ @abstractmethod
+ def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
+ pass
- def lookup_cached_active_repository_tags(self, model_cache, repository_ref, start_pagination_id,
- limit):
- """
+ def lookup_cached_active_repository_tags(
+ self, model_cache, repository_ref, start_pagination_id, limit
+ ):
+ """
Returns a page of active tags in a repository. Note that the tags returned by this method
are ShallowTag objects, which only contain the tag name. This method will automatically cache
the result and check the cache before making a call.
"""
- def load_tags():
- tags = self.lookup_active_repository_tags(repository_ref, start_pagination_id, limit)
- return [tag.asdict() for tag in tags]
- tags_cache_key = cache_key.for_active_repo_tags(repository_ref._db_id, start_pagination_id,
- limit)
- result = model_cache.retrieve(tags_cache_key, load_tags)
+ def load_tags():
+ tags = self.lookup_active_repository_tags(
+ repository_ref, start_pagination_id, limit
+ )
+ return [tag.asdict() for tag in tags]
- try:
- return [ShallowTag.from_dict(tag_dict) for tag_dict in result]
- except FromDictionaryException:
- return self.lookup_active_repository_tags(repository_ref, start_pagination_id, limit)
+ tags_cache_key = cache_key.for_active_repo_tags(
+ repository_ref._db_id, start_pagination_id, limit
+ )
+ result = model_cache.retrieve(tags_cache_key, load_tags)
- def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
- """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace
+ try:
+ return [ShallowTag.from_dict(tag_dict) for tag_dict in result]
+ except FromDictionaryException:
+ return self.lookup_active_repository_tags(
+ repository_ref, start_pagination_id, limit
+ )
+
+ def get_cached_namespace_region_blacklist(self, model_cache, namespace_name):
+ """ Returns a cached set of ISO country codes blacklisted for pulls for the namespace
or None if the list could not be loaded.
"""
- def load_blacklist():
- restrictions = model.user.list_namespace_geo_restrictions(namespace_name)
- if restrictions is None:
- return None
+ def load_blacklist():
+ restrictions = model.user.list_namespace_geo_restrictions(namespace_name)
+ if restrictions is None:
+ return None
- return [restriction.restricted_region_iso_code for restriction in restrictions]
+ return [
+ restriction.restricted_region_iso_code for restriction in restrictions
+ ]
- blacklist_cache_key = cache_key.for_namespace_geo_restrictions(namespace_name)
- result = model_cache.retrieve(blacklist_cache_key, load_blacklist)
- if result is None:
- return None
+ blacklist_cache_key = cache_key.for_namespace_geo_restrictions(namespace_name)
+ result = model_cache.retrieve(blacklist_cache_key, load_blacklist)
+ if result is None:
+ return None
- return set(result)
+ return set(result)
- def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest):
- """
+ def get_cached_repo_blob(self, model_cache, namespace_name, repo_name, blob_digest):
+ """
Returns the blob in the repository with the given digest if any or None if none.
Caches the result in the caching system.
"""
- def load_blob():
- repository_ref = self.lookup_repository(namespace_name, repo_name)
- if repository_ref is None:
- return None
- blob_found = self.get_repo_blob_by_digest(repository_ref, blob_digest,
- include_placements=True)
- if blob_found is None:
- return None
+ def load_blob():
+ repository_ref = self.lookup_repository(namespace_name, repo_name)
+ if repository_ref is None:
+ return None
- return blob_found.asdict()
+ blob_found = self.get_repo_blob_by_digest(
+ repository_ref, blob_digest, include_placements=True
+ )
+ if blob_found is None:
+ return None
- blob_cache_key = cache_key.for_repository_blob(namespace_name, repo_name, blob_digest, 2)
- blob_dict = model_cache.retrieve(blob_cache_key, load_blob)
+ return blob_found.asdict()
- try:
- return Blob.from_dict(blob_dict) if blob_dict is not None else None
- except FromDictionaryException:
- # The data was stale in some way. Simply reload.
- repository_ref = self.lookup_repository(namespace_name, repo_name)
- if repository_ref is None:
- return None
+ blob_cache_key = cache_key.for_repository_blob(
+ namespace_name, repo_name, blob_digest, 2
+ )
+ blob_dict = model_cache.retrieve(blob_cache_key, load_blob)
- return self.get_repo_blob_by_digest(repository_ref, blob_digest, include_placements=True)
+ try:
+ return Blob.from_dict(blob_dict) if blob_dict is not None else None
+ except FromDictionaryException:
+ # The data was stale in some way. Simply reload.
+ repository_ref = self.lookup_repository(namespace_name, repo_name)
+ if repository_ref is None:
+ return None
- @abstractmethod
- def get_repo_blob_by_digest(self, repository_ref, blob_digest, include_placements=False):
- pass
+ return self.get_repo_blob_by_digest(
+ repository_ref, blob_digest, include_placements=True
+ )
- def create_blob_upload(self, repository_ref, new_upload_id, location_name, storage_metadata):
- """ Creates a new blob upload and returns a reference. If the blob upload could not be
+ @abstractmethod
+ def get_repo_blob_by_digest(
+ self, repository_ref, blob_digest, include_placements=False
+ ):
+ pass
+
+ def create_blob_upload(
+ self, repository_ref, new_upload_id, location_name, storage_metadata
+ ):
+ """ Creates a new blob upload and returns a reference. If the blob upload could not be
created, returns None. """
- repo = model.repository.lookup_repository(repository_ref._db_id)
- if repo is None:
- return None
+ repo = model.repository.lookup_repository(repository_ref._db_id)
+ if repo is None:
+ return None
- try:
- upload_record = model.blob.initiate_upload_for_repo(repo, new_upload_id, location_name,
- storage_metadata)
- return BlobUpload.for_upload(upload_record, location_name=location_name)
- except database.Repository.DoesNotExist:
- return None
+ try:
+ upload_record = model.blob.initiate_upload_for_repo(
+ repo, new_upload_id, location_name, storage_metadata
+ )
+ return BlobUpload.for_upload(upload_record, location_name=location_name)
+ except database.Repository.DoesNotExist:
+ return None
- def lookup_blob_upload(self, repository_ref, blob_upload_id):
- """ Looks up the blob upload withn the given ID under the specified repository and returns it
+ def lookup_blob_upload(self, repository_ref, blob_upload_id):
+ """ Looks up the blob upload withn the given ID under the specified repository and returns it
or None if none.
"""
- upload_record = model.blob.get_blob_upload_by_uuid(blob_upload_id)
- if upload_record is None:
- return None
+ upload_record = model.blob.get_blob_upload_by_uuid(blob_upload_id)
+ if upload_record is None:
+ return None
- return BlobUpload.for_upload(upload_record)
+ return BlobUpload.for_upload(upload_record)
- def update_blob_upload(self, blob_upload, uncompressed_byte_count, piece_hashes, piece_sha_state,
- storage_metadata, byte_count, chunk_count, sha_state):
- """ Updates the fields of the blob upload to match those given. Returns the updated blob upload
+ def update_blob_upload(
+ self,
+ blob_upload,
+ uncompressed_byte_count,
+ piece_hashes,
+ piece_sha_state,
+ storage_metadata,
+ byte_count,
+ chunk_count,
+ sha_state,
+ ):
+ """ Updates the fields of the blob upload to match those given. Returns the updated blob upload
or None if the record does not exists.
"""
- upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
- if upload_record is None:
- return None
+ upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
+ if upload_record is None:
+ return None
- upload_record.uncompressed_byte_count = uncompressed_byte_count
- upload_record.piece_hashes = piece_hashes
- upload_record.piece_sha_state = piece_sha_state
- upload_record.storage_metadata = storage_metadata
- upload_record.byte_count = byte_count
- upload_record.chunk_count = chunk_count
- upload_record.sha_state = sha_state
- upload_record.save()
- return BlobUpload.for_upload(upload_record)
+ upload_record.uncompressed_byte_count = uncompressed_byte_count
+ upload_record.piece_hashes = piece_hashes
+ upload_record.piece_sha_state = piece_sha_state
+ upload_record.storage_metadata = storage_metadata
+ upload_record.byte_count = byte_count
+ upload_record.chunk_count = chunk_count
+ upload_record.sha_state = sha_state
+ upload_record.save()
+ return BlobUpload.for_upload(upload_record)
- def delete_blob_upload(self, blob_upload):
- """ Deletes a blob upload record. """
- upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
- if upload_record is not None:
- upload_record.delete_instance()
+ def delete_blob_upload(self, blob_upload):
+ """ Deletes a blob upload record. """
+ upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
+ if upload_record is not None:
+ upload_record.delete_instance()
- def commit_blob_upload(self, blob_upload, blob_digest_str, blob_expiration_seconds):
- """ Commits the blob upload into a blob and sets an expiration before that blob will be GCed.
+ def commit_blob_upload(self, blob_upload, blob_digest_str, blob_expiration_seconds):
+ """ Commits the blob upload into a blob and sets an expiration before that blob will be GCed.
"""
- upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
- if upload_record is None:
- return None
+ upload_record = model.blob.get_blob_upload_by_uuid(blob_upload.upload_id)
+ if upload_record is None:
+ return None
- repository_id = upload_record.repository_id
+ repository_id = upload_record.repository_id
- # Create the blob and temporarily tag it.
- location_obj = model.storage.get_image_location_for_name(blob_upload.location_name)
- blob_record = model.blob.store_blob_record_and_temp_link_in_repo(
- repository_id, blob_digest_str, location_obj.id, blob_upload.byte_count,
- blob_expiration_seconds, blob_upload.uncompressed_byte_count)
+ # Create the blob and temporarily tag it.
+ location_obj = model.storage.get_image_location_for_name(
+ blob_upload.location_name
+ )
+ blob_record = model.blob.store_blob_record_and_temp_link_in_repo(
+ repository_id,
+ blob_digest_str,
+ location_obj.id,
+ blob_upload.byte_count,
+ blob_expiration_seconds,
+ blob_upload.uncompressed_byte_count,
+ )
- # Delete the blob upload.
- upload_record.delete_instance()
- return Blob.for_image_storage(blob_record,
- storage_path=model.storage.get_layer_path(blob_record))
+ # Delete the blob upload.
+ upload_record.delete_instance()
+ return Blob.for_image_storage(
+ blob_record, storage_path=model.storage.get_layer_path(blob_record)
+ )
- def mount_blob_into_repository(self, blob, target_repository_ref, expiration_sec):
- """
+ def mount_blob_into_repository(self, blob, target_repository_ref, expiration_sec):
+ """
Mounts the blob from another repository into the specified target repository, and adds an
expiration before that blob is automatically GCed. This function is useful during push
operations if an existing blob from another repository is being pushed. Returns False if
the mounting fails.
"""
- storage = model.blob.temp_link_blob(target_repository_ref._db_id, blob.digest, expiration_sec)
- return bool(storage)
+ storage = model.blob.temp_link_blob(
+ target_repository_ref._db_id, blob.digest, expiration_sec
+ )
+ return bool(storage)
- def get_legacy_images(self, repository_ref):
- """
+ def get_legacy_images(self, repository_ref):
+ """
Returns an iterator of all the LegacyImage's defined in the matching repository.
"""
- repo = model.repository.lookup_repository(repository_ref._db_id)
- if repo is None:
- return None
+ repo = model.repository.lookup_repository(repository_ref._db_id)
+ if repo is None:
+ return None
- all_images = model.image.get_repository_images_without_placements(repo)
- all_images_map = {image.id: image for image in all_images}
+ all_images = model.image.get_repository_images_without_placements(repo)
+ all_images_map = {image.id: image for image in all_images}
- all_tags = model.tag.list_repository_tags(repo.namespace_user.username, repo.name)
- tags_by_image_id = defaultdict(list)
- for tag in all_tags:
- tags_by_image_id[tag.image_id].append(tag)
+ all_tags = model.tag.list_repository_tags(
+ repo.namespace_user.username, repo.name
+ )
+ tags_by_image_id = defaultdict(list)
+ for tag in all_tags:
+ tags_by_image_id[tag.image_id].append(tag)
- return [LegacyImage.for_image(image, images_map=all_images_map, tags_map=tags_by_image_id)
- for image in all_images]
+ return [
+ LegacyImage.for_image(
+ image, images_map=all_images_map, tags_map=tags_by_image_id
+ )
+ for image in all_images
+ ]
- def get_legacy_image(self, repository_ref, docker_image_id, include_parents=False,
- include_blob=False):
- """
+ def get_legacy_image(
+ self, repository_ref, docker_image_id, include_parents=False, include_blob=False
+ ):
+ """
Returns the matching LegacyImages under the matching repository, if any. If none,
returns None.
"""
- repo = model.repository.lookup_repository(repository_ref._db_id)
- if repo is None:
- return None
+ repo = model.repository.lookup_repository(repository_ref._db_id)
+ if repo is None:
+ return None
- image = model.image.get_image(repository_ref._db_id, docker_image_id)
- if image is None:
- return None
+ image = model.image.get_image(repository_ref._db_id, docker_image_id)
+ if image is None:
+ return None
- parent_images_map = None
- if include_parents:
- parent_images = model.image.get_parent_images(repo.namespace_user.username, repo.name, image)
- parent_images_map = {image.id: image for image in parent_images}
+ parent_images_map = None
+ if include_parents:
+ parent_images = model.image.get_parent_images(
+ repo.namespace_user.username, repo.name, image
+ )
+ parent_images_map = {image.id: image for image in parent_images}
- blob = None
- if include_blob:
- placements = list(model.storage.get_storage_locations(image.storage.uuid))
- blob = Blob.for_image_storage(image.storage,
- storage_path=model.storage.get_layer_path(image.storage),
- placements=placements)
+ blob = None
+ if include_blob:
+ placements = list(model.storage.get_storage_locations(image.storage.uuid))
+ blob = Blob.for_image_storage(
+ image.storage,
+ storage_path=model.storage.get_layer_path(image.storage),
+ placements=placements,
+ )
- return LegacyImage.for_image(image, images_map=parent_images_map, blob=blob)
+ return LegacyImage.for_image(image, images_map=parent_images_map, blob=blob)
- def _get_manifest_local_blobs(self, manifest, repo_id, include_placements=False,
- by_manifest=False):
- parsed = manifest.get_parsed_manifest()
- if parsed is None:
- return None
+ def _get_manifest_local_blobs(
+ self, manifest, repo_id, include_placements=False, by_manifest=False
+ ):
+ parsed = manifest.get_parsed_manifest()
+ if parsed is None:
+ return None
- local_blob_digests = list(set(parsed.local_blob_digests))
- if not len(local_blob_digests):
- return []
+ local_blob_digests = list(set(parsed.local_blob_digests))
+ if not len(local_blob_digests):
+ return []
- blob_query = self._lookup_repo_storages_by_content_checksum(repo_id, local_blob_digests,
- by_manifest=by_manifest)
- blobs = []
- for image_storage in blob_query:
- placements = None
- if include_placements:
- placements = list(model.storage.get_storage_locations(image_storage.uuid))
+ blob_query = self._lookup_repo_storages_by_content_checksum(
+ repo_id, local_blob_digests, by_manifest=by_manifest
+ )
+ blobs = []
+ for image_storage in blob_query:
+ placements = None
+ if include_placements:
+ placements = list(
+ model.storage.get_storage_locations(image_storage.uuid)
+ )
- blob = Blob.for_image_storage(image_storage,
- storage_path=model.storage.get_layer_path(image_storage),
- placements=placements)
- blobs.append(blob)
+ blob = Blob.for_image_storage(
+ image_storage,
+ storage_path=model.storage.get_layer_path(image_storage),
+ placements=placements,
+ )
+ blobs.append(blob)
- return blobs
+ return blobs
- def _list_manifest_layers(self, repo_id, parsed, storage, include_placements=False,
- by_manifest=False):
- """ Returns an *ordered list* of the layers found in the manifest, starting at the base and
+ def _list_manifest_layers(
+ self, repo_id, parsed, storage, include_placements=False, by_manifest=False
+ ):
+ """ Returns an *ordered list* of the layers found in the manifest, starting at the base and
working towards the leaf, including the associated Blob and its placements (if specified).
Returns None if the manifest could not be parsed and validated.
"""
- assert not parsed.is_manifest_list
+ assert not parsed.is_manifest_list
- retriever = RepositoryContentRetriever(repo_id, storage)
- requires_empty_blob = parsed.get_requires_empty_layer_blob(retriever)
+ retriever = RepositoryContentRetriever(repo_id, storage)
+ requires_empty_blob = parsed.get_requires_empty_layer_blob(retriever)
- storage_map = {}
- blob_digests = list(parsed.local_blob_digests)
- if requires_empty_blob:
- blob_digests.append(EMPTY_LAYER_BLOB_DIGEST)
+ storage_map = {}
+ blob_digests = list(parsed.local_blob_digests)
+ if requires_empty_blob:
+ blob_digests.append(EMPTY_LAYER_BLOB_DIGEST)
- if blob_digests:
- blob_query = self._lookup_repo_storages_by_content_checksum(repo_id, blob_digests,
- by_manifest=by_manifest)
- storage_map = {blob.content_checksum: blob for blob in blob_query}
+ if blob_digests:
+ blob_query = self._lookup_repo_storages_by_content_checksum(
+ repo_id, blob_digests, by_manifest=by_manifest
+ )
+ storage_map = {blob.content_checksum: blob for blob in blob_query}
+ layers = parsed.get_layers(retriever)
+ if layers is None:
+ logger.error("Could not load layers for manifest `%s`", parsed.digest)
+ return None
- layers = parsed.get_layers(retriever)
- if layers is None:
- logger.error('Could not load layers for manifest `%s`', parsed.digest)
- return None
+ manifest_layers = []
+ for layer in layers:
+ if layer.is_remote:
+ manifest_layers.append(ManifestLayer(layer, None))
+ continue
- manifest_layers = []
- for layer in layers:
- if layer.is_remote:
- manifest_layers.append(ManifestLayer(layer, None))
- continue
+ digest_str = str(layer.blob_digest)
+ if digest_str not in storage_map:
+ logger.error(
+ "Missing digest `%s` for manifest `%s`",
+ layer.blob_digest,
+ parsed.digest,
+ )
+ return None
+
+ image_storage = storage_map[digest_str]
+ assert image_storage.cas_path is not None
+ assert image_storage.image_size is not None
+
+ placements = None
+ if include_placements:
+ placements = list(
+ model.storage.get_storage_locations(image_storage.uuid)
+ )
+
+ blob = Blob.for_image_storage(
+ image_storage,
+ storage_path=model.storage.get_layer_path(image_storage),
+ placements=placements,
+ )
+ manifest_layers.append(ManifestLayer(layer, blob))
+
+ return manifest_layers
+
+ def _build_derived(self, derived, verb, varying_metadata, include_placements):
+ if derived is None:
+ return None
+
+ derived_storage = derived.derivative
+ placements = None
+ if include_placements:
+ placements = list(model.storage.get_storage_locations(derived_storage.uuid))
+
+ blob = Blob.for_image_storage(
+ derived_storage,
+ storage_path=model.storage.get_layer_path(derived_storage),
+ placements=placements,
+ )
+
+ return DerivedImage.for_derived_storage(derived, verb, varying_metadata, blob)
+
+ def _build_manifest_for_legacy_image(self, tag_name, legacy_image_row):
+ import features
+
+ from app import app, docker_v2_signing_key
+
+ repo = legacy_image_row.repository
+ namespace_name = repo.namespace_user.username
+ repo_name = repo.name
+
+ # Find the v1 metadata for this image and its parents.
+ try:
+ parents = model.image.get_parent_images(
+ namespace_name, repo_name, legacy_image_row
+ )
+ except model.DataModelException:
+ logger.exception(
+ "Could not load parent images for legacy image %s", legacy_image_row.id
+ )
+ return None
+
+ # If the manifest is being generated under the library namespace, then we make its namespace
+ # empty.
+ manifest_namespace = namespace_name
+ if (
+ features.LIBRARY_SUPPORT
+ and namespace_name == app.config["LIBRARY_NAMESPACE"]
+ ):
+ manifest_namespace = ""
+
+ # Create and populate the manifest builder
+ builder = DockerSchema1ManifestBuilder(manifest_namespace, repo_name, tag_name)
+
+ # Add the leaf layer
+ builder.add_layer(
+ legacy_image_row.storage.content_checksum, legacy_image_row.v1_json_metadata
+ )
+ if legacy_image_row.storage.uploading:
+ logger.error(
+ "Cannot add an uploading storage row: %s", legacy_image_row.storage.id
+ )
+ return None
+
+ for parent_image in parents:
+ if parent_image.storage.uploading:
+ logger.error(
+ "Cannot add an uploading storage row: %s",
+ legacy_image_row.storage.id,
+ )
+ return None
+
+ builder.add_layer(
+ parent_image.storage.content_checksum, parent_image.v1_json_metadata
+ )
+
+ try:
+ built_manifest = builder.build(docker_v2_signing_key)
+
+ # If the generated manifest is greater than the maximum size, regenerate it with
+ # intermediate metadata layers stripped down to their bare essentials.
+ if (
+ len(built_manifest.bytes.as_encoded_str())
+ > MAXIMUM_GENERATED_MANIFEST_SIZE
+ ):
+ built_manifest = builder.with_metadata_removed().build(
+ docker_v2_signing_key
+ )
+
+ if (
+ len(built_manifest.bytes.as_encoded_str())
+ > MAXIMUM_GENERATED_MANIFEST_SIZE
+ ):
+ logger.error("Legacy image is too large to generate manifest")
+ return None
+
+ return built_manifest
+ except ManifestException as me:
+ logger.exception(
+ "Got exception when trying to build manifest for legacy image %s",
+ legacy_image_row,
+ )
+ return None
+
+ def _get_shared_storage(self, blob_digest):
+ """ Returns an ImageStorage row for the blob digest if it is a globally shared storage. """
+ # If the EMPTY_LAYER_BLOB_DIGEST is in the checksums, look it up directly. Since we have
+ # so many duplicate copies in the database currently, looking it up bound to a repository
+ # can be incredibly slow, and, since it is defined as a globally shared layer, this is extra
+ # work we don't need to do.
+ if blob_digest == EMPTY_LAYER_BLOB_DIGEST:
+ return get_shared_blob(EMPTY_LAYER_BLOB_DIGEST)
- digest_str = str(layer.blob_digest)
- if digest_str not in storage_map:
- logger.error('Missing digest `%s` for manifest `%s`', layer.blob_digest, parsed.digest)
return None
- image_storage = storage_map[digest_str]
- assert image_storage.cas_path is not None
- assert image_storage.image_size is not None
+ def _lookup_repo_storages_by_content_checksum(
+ self, repo, checksums, by_manifest=False
+ ):
+ checksums = set(checksums)
- placements = None
- if include_placements:
- placements = list(model.storage.get_storage_locations(image_storage.uuid))
+ # Load any shared storages first.
+ extra_storages = []
+ for checksum in list(checksums):
+ shared_storage = self._get_shared_storage(checksum)
+ if shared_storage is not None:
+ extra_storages.append(shared_storage)
+ checksums.remove(checksum)
- blob = Blob.for_image_storage(image_storage,
- storage_path=model.storage.get_layer_path(image_storage),
- placements=placements)
- manifest_layers.append(ManifestLayer(layer, blob))
-
- return manifest_layers
-
- def _build_derived(self, derived, verb, varying_metadata, include_placements):
- if derived is None:
- return None
-
- derived_storage = derived.derivative
- placements = None
- if include_placements:
- placements = list(model.storage.get_storage_locations(derived_storage.uuid))
-
- blob = Blob.for_image_storage(derived_storage,
- storage_path=model.storage.get_layer_path(derived_storage),
- placements=placements)
-
- return DerivedImage.for_derived_storage(derived, verb, varying_metadata, blob)
-
- def _build_manifest_for_legacy_image(self, tag_name, legacy_image_row):
- import features
-
- from app import app, docker_v2_signing_key
-
- repo = legacy_image_row.repository
- namespace_name = repo.namespace_user.username
- repo_name = repo.name
-
- # Find the v1 metadata for this image and its parents.
- try:
- parents = model.image.get_parent_images(namespace_name, repo_name, legacy_image_row)
- except model.DataModelException:
- logger.exception('Could not load parent images for legacy image %s', legacy_image_row.id)
- return None
-
- # If the manifest is being generated under the library namespace, then we make its namespace
- # empty.
- manifest_namespace = namespace_name
- if features.LIBRARY_SUPPORT and namespace_name == app.config['LIBRARY_NAMESPACE']:
- manifest_namespace = ''
-
- # Create and populate the manifest builder
- builder = DockerSchema1ManifestBuilder(manifest_namespace, repo_name, tag_name)
-
- # Add the leaf layer
- builder.add_layer(legacy_image_row.storage.content_checksum, legacy_image_row.v1_json_metadata)
- if legacy_image_row.storage.uploading:
- logger.error('Cannot add an uploading storage row: %s', legacy_image_row.storage.id)
- return None
-
- for parent_image in parents:
- if parent_image.storage.uploading:
- logger.error('Cannot add an uploading storage row: %s', legacy_image_row.storage.id)
- return None
-
- builder.add_layer(parent_image.storage.content_checksum, parent_image.v1_json_metadata)
-
- try:
- built_manifest = builder.build(docker_v2_signing_key)
-
- # If the generated manifest is greater than the maximum size, regenerate it with
- # intermediate metadata layers stripped down to their bare essentials.
- if len(built_manifest.bytes.as_encoded_str()) > MAXIMUM_GENERATED_MANIFEST_SIZE:
- built_manifest = builder.with_metadata_removed().build(docker_v2_signing_key)
-
- if len(built_manifest.bytes.as_encoded_str()) > MAXIMUM_GENERATED_MANIFEST_SIZE:
- logger.error('Legacy image is too large to generate manifest')
- return None
-
- return built_manifest
- except ManifestException as me:
- logger.exception('Got exception when trying to build manifest for legacy image %s',
- legacy_image_row)
- return None
-
- def _get_shared_storage(self, blob_digest):
- """ Returns an ImageStorage row for the blob digest if it is a globally shared storage. """
- # If the EMPTY_LAYER_BLOB_DIGEST is in the checksums, look it up directly. Since we have
- # so many duplicate copies in the database currently, looking it up bound to a repository
- # can be incredibly slow, and, since it is defined as a globally shared layer, this is extra
- # work we don't need to do.
- if blob_digest == EMPTY_LAYER_BLOB_DIGEST:
- return get_shared_blob(EMPTY_LAYER_BLOB_DIGEST)
-
- return None
-
- def _lookup_repo_storages_by_content_checksum(self, repo, checksums, by_manifest=False):
- checksums = set(checksums)
-
- # Load any shared storages first.
- extra_storages = []
- for checksum in list(checksums):
- shared_storage = self._get_shared_storage(checksum)
- if shared_storage is not None:
- extra_storages.append(shared_storage)
- checksums.remove(checksum)
-
- found = []
- if checksums:
- found = list(model.storage.lookup_repo_storages_by_content_checksum(repo, checksums,
- by_manifest=by_manifest))
- return found + extra_storages
+ found = []
+ if checksums:
+ found = list(
+ model.storage.lookup_repo_storages_by_content_checksum(
+ repo, checksums, by_manifest=by_manifest
+ )
+ )
+ return found + extra_storages
diff --git a/data/registry_model/test/test_blobuploader.py b/data/registry_model/test/test_blobuploader.py
index 8b539c617..d60c2e247 100644
--- a/data/registry_model/test/test_blobuploader.py
+++ b/data/registry_model/test/test_blobuploader.py
@@ -7,139 +7,144 @@ from contextlib import closing
import pytest
-from data.registry_model.blobuploader import (retrieve_blob_upload_manager,
- upload_blob, BlobUploadException,
- BlobDigestMismatchException, BlobTooLargeException,
- BlobUploadSettings)
+from data.registry_model.blobuploader import (
+ retrieve_blob_upload_manager,
+ upload_blob,
+ BlobUploadException,
+ BlobDigestMismatchException,
+ BlobTooLargeException,
+ BlobUploadSettings,
+)
from data.registry_model.registry_pre_oci_model import PreOCIModel
from storage.distributedstorage import DistributedStorage
from storage.fakestorage import FakeStorage
from test.fixtures import *
+
@pytest.fixture()
def pre_oci_model(initialized_db):
- return PreOCIModel()
+ return PreOCIModel()
-@pytest.mark.parametrize('chunk_count', [
- 0,
- 1,
- 2,
- 10,
-])
-@pytest.mark.parametrize('subchunk', [
- True,
- False,
-])
+
+@pytest.mark.parametrize("chunk_count", [0, 1, 2, 10])
+@pytest.mark.parametrize("subchunk", [True, False])
def test_basic_upload_blob(chunk_count, subchunk, pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('2M', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("2M", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- data = ''
- with upload_blob(repository_ref, storage, settings) as manager:
- assert manager
- assert manager.blob_upload_id
+ data = ""
+ with upload_blob(repository_ref, storage, settings) as manager:
+ assert manager
+ assert manager.blob_upload_id
- for index in range(0, chunk_count):
- chunk_data = os.urandom(100)
- data += chunk_data
+ for index in range(0, chunk_count):
+ chunk_data = os.urandom(100)
+ data += chunk_data
- if subchunk:
- manager.upload_chunk(app_config, BytesIO(chunk_data))
- manager.upload_chunk(app_config, BytesIO(chunk_data), (index * 100) + 50)
- else:
- manager.upload_chunk(app_config, BytesIO(chunk_data))
+ if subchunk:
+ manager.upload_chunk(app_config, BytesIO(chunk_data))
+ manager.upload_chunk(
+ app_config, BytesIO(chunk_data), (index * 100) + 50
+ )
+ else:
+ manager.upload_chunk(app_config, BytesIO(chunk_data))
- blob = manager.commit_to_blob(app_config)
+ blob = manager.commit_to_blob(app_config)
- # Check the blob.
- assert blob.compressed_size == len(data)
- assert not blob.uploading
- assert blob.digest == 'sha256:' + hashlib.sha256(data).hexdigest()
+ # Check the blob.
+ assert blob.compressed_size == len(data)
+ assert not blob.uploading
+ assert blob.digest == "sha256:" + hashlib.sha256(data).hexdigest()
- # Ensure the blob exists in storage and has the expected data.
- assert storage.get_content(['local_us'], blob.storage_path) == data
+ # Ensure the blob exists in storage and has the expected data.
+ assert storage.get_content(["local_us"], blob.storage_path) == data
def test_cancel_upload(pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('2M', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("2M", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- blob_upload_id = None
- with upload_blob(repository_ref, storage, settings) as manager:
- blob_upload_id = manager.blob_upload_id
- assert pre_oci_model.lookup_blob_upload(repository_ref, blob_upload_id) is not None
+ blob_upload_id = None
+ with upload_blob(repository_ref, storage, settings) as manager:
+ blob_upload_id = manager.blob_upload_id
+ assert (
+ pre_oci_model.lookup_blob_upload(repository_ref, blob_upload_id) is not None
+ )
- manager.upload_chunk(app_config, BytesIO('hello world'))
+ manager.upload_chunk(app_config, BytesIO("hello world"))
- # Since the blob was not comitted, the upload should be deleted.
- assert blob_upload_id
- assert pre_oci_model.lookup_blob_upload(repository_ref, blob_upload_id) is None
+ # Since the blob was not comitted, the upload should be deleted.
+ assert blob_upload_id
+ assert pre_oci_model.lookup_blob_upload(repository_ref, blob_upload_id) is None
def test_too_large(pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('1K', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("1K", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- with upload_blob(repository_ref, storage, settings) as manager:
- with pytest.raises(BlobTooLargeException):
- manager.upload_chunk(app_config, BytesIO(os.urandom(1024 * 1024 * 2)))
+ with upload_blob(repository_ref, storage, settings) as manager:
+ with pytest.raises(BlobTooLargeException):
+ manager.upload_chunk(app_config, BytesIO(os.urandom(1024 * 1024 * 2)))
def test_extra_blob_stream_handlers(pre_oci_model):
- handler1_result = []
- handler2_result = []
+ handler1_result = []
+ handler2_result = []
- def handler1(bytes):
- handler1_result.append(bytes)
+ def handler1(bytes):
+ handler1_result.append(bytes)
- def handler2(bytes):
- handler2_result.append(bytes)
+ def handler2(bytes):
+ handler2_result.append(bytes)
- repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('1K', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("1K", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- with upload_blob(repository_ref, storage, settings,
- extra_blob_stream_handlers=[handler1, handler2]) as manager:
- manager.upload_chunk(app_config, BytesIO('hello '))
- manager.upload_chunk(app_config, BytesIO('world'))
+ with upload_blob(
+ repository_ref,
+ storage,
+ settings,
+ extra_blob_stream_handlers=[handler1, handler2],
+ ) as manager:
+ manager.upload_chunk(app_config, BytesIO("hello "))
+ manager.upload_chunk(app_config, BytesIO("world"))
- assert ''.join(handler1_result) == 'hello world'
- assert ''.join(handler2_result) == 'hello world'
+ assert "".join(handler1_result) == "hello world"
+ assert "".join(handler2_result) == "hello world"
def valid_tar_gz(contents):
- with closing(BytesIO()) as layer_data:
- with closing(tarfile.open(fileobj=layer_data, mode='w|gz')) as tar_file:
- tar_file_info = tarfile.TarInfo(name='somefile')
- tar_file_info.type = tarfile.REGTYPE
- tar_file_info.size = len(contents)
- tar_file_info.mtime = 1
- tar_file.addfile(tar_file_info, BytesIO(contents))
+ with closing(BytesIO()) as layer_data:
+ with closing(tarfile.open(fileobj=layer_data, mode="w|gz")) as tar_file:
+ tar_file_info = tarfile.TarInfo(name="somefile")
+ tar_file_info.type = tarfile.REGTYPE
+ tar_file_info.size = len(contents)
+ tar_file_info.mtime = 1
+ tar_file.addfile(tar_file_info, BytesIO(contents))
- layer_bytes = layer_data.getvalue()
- return layer_bytes
+ layer_bytes = layer_data.getvalue()
+ return layer_bytes
def test_uncompressed_size(pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('1K', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("1K", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- with upload_blob(repository_ref, storage, settings) as manager:
- manager.upload_chunk(app_config, BytesIO(valid_tar_gz('hello world')))
+ with upload_blob(repository_ref, storage, settings) as manager:
+ manager.upload_chunk(app_config, BytesIO(valid_tar_gz("hello world")))
- blob = manager.commit_to_blob(app_config)
-
- assert blob.compressed_size is not None
- assert blob.uncompressed_size is not None
+ blob = manager.commit_to_blob(app_config)
+ assert blob.compressed_size is not None
+ assert blob.uncompressed_size is not None
diff --git a/data/registry_model/test/test_interface.py b/data/registry_model/test/test_interface.py
index 8255ade6d..a0092f578 100644
--- a/data/registry_model/test/test_interface.py
+++ b/data/registry_model/test/test_interface.py
@@ -14,10 +14,21 @@ from playhouse.test_utils import assert_query_count
from app import docker_v2_signing_key, storage
from data import model
-from data.database import (TagManifestLabelMap, TagManifestToManifest, Manifest, ManifestBlob,
- ManifestLegacyImage, ManifestLabel,
- TagManifest, TagManifestLabel, DerivedStorageForImage,
- TorrentInfo, Tag, TagToRepositoryTag, ImageStorageLocation)
+from data.database import (
+ TagManifestLabelMap,
+ TagManifestToManifest,
+ Manifest,
+ ManifestBlob,
+ ManifestLegacyImage,
+ ManifestLabel,
+ TagManifest,
+ TagManifestLabel,
+ DerivedStorageForImage,
+ TorrentInfo,
+ Tag,
+ TagToRepositoryTag,
+ ImageStorageLocation,
+)
from data.cache.impl import InMemoryDataModelCache
from data.registry_model.registry_pre_oci_model import PreOCIModel
from data.registry_model.registry_oci_model import OCIModel
@@ -26,8 +37,11 @@ from data.registry_model.blobuploader import upload_blob, BlobUploadSettings
from data.registry_model.modelsplitter import SplitModel
from data.model.blob import store_blob_record_and_temp_link
from image.docker.types import ManifestImageLayer
-from image.docker.schema1 import (DockerSchema1ManifestBuilder, DOCKER_SCHEMA1_CONTENT_TYPES,
- DockerSchema1Manifest)
+from image.docker.schema1 import (
+ DockerSchema1ManifestBuilder,
+ DOCKER_SCHEMA1_CONTENT_TYPES,
+ DockerSchema1Manifest,
+)
from image.docker.schema2.manifest import DockerSchema2ManifestBuilder
from image.docker.schema2.list import DockerSchema2ManifestListBuilder
from util.bytes import Bytes
@@ -35,1061 +49,1256 @@ from util.bytes import Bytes
from test.fixtures import *
-@pytest.fixture(params=[PreOCIModel(), OCIModel(), OCIModel(oci_model_only=False),
- SplitModel(0, {'devtable'}, {'buynlarge'}, False),
- SplitModel(1.0, {'devtable'}, {'buynlarge'}, False),
- SplitModel(1.0, {'devtable'}, {'buynlarge'}, True)])
+@pytest.fixture(
+ params=[
+ PreOCIModel(),
+ OCIModel(),
+ OCIModel(oci_model_only=False),
+ SplitModel(0, {"devtable"}, {"buynlarge"}, False),
+ SplitModel(1.0, {"devtable"}, {"buynlarge"}, False),
+ SplitModel(1.0, {"devtable"}, {"buynlarge"}, True),
+ ]
+)
def registry_model(request, initialized_db):
- return request.param
+ return request.param
+
@pytest.fixture()
def pre_oci_model(initialized_db):
- return PreOCIModel()
+ return PreOCIModel()
+
@pytest.fixture()
def oci_model(initialized_db):
- return OCIModel()
+ return OCIModel()
-@pytest.mark.parametrize('names, expected', [
- (['unknown'], None),
- (['latest'], {'latest'}),
- (['latest', 'prod'], {'latest', 'prod'}),
- (['latest', 'prod', 'another'], {'latest', 'prod'}),
- (['foo', 'prod'], {'prod'}),
-])
+@pytest.mark.parametrize(
+ "names, expected",
+ [
+ (["unknown"], None),
+ (["latest"], {"latest"}),
+ (["latest", "prod"], {"latest", "prod"}),
+ (["latest", "prod", "another"], {"latest", "prod"}),
+ (["foo", "prod"], {"prod"}),
+ ],
+)
def test_find_matching_tag(names, expected, registry_model):
- repo = model.repository.get_repository('devtable', 'simple')
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found = registry_model.find_matching_tag(repository_ref, names)
- if expected is None:
- assert found is None
- else:
- assert found.name in expected
- assert found.repository.namespace_name == 'devtable'
- assert found.repository.name == 'simple'
+ repo = model.repository.get_repository("devtable", "simple")
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found = registry_model.find_matching_tag(repository_ref, names)
+ if expected is None:
+ assert found is None
+ else:
+ assert found.name in expected
+ assert found.repository.namespace_name == "devtable"
+ assert found.repository.name == "simple"
-@pytest.mark.parametrize('repo_namespace, repo_name, expected', [
- ('devtable', 'simple', {'latest', 'prod'}),
- ('buynlarge', 'orgrepo', {'latest', 'prod'}),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name, expected",
+ [
+ ("devtable", "simple", {"latest", "prod"}),
+ ("buynlarge", "orgrepo", {"latest", "prod"}),
+ ],
+)
def test_get_most_recent_tag(repo_namespace, repo_name, expected, registry_model):
- repo = model.repository.get_repository(repo_namespace, repo_name)
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found = registry_model.get_most_recent_tag(repository_ref)
- if expected is None:
- assert found is None
- else:
- assert found.name in expected
+ repo = model.repository.get_repository(repo_namespace, repo_name)
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found = registry_model.get_most_recent_tag(repository_ref)
+ if expected is None:
+ assert found is None
+ else:
+ assert found.name in expected
-@pytest.mark.parametrize('repo_namespace, repo_name, expected', [
- ('devtable', 'simple', True),
- ('buynlarge', 'orgrepo', True),
- ('buynlarge', 'unknownrepo', False),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name, expected",
+ [
+ ("devtable", "simple", True),
+ ("buynlarge", "orgrepo", True),
+ ("buynlarge", "unknownrepo", False),
+ ],
+)
def test_lookup_repository(repo_namespace, repo_name, expected, registry_model):
- repo_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- if expected:
- assert repo_ref
- else:
- assert repo_ref is None
+ repo_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ if expected:
+ assert repo_ref
+ else:
+ assert repo_ref is None
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('buynlarge', 'orgrepo'),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name", [("devtable", "simple"), ("buynlarge", "orgrepo")]
+)
def test_lookup_manifests(repo_namespace, repo_name, registry_model):
- repo = model.repository.get_repository(repo_namespace, repo_name)
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found_tag = registry_model.find_matching_tag(repository_ref, ['latest'])
- found_manifest = registry_model.get_manifest_for_tag(found_tag)
- found = registry_model.lookup_manifest_by_digest(repository_ref, found_manifest.digest,
- include_legacy_image=True)
- assert found._db_id == found_manifest._db_id
- assert found.digest == found_manifest.digest
- assert found.legacy_image
- assert found.legacy_image.parents
+ repo = model.repository.get_repository(repo_namespace, repo_name)
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found_tag = registry_model.find_matching_tag(repository_ref, ["latest"])
+ found_manifest = registry_model.get_manifest_for_tag(found_tag)
+ found = registry_model.lookup_manifest_by_digest(
+ repository_ref, found_manifest.digest, include_legacy_image=True
+ )
+ assert found._db_id == found_manifest._db_id
+ assert found.digest == found_manifest.digest
+ assert found.legacy_image
+ assert found.legacy_image.parents
- schema1_parsed = registry_model.get_schema1_parsed_manifest(found, 'foo', 'bar', 'baz', storage)
- assert schema1_parsed is not None
+ schema1_parsed = registry_model.get_schema1_parsed_manifest(
+ found, "foo", "bar", "baz", storage
+ )
+ assert schema1_parsed is not None
def test_lookup_unknown_manifest(registry_model):
- repo = model.repository.get_repository('devtable', 'simple')
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found = registry_model.lookup_manifest_by_digest(repository_ref, 'sha256:deadbeef')
- assert found is None
+ repo = model.repository.get_repository("devtable", "simple")
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found = registry_model.lookup_manifest_by_digest(repository_ref, "sha256:deadbeef")
+ assert found is None
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
def test_legacy_images(repo_namespace, repo_name, registry_model):
- repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- legacy_images = registry_model.get_legacy_images(repository_ref)
- assert len(legacy_images)
+ repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ legacy_images = registry_model.get_legacy_images(repository_ref)
+ assert len(legacy_images)
- found_tags = set()
- for image in legacy_images:
- found_image = registry_model.get_legacy_image(repository_ref, image.docker_image_id,
- include_parents=True)
+ found_tags = set()
+ for image in legacy_images:
+ found_image = registry_model.get_legacy_image(
+ repository_ref, image.docker_image_id, include_parents=True
+ )
- with assert_query_count(5 if found_image.parents else 4):
- found_image = registry_model.get_legacy_image(repository_ref, image.docker_image_id,
- include_parents=True, include_blob=True)
- assert found_image.docker_image_id == image.docker_image_id
- assert found_image.parents == image.parents
- assert found_image.blob
- assert found_image.blob.placements
+ with assert_query_count(5 if found_image.parents else 4):
+ found_image = registry_model.get_legacy_image(
+ repository_ref,
+ image.docker_image_id,
+ include_parents=True,
+ include_blob=True,
+ )
+ assert found_image.docker_image_id == image.docker_image_id
+ assert found_image.parents == image.parents
+ assert found_image.blob
+ assert found_image.blob.placements
- # Check that the tags list can be retrieved.
- assert image.tags is not None
- found_tags.update({tag.name for tag in image.tags})
+ # Check that the tags list can be retrieved.
+ assert image.tags is not None
+ found_tags.update({tag.name for tag in image.tags})
- # Check against the actual DB row.
- model_image = model.image.get_image(repository_ref._db_id, found_image.docker_image_id)
- assert model_image.id == found_image._db_id
- assert ([pid for pid in reversed(model_image.ancestor_id_list())] ==
- [p._db_id for p in found_image.parents])
+ # Check against the actual DB row.
+ model_image = model.image.get_image(
+ repository_ref._db_id, found_image.docker_image_id
+ )
+ assert model_image.id == found_image._db_id
+ assert [pid for pid in reversed(model_image.ancestor_id_list())] == [
+ p._db_id for p in found_image.parents
+ ]
- # Try without parents and ensure it raises an exception.
- found_image = registry_model.get_legacy_image(repository_ref, image.docker_image_id,
- include_parents=False)
- with pytest.raises(Exception):
- assert not found_image.parents
+ # Try without parents and ensure it raises an exception.
+ found_image = registry_model.get_legacy_image(
+ repository_ref, image.docker_image_id, include_parents=False
+ )
+ with pytest.raises(Exception):
+ assert not found_image.parents
- assert found_tags
+ assert found_tags
- unknown = registry_model.get_legacy_image(repository_ref, 'unknown', include_parents=True)
- assert unknown is None
+ unknown = registry_model.get_legacy_image(
+ repository_ref, "unknown", include_parents=True
+ )
+ assert unknown is None
def test_manifest_labels(registry_model):
- repo = model.repository.get_repository('devtable', 'simple')
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found_tag = registry_model.find_matching_tag(repository_ref, ['latest'])
- found_manifest = registry_model.get_manifest_for_tag(found_tag)
+ repo = model.repository.get_repository("devtable", "simple")
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found_tag = registry_model.find_matching_tag(repository_ref, ["latest"])
+ found_manifest = registry_model.get_manifest_for_tag(found_tag)
- # Create a new label.
- created = registry_model.create_manifest_label(found_manifest, 'foo', 'bar', 'api')
- assert created.key == 'foo'
- assert created.value == 'bar'
- assert created.source_type_name == 'api'
- assert created.media_type_name == 'text/plain'
+ # Create a new label.
+ created = registry_model.create_manifest_label(found_manifest, "foo", "bar", "api")
+ assert created.key == "foo"
+ assert created.value == "bar"
+ assert created.source_type_name == "api"
+ assert created.media_type_name == "text/plain"
- # Ensure we can look it up.
- assert registry_model.get_manifest_label(found_manifest, created.uuid) == created
+ # Ensure we can look it up.
+ assert registry_model.get_manifest_label(found_manifest, created.uuid) == created
- # Ensure it is in our list of labels.
- assert created in registry_model.list_manifest_labels(found_manifest)
- assert created in registry_model.list_manifest_labels(found_manifest, key_prefix='fo')
+ # Ensure it is in our list of labels.
+ assert created in registry_model.list_manifest_labels(found_manifest)
+ assert created in registry_model.list_manifest_labels(
+ found_manifest, key_prefix="fo"
+ )
- # Ensure it is *not* in our filtered list.
- assert created not in registry_model.list_manifest_labels(found_manifest, key_prefix='ba')
+ # Ensure it is *not* in our filtered list.
+ assert created not in registry_model.list_manifest_labels(
+ found_manifest, key_prefix="ba"
+ )
- # Delete the label and ensure it is gone.
- assert registry_model.delete_manifest_label(found_manifest, created.uuid)
- assert registry_model.get_manifest_label(found_manifest, created.uuid) is None
- assert created not in registry_model.list_manifest_labels(found_manifest)
+ # Delete the label and ensure it is gone.
+ assert registry_model.delete_manifest_label(found_manifest, created.uuid)
+ assert registry_model.get_manifest_label(found_manifest, created.uuid) is None
+ assert created not in registry_model.list_manifest_labels(found_manifest)
def test_manifest_label_handlers(registry_model):
- repo = model.repository.get_repository('devtable', 'simple')
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found_tag = registry_model.get_repo_tag(repository_ref, 'latest')
- found_manifest = registry_model.get_manifest_for_tag(found_tag)
+ repo = model.repository.get_repository("devtable", "simple")
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found_tag = registry_model.get_repo_tag(repository_ref, "latest")
+ found_manifest = registry_model.get_manifest_for_tag(found_tag)
- # Ensure the tag has no expiration.
- assert found_tag.lifetime_end_ts is None
+ # Ensure the tag has no expiration.
+ assert found_tag.lifetime_end_ts is None
- # Create a new label with an expires-after.
- registry_model.create_manifest_label(found_manifest, 'quay.expires-after', '2h', 'api')
+ # Create a new label with an expires-after.
+ registry_model.create_manifest_label(
+ found_manifest, "quay.expires-after", "2h", "api"
+ )
- # Ensure the tag now has an expiration.
- updated_tag = registry_model.get_repo_tag(repository_ref, 'latest')
- assert updated_tag.lifetime_end_ts == (updated_tag.lifetime_start_ts + (60 * 60 * 2))
+ # Ensure the tag now has an expiration.
+ updated_tag = registry_model.get_repo_tag(repository_ref, "latest")
+ assert updated_tag.lifetime_end_ts == (
+ updated_tag.lifetime_start_ts + (60 * 60 * 2)
+ )
def test_batch_labels(registry_model):
- repo = model.repository.get_repository('devtable', 'history')
- repository_ref = RepositoryReference.for_repo_obj(repo)
- found_tag = registry_model.find_matching_tag(repository_ref, ['latest'])
- found_manifest = registry_model.get_manifest_for_tag(found_tag)
+ repo = model.repository.get_repository("devtable", "history")
+ repository_ref = RepositoryReference.for_repo_obj(repo)
+ found_tag = registry_model.find_matching_tag(repository_ref, ["latest"])
+ found_manifest = registry_model.get_manifest_for_tag(found_tag)
- with registry_model.batch_create_manifest_labels(found_manifest) as add_label:
- add_label('foo', '1', 'api')
- add_label('bar', '2', 'api')
- add_label('baz', '3', 'api')
+ with registry_model.batch_create_manifest_labels(found_manifest) as add_label:
+ add_label("foo", "1", "api")
+ add_label("bar", "2", "api")
+ add_label("baz", "3", "api")
- # Ensure we can look them up.
- assert len(registry_model.list_manifest_labels(found_manifest)) == 3
+ # Ensure we can look them up.
+ assert len(registry_model.list_manifest_labels(found_manifest)) == 3
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
def test_repository_tags(repo_namespace, repo_name, registry_model):
- repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- tags = registry_model.list_all_active_repository_tags(repository_ref, include_legacy_images=True)
- assert len(tags)
+ repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ tags = registry_model.list_all_active_repository_tags(
+ repository_ref, include_legacy_images=True
+ )
+ assert len(tags)
- tags_map = registry_model.get_legacy_tags_map(repository_ref, storage)
+ tags_map = registry_model.get_legacy_tags_map(repository_ref, storage)
- for tag in tags:
- found_tag = registry_model.get_repo_tag(repository_ref, tag.name, include_legacy_image=True)
- assert found_tag == tag
+ for tag in tags:
+ found_tag = registry_model.get_repo_tag(
+ repository_ref, tag.name, include_legacy_image=True
+ )
+ assert found_tag == tag
- if found_tag.legacy_image is None:
- continue
+ if found_tag.legacy_image is None:
+ continue
- found_image = registry_model.get_legacy_image(repository_ref,
- found_tag.legacy_image.docker_image_id)
- assert found_image == found_tag.legacy_image
- assert tag.name in tags_map
- assert tags_map[tag.name] == found_image.docker_image_id
+ found_image = registry_model.get_legacy_image(
+ repository_ref, found_tag.legacy_image.docker_image_id
+ )
+ assert found_image == found_tag.legacy_image
+ assert tag.name in tags_map
+ assert tags_map[tag.name] == found_image.docker_image_id
-@pytest.mark.parametrize('namespace, name, expected_tag_count, has_expired', [
- ('devtable', 'simple', 2, False),
- ('devtable', 'history', 2, True),
- ('devtable', 'gargantuan', 8, False),
- ('public', 'publicrepo', 1, False),
-])
-def test_repository_tag_history(namespace, name, expected_tag_count, has_expired, registry_model):
- # Pre-cache media type loads to ensure consistent query count.
- Manifest.media_type.get_name(1)
+@pytest.mark.parametrize(
+ "namespace, name, expected_tag_count, has_expired",
+ [
+ ("devtable", "simple", 2, False),
+ ("devtable", "history", 2, True),
+ ("devtable", "gargantuan", 8, False),
+ ("public", "publicrepo", 1, False),
+ ],
+)
+def test_repository_tag_history(
+ namespace, name, expected_tag_count, has_expired, registry_model
+):
+ # Pre-cache media type loads to ensure consistent query count.
+ Manifest.media_type.get_name(1)
- repository_ref = registry_model.lookup_repository(namespace, name)
- with assert_query_count(2):
- history, has_more = registry_model.list_repository_tag_history(repository_ref)
- assert not has_more
- assert len(history) == expected_tag_count
+ repository_ref = registry_model.lookup_repository(namespace, name)
+ with assert_query_count(2):
+ history, has_more = registry_model.list_repository_tag_history(repository_ref)
+ assert not has_more
+ assert len(history) == expected_tag_count
- for tag in history:
- # Retrieve the manifest to ensure it doesn't issue extra queries.
- tag.manifest
+ for tag in history:
+ # Retrieve the manifest to ensure it doesn't issue extra queries.
+ tag.manifest
- if has_expired:
- # Ensure the latest tag is marked expired, since there is an expired one.
- with assert_query_count(1):
- assert registry_model.has_expired_tag(repository_ref, 'latest')
+ if has_expired:
+ # Ensure the latest tag is marked expired, since there is an expired one.
+ with assert_query_count(1):
+ assert registry_model.has_expired_tag(repository_ref, "latest")
-@pytest.mark.parametrize('repositories, expected_tag_count', [
- ([], 0),
- ([('devtable', 'simple'), ('devtable', 'building')], 1),
-])
-def test_get_most_recent_tag_lifetime_start(repositories, expected_tag_count, registry_model):
- last_modified_map = registry_model.get_most_recent_tag_lifetime_start(
- [registry_model.lookup_repository(name, namespace) for name, namespace in repositories]
- )
+@pytest.mark.parametrize(
+ "repositories, expected_tag_count",
+ [([], 0), ([("devtable", "simple"), ("devtable", "building")], 1)],
+)
+def test_get_most_recent_tag_lifetime_start(
+ repositories, expected_tag_count, registry_model
+):
+ last_modified_map = registry_model.get_most_recent_tag_lifetime_start(
+ [
+ registry_model.lookup_repository(name, namespace)
+ for name, namespace in repositories
+ ]
+ )
- assert len(last_modified_map) == expected_tag_count
- for repo_id, last_modified in last_modified_map.items():
- tag = registry_model.get_most_recent_tag(RepositoryReference.for_id(repo_id))
- assert last_modified == tag.lifetime_start_ms / 1000
+ assert len(last_modified_map) == expected_tag_count
+ for repo_id, last_modified in last_modified_map.items():
+ tag = registry_model.get_most_recent_tag(RepositoryReference.for_id(repo_id))
+ assert last_modified == tag.lifetime_start_ms / 1000
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
-@pytest.mark.parametrize('via_manifest', [
- False,
- True,
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
+@pytest.mark.parametrize("via_manifest", [False, True])
def test_delete_tags(repo_namespace, repo_name, via_manifest, registry_model):
- repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- tags = registry_model.list_all_active_repository_tags(repository_ref)
- assert len(tags)
+ repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ tags = registry_model.list_all_active_repository_tags(repository_ref)
+ assert len(tags)
- # Save history before the deletions.
- previous_history, _ = registry_model.list_repository_tag_history(repository_ref, size=1000)
- assert len(previous_history) >= len(tags)
+ # Save history before the deletions.
+ previous_history, _ = registry_model.list_repository_tag_history(
+ repository_ref, size=1000
+ )
+ assert len(previous_history) >= len(tags)
- # Delete every tag in the repository.
- for tag in tags:
- if via_manifest:
- assert registry_model.delete_tag(repository_ref, tag.name)
- else:
- manifest = registry_model.get_manifest_for_tag(tag)
- if manifest is not None:
- assert registry_model.delete_tags_for_manifest(manifest)
+ # Delete every tag in the repository.
+ for tag in tags:
+ if via_manifest:
+ assert registry_model.delete_tag(repository_ref, tag.name)
+ else:
+ manifest = registry_model.get_manifest_for_tag(tag)
+ if manifest is not None:
+ assert registry_model.delete_tags_for_manifest(manifest)
- # Make sure the tag is no longer found.
- # TODO: Uncomment once we're done with the SplitModel.
- #with assert_query_count(1):
- found_tag = registry_model.get_repo_tag(repository_ref, tag.name, include_legacy_image=True)
- assert found_tag is None
+ # Make sure the tag is no longer found.
+ # TODO: Uncomment once we're done with the SplitModel.
+ # with assert_query_count(1):
+ found_tag = registry_model.get_repo_tag(
+ repository_ref, tag.name, include_legacy_image=True
+ )
+ assert found_tag is None
- # Ensure all tags have been deleted.
- tags = registry_model.list_all_active_repository_tags(repository_ref)
- assert not len(tags)
+ # Ensure all tags have been deleted.
+ tags = registry_model.list_all_active_repository_tags(repository_ref)
+ assert not len(tags)
- # Ensure that the tags all live in history.
- history, _ = registry_model.list_repository_tag_history(repository_ref, size=1000)
- assert len(history) == len(previous_history)
+ # Ensure that the tags all live in history.
+ history, _ = registry_model.list_repository_tag_history(repository_ref, size=1000)
+ assert len(history) == len(previous_history)
-@pytest.mark.parametrize('use_manifest', [
- True,
- False,
-])
+@pytest.mark.parametrize("use_manifest", [True, False])
def test_retarget_tag_history(use_manifest, registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'history')
- history, _ = registry_model.list_repository_tag_history(repository_ref)
+ repository_ref = registry_model.lookup_repository("devtable", "history")
+ history, _ = registry_model.list_repository_tag_history(repository_ref)
- if use_manifest:
- manifest_or_legacy_image = registry_model.lookup_manifest_by_digest(repository_ref,
- history[0].manifest_digest,
- allow_dead=True)
- else:
- manifest_or_legacy_image = history[0].legacy_image
+ if use_manifest:
+ manifest_or_legacy_image = registry_model.lookup_manifest_by_digest(
+ repository_ref, history[0].manifest_digest, allow_dead=True
+ )
+ else:
+ manifest_or_legacy_image = history[0].legacy_image
- # Retarget the tag.
- assert manifest_or_legacy_image
- updated_tag = registry_model.retarget_tag(repository_ref, 'latest', manifest_or_legacy_image,
- storage, docker_v2_signing_key, is_reversion=True)
+ # Retarget the tag.
+ assert manifest_or_legacy_image
+ updated_tag = registry_model.retarget_tag(
+ repository_ref,
+ "latest",
+ manifest_or_legacy_image,
+ storage,
+ docker_v2_signing_key,
+ is_reversion=True,
+ )
- # Ensure the tag has changed targets.
- if use_manifest:
- assert updated_tag.manifest_digest == manifest_or_legacy_image.digest
- else:
- assert updated_tag.legacy_image == manifest_or_legacy_image
+ # Ensure the tag has changed targets.
+ if use_manifest:
+ assert updated_tag.manifest_digest == manifest_or_legacy_image.digest
+ else:
+ assert updated_tag.legacy_image == manifest_or_legacy_image
- # Ensure history has been updated.
- new_history, _ = registry_model.list_repository_tag_history(repository_ref)
- assert len(new_history) == len(history) + 1
+ # Ensure history has been updated.
+ new_history, _ = registry_model.list_repository_tag_history(repository_ref)
+ assert len(new_history) == len(history) + 1
def test_retarget_tag_schema1(oci_model):
- repository_ref = oci_model.lookup_repository('devtable', 'simple')
- latest_tag = oci_model.get_repo_tag(repository_ref, 'latest')
- manifest = oci_model.get_manifest_for_tag(latest_tag)
+ repository_ref = oci_model.lookup_repository("devtable", "simple")
+ latest_tag = oci_model.get_repo_tag(repository_ref, "latest")
+ manifest = oci_model.get_manifest_for_tag(latest_tag)
- existing_parsed = manifest.get_parsed_manifest()
+ existing_parsed = manifest.get_parsed_manifest()
- # Retarget a new tag to the manifest.
- updated_tag = oci_model.retarget_tag(repository_ref, 'somenewtag', manifest, storage,
- docker_v2_signing_key)
- assert updated_tag
- assert updated_tag.name == 'somenewtag'
+ # Retarget a new tag to the manifest.
+ updated_tag = oci_model.retarget_tag(
+ repository_ref, "somenewtag", manifest, storage, docker_v2_signing_key
+ )
+ assert updated_tag
+ assert updated_tag.name == "somenewtag"
- updated_manifest = oci_model.get_manifest_for_tag(updated_tag)
- parsed = updated_manifest.get_parsed_manifest()
- assert parsed.namespace == 'devtable'
- assert parsed.repo_name == 'simple'
- assert parsed.tag == 'somenewtag'
+ updated_manifest = oci_model.get_manifest_for_tag(updated_tag)
+ parsed = updated_manifest.get_parsed_manifest()
+ assert parsed.namespace == "devtable"
+ assert parsed.repo_name == "simple"
+ assert parsed.tag == "somenewtag"
- assert parsed.layers == existing_parsed.layers
+ assert parsed.layers == existing_parsed.layers
- # Ensure the tag has changed targets.
- assert oci_model.get_repo_tag(repository_ref, 'somenewtag') == updated_tag
+ # Ensure the tag has changed targets.
+ assert oci_model.get_repo_tag(repository_ref, "somenewtag") == updated_tag
def test_change_repository_tag_expiration(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- tag = registry_model.get_repo_tag(repository_ref, 'latest')
- assert tag.lifetime_end_ts is None
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ tag = registry_model.get_repo_tag(repository_ref, "latest")
+ assert tag.lifetime_end_ts is None
- new_datetime = datetime.utcnow() + timedelta(days=2)
- previous, okay = registry_model.change_repository_tag_expiration(tag, new_datetime)
+ new_datetime = datetime.utcnow() + timedelta(days=2)
+ previous, okay = registry_model.change_repository_tag_expiration(tag, new_datetime)
- assert okay
- assert previous is None
+ assert okay
+ assert previous is None
- tag = registry_model.get_repo_tag(repository_ref, 'latest')
- assert tag.lifetime_end_ts is not None
+ tag = registry_model.get_repo_tag(repository_ref, "latest")
+ assert tag.lifetime_end_ts is not None
-@pytest.mark.parametrize('repo_namespace, repo_name, expected_non_empty', [
- ('devtable', 'simple', []),
- ('devtable', 'complex', ['prod', 'v2.0']),
- ('devtable', 'history', ['latest']),
- ('buynlarge', 'orgrepo', []),
- ('devtable', 'gargantuan', ['v2.0', 'v3.0', 'v4.0', 'v5.0', 'v6.0']),
-])
-def test_get_legacy_images_owned_by_tag(repo_namespace, repo_name, expected_non_empty,
- registry_model):
- repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- tags = registry_model.list_all_active_repository_tags(repository_ref)
- assert len(tags)
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name, expected_non_empty",
+ [
+ ("devtable", "simple", []),
+ ("devtable", "complex", ["prod", "v2.0"]),
+ ("devtable", "history", ["latest"]),
+ ("buynlarge", "orgrepo", []),
+ ("devtable", "gargantuan", ["v2.0", "v3.0", "v4.0", "v5.0", "v6.0"]),
+ ],
+)
+def test_get_legacy_images_owned_by_tag(
+ repo_namespace, repo_name, expected_non_empty, registry_model
+):
+ repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ tags = registry_model.list_all_active_repository_tags(repository_ref)
+ assert len(tags)
- non_empty = set()
- for tag in tags:
- if registry_model.get_legacy_images_owned_by_tag(tag):
- non_empty.add(tag.name)
+ non_empty = set()
+ for tag in tags:
+ if registry_model.get_legacy_images_owned_by_tag(tag):
+ non_empty.add(tag.name)
- assert non_empty == set(expected_non_empty)
+ assert non_empty == set(expected_non_empty)
def test_get_security_status(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- tags = registry_model.list_all_active_repository_tags(repository_ref, include_legacy_images=True)
- assert len(tags)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ tags = registry_model.list_all_active_repository_tags(
+ repository_ref, include_legacy_images=True
+ )
+ assert len(tags)
- for tag in tags:
- assert registry_model.get_security_status(tag.legacy_image)
- registry_model.reset_security_status(tag.legacy_image)
- assert registry_model.get_security_status(tag.legacy_image)
+ for tag in tags:
+ assert registry_model.get_security_status(tag.legacy_image)
+ registry_model.reset_security_status(tag.legacy_image)
+ assert registry_model.get_security_status(tag.legacy_image)
@pytest.fixture()
def clear_rows(initialized_db):
- # Remove all new-style rows so we can backfill.
- TagToRepositoryTag.delete().execute()
- Tag.delete().execute()
- TagManifestLabelMap.delete().execute()
- ManifestLabel.delete().execute()
- ManifestBlob.delete().execute()
- ManifestLegacyImage.delete().execute()
- TagManifestToManifest.delete().execute()
- Manifest.delete().execute()
- TagManifestLabel.delete().execute()
- TagManifest.delete().execute()
+ # Remove all new-style rows so we can backfill.
+ TagToRepositoryTag.delete().execute()
+ Tag.delete().execute()
+ TagManifestLabelMap.delete().execute()
+ ManifestLabel.delete().execute()
+ ManifestBlob.delete().execute()
+ ManifestLegacyImage.delete().execute()
+ TagManifestToManifest.delete().execute()
+ Manifest.delete().execute()
+ TagManifestLabel.delete().execute()
+ TagManifest.delete().execute()
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
-def test_backfill_manifest_for_tag(repo_namespace, repo_name, clear_rows, pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository(repo_namespace, repo_name)
- tags, has_more = pre_oci_model.list_repository_tag_history(repository_ref, size=2500)
- assert tags
- assert not has_more
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
+def test_backfill_manifest_for_tag(
+ repo_namespace, repo_name, clear_rows, pre_oci_model
+):
+ repository_ref = pre_oci_model.lookup_repository(repo_namespace, repo_name)
+ tags, has_more = pre_oci_model.list_repository_tag_history(
+ repository_ref, size=2500
+ )
+ assert tags
+ assert not has_more
- for tag in tags:
- assert not tag.manifest_digest
- assert pre_oci_model.backfill_manifest_for_tag(tag)
+ for tag in tags:
+ assert not tag.manifest_digest
+ assert pre_oci_model.backfill_manifest_for_tag(tag)
- tags, _ = pre_oci_model.list_repository_tag_history(repository_ref)
- assert tags
- for tag in tags:
- assert tag.manifest_digest
+ tags, _ = pre_oci_model.list_repository_tag_history(repository_ref)
+ assert tags
+ for tag in tags:
+ assert tag.manifest_digest
- manifest = pre_oci_model.get_manifest_for_tag(tag)
- assert manifest
+ manifest = pre_oci_model.get_manifest_for_tag(tag)
+ assert manifest
- legacy_image = pre_oci_model.get_legacy_image(repository_ref, tag.legacy_image.docker_image_id,
- include_parents=True)
+ legacy_image = pre_oci_model.get_legacy_image(
+ repository_ref, tag.legacy_image.docker_image_id, include_parents=True
+ )
- parsed_manifest = manifest.get_parsed_manifest()
- assert parsed_manifest.leaf_layer_v1_image_id == legacy_image.docker_image_id
- assert parsed_manifest.parent_image_ids == {p.docker_image_id for p in legacy_image.parents}
+ parsed_manifest = manifest.get_parsed_manifest()
+ assert parsed_manifest.leaf_layer_v1_image_id == legacy_image.docker_image_id
+ assert parsed_manifest.parent_image_ids == {
+ p.docker_image_id for p in legacy_image.parents
+ }
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
-def test_backfill_manifest_on_lookup(repo_namespace, repo_name, clear_rows, pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository(repo_namespace, repo_name)
- tags = pre_oci_model.list_all_active_repository_tags(repository_ref)
- assert tags
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
+def test_backfill_manifest_on_lookup(
+ repo_namespace, repo_name, clear_rows, pre_oci_model
+):
+ repository_ref = pre_oci_model.lookup_repository(repo_namespace, repo_name)
+ tags = pre_oci_model.list_all_active_repository_tags(repository_ref)
+ assert tags
- for tag in tags:
- assert not tag.manifest_digest
- assert not pre_oci_model.get_manifest_for_tag(tag)
+ for tag in tags:
+ assert not tag.manifest_digest
+ assert not pre_oci_model.get_manifest_for_tag(tag)
- manifest = pre_oci_model.get_manifest_for_tag(tag, backfill_if_necessary=True)
- assert manifest
+ manifest = pre_oci_model.get_manifest_for_tag(tag, backfill_if_necessary=True)
+ assert manifest
- updated_tag = pre_oci_model.get_repo_tag(repository_ref, tag.name)
- assert updated_tag.manifest_digest == manifest.digest
+ updated_tag = pre_oci_model.get_repo_tag(repository_ref, tag.name)
+ assert updated_tag.manifest_digest == manifest.digest
-@pytest.mark.parametrize('namespace, expect_enabled', [
- ('devtable', True),
- ('buynlarge', True),
-
- ('disabled', False),
-])
+@pytest.mark.parametrize(
+ "namespace, expect_enabled",
+ [("devtable", True), ("buynlarge", True), ("disabled", False)],
+)
def test_is_namespace_enabled(namespace, expect_enabled, registry_model):
- assert registry_model.is_namespace_enabled(namespace) == expect_enabled
+ assert registry_model.is_namespace_enabled(namespace) == expect_enabled
-@pytest.mark.parametrize('repo_namespace, repo_name', [
- ('devtable', 'simple'),
- ('devtable', 'complex'),
- ('devtable', 'history'),
- ('buynlarge', 'orgrepo'),
-])
+@pytest.mark.parametrize(
+ "repo_namespace, repo_name",
+ [
+ ("devtable", "simple"),
+ ("devtable", "complex"),
+ ("devtable", "history"),
+ ("buynlarge", "orgrepo"),
+ ],
+)
def test_layers_and_blobs(repo_namespace, repo_name, registry_model):
- repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
- tags = registry_model.list_all_active_repository_tags(repository_ref)
- assert tags
+ repository_ref = registry_model.lookup_repository(repo_namespace, repo_name)
+ tags = registry_model.list_all_active_repository_tags(repository_ref)
+ assert tags
- for tag in tags:
- manifest = registry_model.get_manifest_for_tag(tag)
- assert manifest
+ for tag in tags:
+ manifest = registry_model.get_manifest_for_tag(tag)
+ assert manifest
- parsed = manifest.get_parsed_manifest()
- assert parsed
+ parsed = manifest.get_parsed_manifest()
+ assert parsed
- layers = registry_model.list_parsed_manifest_layers(repository_ref, parsed, storage)
- assert layers
+ layers = registry_model.list_parsed_manifest_layers(
+ repository_ref, parsed, storage
+ )
+ assert layers
- layers = registry_model.list_parsed_manifest_layers(repository_ref, parsed, storage,
- include_placements=True)
- assert layers
+ layers = registry_model.list_parsed_manifest_layers(
+ repository_ref, parsed, storage, include_placements=True
+ )
+ assert layers
- for index, manifest_layer in enumerate(layers):
- assert manifest_layer.blob.storage_path
- assert manifest_layer.blob.placements
+ for index, manifest_layer in enumerate(layers):
+ assert manifest_layer.blob.storage_path
+ assert manifest_layer.blob.placements
- repo_blob = registry_model.get_repo_blob_by_digest(repository_ref, manifest_layer.blob.digest)
- assert repo_blob.digest == manifest_layer.blob.digest
+ repo_blob = registry_model.get_repo_blob_by_digest(
+ repository_ref, manifest_layer.blob.digest
+ )
+ assert repo_blob.digest == manifest_layer.blob.digest
- assert manifest_layer.estimated_size(1) is not None
- assert isinstance(manifest_layer.layer_info, ManifestImageLayer)
+ assert manifest_layer.estimated_size(1) is not None
+ assert isinstance(manifest_layer.layer_info, ManifestImageLayer)
- blobs = registry_model.get_manifest_local_blobs(manifest, include_placements=True)
- assert {b.digest for b in blobs} == set(parsed.local_blob_digests)
+ blobs = registry_model.get_manifest_local_blobs(
+ manifest, include_placements=True
+ )
+ assert {b.digest for b in blobs} == set(parsed.local_blob_digests)
def test_manifest_remote_layers(oci_model):
- # Create a config blob for testing.
- config_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ # Create a config blob for testing.
+ config_json = json.dumps(
+ {
+ "config": {},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ }
+ ],
+ }
+ )
- app_config = {'TESTING': True}
- repository_ref = oci_model.lookup_repository('devtable', 'simple')
- with upload_blob(repository_ref, storage, BlobUploadSettings(500, 500, 500)) as upload:
- upload.upload_chunk(app_config, BytesIO(config_json))
- blob = upload.commit_to_blob(app_config)
+ app_config = {"TESTING": True}
+ repository_ref = oci_model.lookup_repository("devtable", "simple")
+ with upload_blob(
+ repository_ref, storage, BlobUploadSettings(500, 500, 500)
+ ) as upload:
+ upload.upload_chunk(app_config, BytesIO(config_json))
+ blob = upload.commit_to_blob(app_config)
- # Create the manifest in the repo.
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(blob.digest, blob.compressed_size)
- builder.add_layer('sha256:abcd', 1234, urls=['http://hello/world'])
- manifest = builder.build()
+ # Create the manifest in the repo.
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(blob.digest, blob.compressed_size)
+ builder.add_layer("sha256:abcd", 1234, urls=["http://hello/world"])
+ manifest = builder.build()
- created_manifest, _ = oci_model.create_manifest_and_retarget_tag(repository_ref, manifest,
- 'sometag', storage)
- assert created_manifest
+ created_manifest, _ = oci_model.create_manifest_and_retarget_tag(
+ repository_ref, manifest, "sometag", storage
+ )
+ assert created_manifest
- layers = oci_model.list_parsed_manifest_layers(repository_ref,
- created_manifest.get_parsed_manifest(),
- storage)
- assert len(layers) == 1
- assert layers[0].layer_info.is_remote
- assert layers[0].layer_info.urls == ['http://hello/world']
- assert layers[0].blob is None
+ layers = oci_model.list_parsed_manifest_layers(
+ repository_ref, created_manifest.get_parsed_manifest(), storage
+ )
+ assert len(layers) == 1
+ assert layers[0].layer_info.is_remote
+ assert layers[0].layer_info.urls == ["http://hello/world"]
+ assert layers[0].blob is None
def test_derived_image(registry_model):
- # Clear all existing derived storage.
- DerivedStorageForImage.delete().execute()
+ # Clear all existing derived storage.
+ DerivedStorageForImage.delete().execute()
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- tag = registry_model.get_repo_tag(repository_ref, 'latest')
- manifest = registry_model.get_manifest_for_tag(tag)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ tag = registry_model.get_repo_tag(repository_ref, "latest")
+ manifest = registry_model.get_manifest_for_tag(tag)
- # Ensure the squashed image doesn't exist.
- assert registry_model.lookup_derived_image(manifest, 'squash', storage, {}) is None
+ # Ensure the squashed image doesn't exist.
+ assert registry_model.lookup_derived_image(manifest, "squash", storage, {}) is None
- # Create a new one.
- squashed = registry_model.lookup_or_create_derived_image(manifest, 'squash',
- 'local_us', storage, {})
- assert registry_model.lookup_or_create_derived_image(manifest, 'squash',
- 'local_us', storage, {}) == squashed
- assert squashed.unique_id
+ # Create a new one.
+ squashed = registry_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {}
+ )
+ assert (
+ registry_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {}
+ )
+ == squashed
+ )
+ assert squashed.unique_id
- # Check and set the size.
- assert squashed.blob.compressed_size is None
- registry_model.set_derived_image_size(squashed, 1234)
+ # Check and set the size.
+ assert squashed.blob.compressed_size is None
+ registry_model.set_derived_image_size(squashed, 1234)
- found = registry_model.lookup_derived_image(manifest, 'squash', storage, {})
- assert found.blob.compressed_size == 1234
- assert found.unique_id == squashed.unique_id
+ found = registry_model.lookup_derived_image(manifest, "squash", storage, {})
+ assert found.blob.compressed_size == 1234
+ assert found.unique_id == squashed.unique_id
- # Ensure its returned now.
- assert found == squashed
+ # Ensure its returned now.
+ assert found == squashed
- # Ensure different metadata results in a different derived image.
- found = registry_model.lookup_derived_image(manifest, 'squash', storage, {'foo': 'bar'})
- assert found is None
+ # Ensure different metadata results in a different derived image.
+ found = registry_model.lookup_derived_image(
+ manifest, "squash", storage, {"foo": "bar"}
+ )
+ assert found is None
- squashed_foo = registry_model.lookup_or_create_derived_image(manifest, 'squash', 'local_us',
- storage, {'foo': 'bar'})
- assert squashed_foo != squashed
+ squashed_foo = registry_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {"foo": "bar"}
+ )
+ assert squashed_foo != squashed
- found = registry_model.lookup_derived_image(manifest, 'squash', storage, {'foo': 'bar'})
- assert found == squashed_foo
+ found = registry_model.lookup_derived_image(
+ manifest, "squash", storage, {"foo": "bar"}
+ )
+ assert found == squashed_foo
- assert squashed.unique_id != squashed_foo.unique_id
+ assert squashed.unique_id != squashed_foo.unique_id
- # Lookup with placements.
- squashed = registry_model.lookup_or_create_derived_image(manifest, 'squash', 'local_us',
- storage, {}, include_placements=True)
- assert squashed.blob.placements
+ # Lookup with placements.
+ squashed = registry_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {}, include_placements=True
+ )
+ assert squashed.blob.placements
- # Delete the derived image.
- registry_model.delete_derived_image(squashed)
- assert registry_model.lookup_derived_image(manifest, 'squash', storage, {}) is None
+ # Delete the derived image.
+ registry_model.delete_derived_image(squashed)
+ assert registry_model.lookup_derived_image(manifest, "squash", storage, {}) is None
def test_derived_image_signatures(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- tag = registry_model.get_repo_tag(repository_ref, 'latest')
- manifest = registry_model.get_manifest_for_tag(tag)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ tag = registry_model.get_repo_tag(repository_ref, "latest")
+ manifest = registry_model.get_manifest_for_tag(tag)
- derived = registry_model.lookup_derived_image(manifest, 'squash', storage, {})
- assert derived
+ derived = registry_model.lookup_derived_image(manifest, "squash", storage, {})
+ assert derived
- signature = registry_model.get_derived_image_signature(derived, 'gpg2')
- assert signature is None
+ signature = registry_model.get_derived_image_signature(derived, "gpg2")
+ assert signature is None
- registry_model.set_derived_image_signature(derived, 'gpg2', 'foo')
- assert registry_model.get_derived_image_signature(derived, 'gpg2') == 'foo'
+ registry_model.set_derived_image_signature(derived, "gpg2", "foo")
+ assert registry_model.get_derived_image_signature(derived, "gpg2") == "foo"
def test_derived_image_for_manifest_list(oci_model):
- # Clear all existing derived storage.
- DerivedStorageForImage.delete().execute()
+ # Clear all existing derived storage.
+ DerivedStorageForImage.delete().execute()
- # Create a config blob for testing.
- config_json = json.dumps({
- 'config': {},
- "rootfs": {
- "type": "layers",
- "diff_ids": []
- },
- "history": [
- {
- "created": "2018-04-03T18:37:09.284840891Z",
- "created_by": "do something",
- },
- ],
- })
+ # Create a config blob for testing.
+ config_json = json.dumps(
+ {
+ "config": {},
+ "rootfs": {"type": "layers", "diff_ids": []},
+ "history": [
+ {
+ "created": "2018-04-03T18:37:09.284840891Z",
+ "created_by": "do something",
+ }
+ ],
+ }
+ )
- app_config = {'TESTING': True}
- repository_ref = oci_model.lookup_repository('devtable', 'simple')
- with upload_blob(repository_ref, storage, BlobUploadSettings(500, 500, 500)) as upload:
- upload.upload_chunk(app_config, BytesIO(config_json))
- blob = upload.commit_to_blob(app_config)
+ app_config = {"TESTING": True}
+ repository_ref = oci_model.lookup_repository("devtable", "simple")
+ with upload_blob(
+ repository_ref, storage, BlobUploadSettings(500, 500, 500)
+ ) as upload:
+ upload.upload_chunk(app_config, BytesIO(config_json))
+ blob = upload.commit_to_blob(app_config)
- # Create the manifest in the repo.
- builder = DockerSchema2ManifestBuilder()
- builder.set_config_digest(blob.digest, blob.compressed_size)
- builder.add_layer(blob.digest, blob.compressed_size)
- amd64_manifest = builder.build()
+ # Create the manifest in the repo.
+ builder = DockerSchema2ManifestBuilder()
+ builder.set_config_digest(blob.digest, blob.compressed_size)
+ builder.add_layer(blob.digest, blob.compressed_size)
+ amd64_manifest = builder.build()
- oci_model.create_manifest_and_retarget_tag(repository_ref, amd64_manifest, 'submanifest', storage)
+ oci_model.create_manifest_and_retarget_tag(
+ repository_ref, amd64_manifest, "submanifest", storage
+ )
- # Create a manifest list, pointing to at least one amd64+linux manifest.
- builder = DockerSchema2ManifestListBuilder()
- builder.add_manifest(amd64_manifest, 'amd64', 'linux')
- manifestlist = builder.build()
+ # Create a manifest list, pointing to at least one amd64+linux manifest.
+ builder = DockerSchema2ManifestListBuilder()
+ builder.add_manifest(amd64_manifest, "amd64", "linux")
+ manifestlist = builder.build()
- oci_model.create_manifest_and_retarget_tag(repository_ref, manifestlist, 'listtag', storage)
- manifest = oci_model.get_manifest_for_tag(oci_model.get_repo_tag(repository_ref, 'listtag'))
- assert manifest
- assert manifest.get_parsed_manifest().is_manifest_list
+ oci_model.create_manifest_and_retarget_tag(
+ repository_ref, manifestlist, "listtag", storage
+ )
+ manifest = oci_model.get_manifest_for_tag(
+ oci_model.get_repo_tag(repository_ref, "listtag")
+ )
+ assert manifest
+ assert manifest.get_parsed_manifest().is_manifest_list
- # Ensure the squashed image doesn't exist.
- assert oci_model.lookup_derived_image(manifest, 'squash', storage, {}) is None
+ # Ensure the squashed image doesn't exist.
+ assert oci_model.lookup_derived_image(manifest, "squash", storage, {}) is None
- # Create a new one.
- squashed = oci_model.lookup_or_create_derived_image(manifest, 'squash', 'local_us', storage, {})
- assert squashed.unique_id
- assert oci_model.lookup_or_create_derived_image(manifest, 'squash',
- 'local_us', storage, {}) == squashed
+ # Create a new one.
+ squashed = oci_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {}
+ )
+ assert squashed.unique_id
+ assert (
+ oci_model.lookup_or_create_derived_image(
+ manifest, "squash", "local_us", storage, {}
+ )
+ == squashed
+ )
- # Perform lookup.
- assert oci_model.lookup_derived_image(manifest, 'squash', storage, {}) == squashed
+ # Perform lookup.
+ assert oci_model.lookup_derived_image(manifest, "squash", storage, {}) == squashed
def test_torrent_info(registry_model):
- # Remove all existing info.
- TorrentInfo.delete().execute()
+ # Remove all existing info.
+ TorrentInfo.delete().execute()
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- tag = registry_model.get_repo_tag(repository_ref, 'latest')
- manifest = registry_model.get_manifest_for_tag(tag)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ tag = registry_model.get_repo_tag(repository_ref, "latest")
+ manifest = registry_model.get_manifest_for_tag(tag)
- blobs = registry_model.get_manifest_local_blobs(manifest)
- assert blobs
+ blobs = registry_model.get_manifest_local_blobs(manifest)
+ assert blobs
- assert registry_model.get_torrent_info(blobs[0]) is None
- registry_model.set_torrent_info(blobs[0], 2, 'foo')
+ assert registry_model.get_torrent_info(blobs[0]) is None
+ registry_model.set_torrent_info(blobs[0], 2, "foo")
- # Set it again exactly, which should be a no-op.
- registry_model.set_torrent_info(blobs[0], 2, 'foo')
+ # Set it again exactly, which should be a no-op.
+ registry_model.set_torrent_info(blobs[0], 2, "foo")
- # Check the information we've set.
- torrent_info = registry_model.get_torrent_info(blobs[0])
- assert torrent_info is not None
- assert torrent_info.piece_length == 2
- assert torrent_info.pieces == 'foo'
+ # Check the information we've set.
+ torrent_info = registry_model.get_torrent_info(blobs[0])
+ assert torrent_info is not None
+ assert torrent_info.piece_length == 2
+ assert torrent_info.pieces == "foo"
- # Try setting it again. Nothing should happen.
- registry_model.set_torrent_info(blobs[0], 3, 'bar')
+ # Try setting it again. Nothing should happen.
+ registry_model.set_torrent_info(blobs[0], 3, "bar")
- torrent_info = registry_model.get_torrent_info(blobs[0])
- assert torrent_info is not None
- assert torrent_info.piece_length == 2
- assert torrent_info.pieces == 'foo'
+ torrent_info = registry_model.get_torrent_info(blobs[0])
+ assert torrent_info is not None
+ assert torrent_info.piece_length == 2
+ assert torrent_info.pieces == "foo"
def test_blob_uploads(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
- blob_upload = registry_model.create_blob_upload(repository_ref, str(uuid.uuid4()),
- 'local_us', {'some': 'metadata'})
- assert blob_upload
- assert blob_upload.storage_metadata == {'some': 'metadata'}
- assert blob_upload.location_name == 'local_us'
+ blob_upload = registry_model.create_blob_upload(
+ repository_ref, str(uuid.uuid4()), "local_us", {"some": "metadata"}
+ )
+ assert blob_upload
+ assert blob_upload.storage_metadata == {"some": "metadata"}
+ assert blob_upload.location_name == "local_us"
- # Ensure we can find the blob upload.
- assert registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id) == blob_upload
+ # Ensure we can find the blob upload.
+ assert (
+ registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
+ == blob_upload
+ )
- # Update and ensure the changes are saved.
- assert registry_model.update_blob_upload(blob_upload, 1, 'the-pieces_hash',
- blob_upload.piece_sha_state,
- {'new': 'metadata'}, 2, 3,
- blob_upload.sha_state)
+ # Update and ensure the changes are saved.
+ assert registry_model.update_blob_upload(
+ blob_upload,
+ 1,
+ "the-pieces_hash",
+ blob_upload.piece_sha_state,
+ {"new": "metadata"},
+ 2,
+ 3,
+ blob_upload.sha_state,
+ )
- updated = registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
- assert updated
- assert updated.uncompressed_byte_count == 1
- assert updated.piece_hashes == 'the-pieces_hash'
- assert updated.storage_metadata == {'new': 'metadata'}
- assert updated.byte_count == 2
- assert updated.chunk_count == 3
+ updated = registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
+ assert updated
+ assert updated.uncompressed_byte_count == 1
+ assert updated.piece_hashes == "the-pieces_hash"
+ assert updated.storage_metadata == {"new": "metadata"}
+ assert updated.byte_count == 2
+ assert updated.chunk_count == 3
- # Delete the upload.
- registry_model.delete_blob_upload(blob_upload)
+ # Delete the upload.
+ registry_model.delete_blob_upload(blob_upload)
- # Ensure it can no longer be found.
- assert not registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
+ # Ensure it can no longer be found.
+ assert not registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
def test_commit_blob_upload(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- blob_upload = registry_model.create_blob_upload(repository_ref, str(uuid.uuid4()),
- 'local_us', {'some': 'metadata'})
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ blob_upload = registry_model.create_blob_upload(
+ repository_ref, str(uuid.uuid4()), "local_us", {"some": "metadata"}
+ )
- # Commit the blob upload and make sure it is written as a blob.
- digest = 'sha256:' + hashlib.sha256('hello').hexdigest()
- blob = registry_model.commit_blob_upload(blob_upload, digest, 60)
- assert blob.digest == digest
+ # Commit the blob upload and make sure it is written as a blob.
+ digest = "sha256:" + hashlib.sha256("hello").hexdigest()
+ blob = registry_model.commit_blob_upload(blob_upload, digest, 60)
+ assert blob.digest == digest
- # Ensure the upload can no longer be found.
- assert not registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
+ # Ensure the upload can no longer be found.
+ assert not registry_model.lookup_blob_upload(repository_ref, blob_upload.upload_id)
# TODO: Re-enable for OCI model once we have a new table for temporary blobs.
def test_mount_blob_into_repository(pre_oci_model):
- repository_ref = pre_oci_model.lookup_repository('devtable', 'simple')
- latest_tag = pre_oci_model.get_repo_tag(repository_ref, 'latest')
- manifest = pre_oci_model.get_manifest_for_tag(latest_tag)
+ repository_ref = pre_oci_model.lookup_repository("devtable", "simple")
+ latest_tag = pre_oci_model.get_repo_tag(repository_ref, "latest")
+ manifest = pre_oci_model.get_manifest_for_tag(latest_tag)
- target_repository_ref = pre_oci_model.lookup_repository('devtable', 'complex')
+ target_repository_ref = pre_oci_model.lookup_repository("devtable", "complex")
- blobs = pre_oci_model.get_manifest_local_blobs(manifest, include_placements=True)
- assert blobs
+ blobs = pre_oci_model.get_manifest_local_blobs(manifest, include_placements=True)
+ assert blobs
- for blob in blobs:
- # Ensure the blob doesn't exist under the repository.
- assert not pre_oci_model.get_repo_blob_by_digest(target_repository_ref, blob.digest)
+ for blob in blobs:
+ # Ensure the blob doesn't exist under the repository.
+ assert not pre_oci_model.get_repo_blob_by_digest(
+ target_repository_ref, blob.digest
+ )
- # Mount the blob into the repository.
- assert pre_oci_model.mount_blob_into_repository(blob, target_repository_ref, 60)
+ # Mount the blob into the repository.
+ assert pre_oci_model.mount_blob_into_repository(blob, target_repository_ref, 60)
- # Ensure it now exists.
- found = pre_oci_model.get_repo_blob_by_digest(target_repository_ref, blob.digest)
- assert found == blob
+ # Ensure it now exists.
+ found = pre_oci_model.get_repo_blob_by_digest(
+ target_repository_ref, blob.digest
+ )
+ assert found == blob
class SomeException(Exception):
- pass
+ pass
def test_get_cached_repo_blob(registry_model):
- model_cache = InMemoryDataModelCache()
+ model_cache = InMemoryDataModelCache()
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- latest_tag = registry_model.get_repo_tag(repository_ref, 'latest')
- manifest = registry_model.get_manifest_for_tag(latest_tag)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ latest_tag = registry_model.get_repo_tag(repository_ref, "latest")
+ manifest = registry_model.get_manifest_for_tag(latest_tag)
- blobs = registry_model.get_manifest_local_blobs(manifest, include_placements=True)
- assert blobs
+ blobs = registry_model.get_manifest_local_blobs(manifest, include_placements=True)
+ assert blobs
- blob = blobs[0]
+ blob = blobs[0]
- # Load a blob to add it to the cache.
- found = registry_model.get_cached_repo_blob(model_cache, 'devtable', 'simple', blob.digest)
- assert found.digest == blob.digest
- assert found.uuid == blob.uuid
- assert found.compressed_size == blob.compressed_size
- assert found.uncompressed_size == blob.uncompressed_size
- assert found.uploading == blob.uploading
- assert found.placements == blob.placements
+ # Load a blob to add it to the cache.
+ found = registry_model.get_cached_repo_blob(
+ model_cache, "devtable", "simple", blob.digest
+ )
+ assert found.digest == blob.digest
+ assert found.uuid == blob.uuid
+ assert found.compressed_size == blob.compressed_size
+ assert found.uncompressed_size == blob.uncompressed_size
+ assert found.uploading == blob.uploading
+ assert found.placements == blob.placements
- # Disconnect from the database by overwriting the connection.
- def fail(x, y):
- raise SomeException('Not connected!')
+ # Disconnect from the database by overwriting the connection.
+ def fail(x, y):
+ raise SomeException("Not connected!")
- with patch('data.registry_model.registry_pre_oci_model.model.blob.get_repository_blob_by_digest',
- fail):
- with patch('data.registry_model.registry_oci_model.model.oci.blob.get_repository_blob_by_digest',
- fail):
- # Make sure we can load again, which should hit the cache.
- cached = registry_model.get_cached_repo_blob(model_cache, 'devtable', 'simple', blob.digest)
- assert cached.digest == blob.digest
- assert cached.uuid == blob.uuid
- assert cached.compressed_size == blob.compressed_size
- assert cached.uncompressed_size == blob.uncompressed_size
- assert cached.uploading == blob.uploading
- assert cached.placements == blob.placements
+ with patch(
+ "data.registry_model.registry_pre_oci_model.model.blob.get_repository_blob_by_digest",
+ fail,
+ ):
+ with patch(
+ "data.registry_model.registry_oci_model.model.oci.blob.get_repository_blob_by_digest",
+ fail,
+ ):
+ # Make sure we can load again, which should hit the cache.
+ cached = registry_model.get_cached_repo_blob(
+ model_cache, "devtable", "simple", blob.digest
+ )
+ assert cached.digest == blob.digest
+ assert cached.uuid == blob.uuid
+ assert cached.compressed_size == blob.compressed_size
+ assert cached.uncompressed_size == blob.uncompressed_size
+ assert cached.uploading == blob.uploading
+ assert cached.placements == blob.placements
- # Try another blob, which should fail since the DB is not connected and the cache
- # does not contain the blob.
- with pytest.raises(SomeException):
- registry_model.get_cached_repo_blob(model_cache, 'devtable', 'simple', 'some other digest')
+ # Try another blob, which should fail since the DB is not connected and the cache
+ # does not contain the blob.
+ with pytest.raises(SomeException):
+ registry_model.get_cached_repo_blob(
+ model_cache, "devtable", "simple", "some other digest"
+ )
def test_create_manifest_and_retarget_tag(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- latest_tag = registry_model.get_repo_tag(repository_ref, 'latest', include_legacy_image=True)
- manifest = registry_model.get_manifest_for_tag(latest_tag).get_parsed_manifest()
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ latest_tag = registry_model.get_repo_tag(
+ repository_ref, "latest", include_legacy_image=True
+ )
+ manifest = registry_model.get_manifest_for_tag(latest_tag).get_parsed_manifest()
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
- builder.add_layer(manifest.blob_digests[0],
- '{"id": "%s"}' % latest_tag.legacy_image.docker_image_id)
- sample_manifest = builder.build(docker_v2_signing_key)
- assert sample_manifest is not None
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "anothertag")
+ builder.add_layer(
+ manifest.blob_digests[0],
+ '{"id": "%s"}' % latest_tag.legacy_image.docker_image_id,
+ )
+ sample_manifest = builder.build(docker_v2_signing_key)
+ assert sample_manifest is not None
- another_manifest, tag = registry_model.create_manifest_and_retarget_tag(repository_ref,
- sample_manifest,
- 'anothertag',
- storage)
- assert another_manifest is not None
- assert tag is not None
+ another_manifest, tag = registry_model.create_manifest_and_retarget_tag(
+ repository_ref, sample_manifest, "anothertag", storage
+ )
+ assert another_manifest is not None
+ assert tag is not None
- assert tag.name == 'anothertag'
- assert another_manifest.get_parsed_manifest().manifest_dict == sample_manifest.manifest_dict
+ assert tag.name == "anothertag"
+ assert (
+ another_manifest.get_parsed_manifest().manifest_dict
+ == sample_manifest.manifest_dict
+ )
def test_get_schema1_parsed_manifest(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- latest_tag = registry_model.get_repo_tag(repository_ref, 'latest', include_legacy_image=True)
- manifest = registry_model.get_manifest_for_tag(latest_tag)
- assert registry_model.get_schema1_parsed_manifest(manifest, '', '', '', storage)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ latest_tag = registry_model.get_repo_tag(
+ repository_ref, "latest", include_legacy_image=True
+ )
+ manifest = registry_model.get_manifest_for_tag(latest_tag)
+ assert registry_model.get_schema1_parsed_manifest(manifest, "", "", "", storage)
def test_convert_manifest(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- latest_tag = registry_model.get_repo_tag(repository_ref, 'latest', include_legacy_image=True)
- manifest = registry_model.get_manifest_for_tag(latest_tag)
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ latest_tag = registry_model.get_repo_tag(
+ repository_ref, "latest", include_legacy_image=True
+ )
+ manifest = registry_model.get_manifest_for_tag(latest_tag)
- mediatypes = DOCKER_SCHEMA1_CONTENT_TYPES
- assert registry_model.convert_manifest(manifest, '', '', '', mediatypes, storage)
+ mediatypes = DOCKER_SCHEMA1_CONTENT_TYPES
+ assert registry_model.convert_manifest(manifest, "", "", "", mediatypes, storage)
- mediatypes = []
- assert registry_model.convert_manifest(manifest, '', '', '', mediatypes, storage) is None
+ mediatypes = []
+ assert (
+ registry_model.convert_manifest(manifest, "", "", "", mediatypes, storage)
+ is None
+ )
def test_create_manifest_and_retarget_tag_with_labels(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- latest_tag = registry_model.get_repo_tag(repository_ref, 'latest', include_legacy_image=True)
- manifest = registry_model.get_manifest_for_tag(latest_tag).get_parsed_manifest()
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ latest_tag = registry_model.get_repo_tag(
+ repository_ref, "latest", include_legacy_image=True
+ )
+ manifest = registry_model.get_manifest_for_tag(latest_tag).get_parsed_manifest()
- json_metadata = {
- 'id': latest_tag.legacy_image.docker_image_id,
- 'config': {
- 'Labels': {
- 'quay.expires-after': '2w',
- },
- },
- }
+ json_metadata = {
+ "id": latest_tag.legacy_image.docker_image_id,
+ "config": {"Labels": {"quay.expires-after": "2w"}},
+ }
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
- builder.add_layer(manifest.blob_digests[0], json.dumps(json_metadata))
- sample_manifest = builder.build(docker_v2_signing_key)
- assert sample_manifest is not None
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "anothertag")
+ builder.add_layer(manifest.blob_digests[0], json.dumps(json_metadata))
+ sample_manifest = builder.build(docker_v2_signing_key)
+ assert sample_manifest is not None
- another_manifest, tag = registry_model.create_manifest_and_retarget_tag(repository_ref,
- sample_manifest,
- 'anothertag',
- storage)
- assert another_manifest is not None
- assert tag is not None
+ another_manifest, tag = registry_model.create_manifest_and_retarget_tag(
+ repository_ref, sample_manifest, "anothertag", storage
+ )
+ assert another_manifest is not None
+ assert tag is not None
- assert tag.name == 'anothertag'
- assert another_manifest.get_parsed_manifest().manifest_dict == sample_manifest.manifest_dict
-
- # Ensure the labels were applied.
- assert tag.lifetime_end_ms is not None
+ assert tag.name == "anothertag"
+ assert (
+ another_manifest.get_parsed_manifest().manifest_dict
+ == sample_manifest.manifest_dict
+ )
+ # Ensure the labels were applied.
+ assert tag.lifetime_end_ms is not None
def _populate_blob(digest):
- location = ImageStorageLocation.get(name='local_us')
- store_blob_record_and_temp_link('devtable', 'simple', digest, location, 1, 120)
+ location = ImageStorageLocation.get(name="local_us")
+ store_blob_record_and_temp_link("devtable", "simple", digest, location, 1, 120)
def test_known_issue_schema1(registry_model):
- test_dir = os.path.dirname(os.path.abspath(__file__))
- path = os.path.join(test_dir, '../../../image/docker/test/validate_manifest_known_issue.json')
- with open(path, 'r') as f:
- manifest_bytes = f.read()
+ test_dir = os.path.dirname(os.path.abspath(__file__))
+ path = os.path.join(
+ test_dir, "../../../image/docker/test/validate_manifest_known_issue.json"
+ )
+ with open(path, "r") as f:
+ manifest_bytes = f.read()
- manifest = DockerSchema1Manifest(Bytes.for_string_or_unicode(manifest_bytes))
+ manifest = DockerSchema1Manifest(Bytes.for_string_or_unicode(manifest_bytes))
- for blob_digest in manifest.local_blob_digests:
- _populate_blob(blob_digest)
+ for blob_digest in manifest.local_blob_digests:
+ _populate_blob(blob_digest)
- digest = manifest.digest
- assert digest == 'sha256:44518f5a4d1cb5b7a6347763116fb6e10f6a8563b6c40bb389a0a982f0a9f47a'
+ digest = manifest.digest
+ assert (
+ digest
+ == "sha256:44518f5a4d1cb5b7a6347763116fb6e10f6a8563b6c40bb389a0a982f0a9f47a"
+ )
- # Create the manifest in the database.
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- created_manifest, _ = registry_model.create_manifest_and_retarget_tag(repository_ref, manifest,
- 'latest', storage)
- assert created_manifest
- assert created_manifest.digest == manifest.digest
- assert (created_manifest.internal_manifest_bytes.as_encoded_str() ==
- manifest.bytes.as_encoded_str())
+ # Create the manifest in the database.
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ created_manifest, _ = registry_model.create_manifest_and_retarget_tag(
+ repository_ref, manifest, "latest", storage
+ )
+ assert created_manifest
+ assert created_manifest.digest == manifest.digest
+ assert (
+ created_manifest.internal_manifest_bytes.as_encoded_str()
+ == manifest.bytes.as_encoded_str()
+ )
- # Look it up again and validate.
- found = registry_model.lookup_manifest_by_digest(repository_ref, manifest.digest, allow_dead=True)
- assert found
- assert found.digest == digest
- assert found.internal_manifest_bytes.as_encoded_str() == manifest.bytes.as_encoded_str()
- assert found.get_parsed_manifest().digest == digest
+ # Look it up again and validate.
+ found = registry_model.lookup_manifest_by_digest(
+ repository_ref, manifest.digest, allow_dead=True
+ )
+ assert found
+ assert found.digest == digest
+ assert (
+ found.internal_manifest_bytes.as_encoded_str()
+ == manifest.bytes.as_encoded_str()
+ )
+ assert found.get_parsed_manifest().digest == digest
def test_unicode_emoji(registry_model):
- builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'latest')
- builder.add_layer('sha256:abcde', json.dumps({
- 'id': 'someid',
- 'author': u'😱',
- }, ensure_ascii=False))
+ builder = DockerSchema1ManifestBuilder("devtable", "simple", "latest")
+ builder.add_layer(
+ "sha256:abcde", json.dumps({"id": "someid", "author": u"😱"}, ensure_ascii=False)
+ )
- manifest = builder.build(ensure_ascii=False)
- manifest._validate()
+ manifest = builder.build(ensure_ascii=False)
+ manifest._validate()
- for blob_digest in manifest.local_blob_digests:
- _populate_blob(blob_digest)
+ for blob_digest in manifest.local_blob_digests:
+ _populate_blob(blob_digest)
- # Create the manifest in the database.
- repository_ref = registry_model.lookup_repository('devtable', 'simple')
- created_manifest, _ = registry_model.create_manifest_and_retarget_tag(repository_ref, manifest,
- 'latest', storage)
- assert created_manifest
- assert created_manifest.digest == manifest.digest
- assert (created_manifest.internal_manifest_bytes.as_encoded_str() ==
- manifest.bytes.as_encoded_str())
+ # Create the manifest in the database.
+ repository_ref = registry_model.lookup_repository("devtable", "simple")
+ created_manifest, _ = registry_model.create_manifest_and_retarget_tag(
+ repository_ref, manifest, "latest", storage
+ )
+ assert created_manifest
+ assert created_manifest.digest == manifest.digest
+ assert (
+ created_manifest.internal_manifest_bytes.as_encoded_str()
+ == manifest.bytes.as_encoded_str()
+ )
- # Look it up again and validate.
- found = registry_model.lookup_manifest_by_digest(repository_ref, manifest.digest, allow_dead=True)
- assert found
- assert found.digest == manifest.digest
- assert found.internal_manifest_bytes.as_encoded_str() == manifest.bytes.as_encoded_str()
- assert found.get_parsed_manifest().digest == manifest.digest
+ # Look it up again and validate.
+ found = registry_model.lookup_manifest_by_digest(
+ repository_ref, manifest.digest, allow_dead=True
+ )
+ assert found
+ assert found.digest == manifest.digest
+ assert (
+ found.internal_manifest_bytes.as_encoded_str()
+ == manifest.bytes.as_encoded_str()
+ )
+ assert found.get_parsed_manifest().digest == manifest.digest
def test_lookup_active_repository_tags(oci_model):
- repository_ref = oci_model.lookup_repository('devtable', 'simple')
- latest_tag = oci_model.get_repo_tag(repository_ref, 'latest')
- manifest = oci_model.get_manifest_for_tag(latest_tag)
+ repository_ref = oci_model.lookup_repository("devtable", "simple")
+ latest_tag = oci_model.get_repo_tag(repository_ref, "latest")
+ manifest = oci_model.get_manifest_for_tag(latest_tag)
- tag_count = 500
+ tag_count = 500
- # Create a bunch of tags.
- tags_expected = set()
- for index in range(0, tag_count):
- tags_expected.add('somenewtag%s' % index)
- oci_model.retarget_tag(repository_ref, 'somenewtag%s' % index, manifest, storage,
- docker_v2_signing_key)
+ # Create a bunch of tags.
+ tags_expected = set()
+ for index in range(0, tag_count):
+ tags_expected.add("somenewtag%s" % index)
+ oci_model.retarget_tag(
+ repository_ref,
+ "somenewtag%s" % index,
+ manifest,
+ storage,
+ docker_v2_signing_key,
+ )
- assert tags_expected
+ assert tags_expected
- # List the tags.
- tags_found = set()
- tag_id = None
- while True:
- tags = oci_model.lookup_active_repository_tags(repository_ref, tag_id, 11)
- assert len(tags) <= 11
- for tag in tags[0:10]:
- assert tag.name not in tags_found
- if tag.name in tags_expected:
- tags_found.add(tag.name)
- tags_expected.remove(tag.name)
+ # List the tags.
+ tags_found = set()
+ tag_id = None
+ while True:
+ tags = oci_model.lookup_active_repository_tags(repository_ref, tag_id, 11)
+ assert len(tags) <= 11
+ for tag in tags[0:10]:
+ assert tag.name not in tags_found
+ if tag.name in tags_expected:
+ tags_found.add(tag.name)
+ tags_expected.remove(tag.name)
- if len(tags) < 11:
- break
+ if len(tags) < 11:
+ break
- tag_id = tags[10].id
+ tag_id = tags[10].id
- # Make sure we've found all the tags.
- assert tags_found
- assert not tags_expected
+ # Make sure we've found all the tags.
+ assert tags_found
+ assert not tags_expected
def test_yield_tags_for_vulnerability_notification(registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'complex')
+ repository_ref = registry_model.lookup_repository("devtable", "complex")
- # Check for all legacy images under the tags and ensure not raised because
- # no notification is yet registered.
- for tag in registry_model.list_all_active_repository_tags(repository_ref,
- include_legacy_images=True):
- image = registry_model.get_legacy_image(repository_ref, tag.legacy_image.docker_image_id,
- include_blob=True)
- pairs = [(image.docker_image_id, image.blob.uuid)]
- results = list(registry_model.yield_tags_for_vulnerability_notification(pairs))
- assert not len(results)
+ # Check for all legacy images under the tags and ensure not raised because
+ # no notification is yet registered.
+ for tag in registry_model.list_all_active_repository_tags(
+ repository_ref, include_legacy_images=True
+ ):
+ image = registry_model.get_legacy_image(
+ repository_ref, tag.legacy_image.docker_image_id, include_blob=True
+ )
+ pairs = [(image.docker_image_id, image.blob.uuid)]
+ results = list(registry_model.yield_tags_for_vulnerability_notification(pairs))
+ assert not len(results)
- # Register a notification.
- model.notification.create_repo_notification(repository_ref.id, 'vulnerability_found', 'email',
- {}, {})
+ # Register a notification.
+ model.notification.create_repo_notification(
+ repository_ref.id, "vulnerability_found", "email", {}, {}
+ )
- # Check again.
- for tag in registry_model.list_all_active_repository_tags(repository_ref,
- include_legacy_images=True):
- image = registry_model.get_legacy_image(repository_ref, tag.legacy_image.docker_image_id,
- include_blob=True, include_parents=True)
+ # Check again.
+ for tag in registry_model.list_all_active_repository_tags(
+ repository_ref, include_legacy_images=True
+ ):
+ image = registry_model.get_legacy_image(
+ repository_ref,
+ tag.legacy_image.docker_image_id,
+ include_blob=True,
+ include_parents=True,
+ )
- # Check for every parent of the image.
- for current in image.parents:
- img = registry_model.get_legacy_image(repository_ref, current.docker_image_id,
- include_blob=True)
- pairs = [(img.docker_image_id, img.blob.uuid)]
- results = list(registry_model.yield_tags_for_vulnerability_notification(pairs))
- assert len(results) > 0
- assert tag.name in {t.name for t in results}
+ # Check for every parent of the image.
+ for current in image.parents:
+ img = registry_model.get_legacy_image(
+ repository_ref, current.docker_image_id, include_blob=True
+ )
+ pairs = [(img.docker_image_id, img.blob.uuid)]
+ results = list(
+ registry_model.yield_tags_for_vulnerability_notification(pairs)
+ )
+ assert len(results) > 0
+ assert tag.name in {t.name for t in results}
- # Check for the image itself.
- pairs = [(image.docker_image_id, image.blob.uuid)]
- results = list(registry_model.yield_tags_for_vulnerability_notification(pairs))
- assert len(results) > 0
- assert tag.name in {t.name for t in results}
+ # Check for the image itself.
+ pairs = [(image.docker_image_id, image.blob.uuid)]
+ results = list(registry_model.yield_tags_for_vulnerability_notification(pairs))
+ assert len(results) > 0
+ assert tag.name in {t.name for t in results}
diff --git a/data/registry_model/test/test_manifestbuilder.py b/data/registry_model/test/test_manifestbuilder.py
index 538731b8d..663699454 100644
--- a/data/registry_model/test/test_manifestbuilder.py
+++ b/data/registry_model/test/test_manifestbuilder.py
@@ -10,7 +10,10 @@ from mock import patch
from app import docker_v2_signing_key
from data.registry_model.blobuploader import BlobUploadSettings, upload_blob
-from data.registry_model.manifestbuilder import create_manifest_builder, lookup_manifest_builder
+from data.registry_model.manifestbuilder import (
+ create_manifest_builder,
+ lookup_manifest_builder,
+)
from data.registry_model.registry_pre_oci_model import PreOCIModel
from data.registry_model.registry_oci_model import OCIModel
@@ -21,84 +24,116 @@ from test.fixtures import *
@pytest.fixture(params=[PreOCIModel, OCIModel])
def registry_model(request, initialized_db):
- return request.param()
+ return request.param()
@pytest.fixture()
def fake_session():
- with patch('data.registry_model.manifestbuilder.session', {}):
- yield
+ with patch("data.registry_model.manifestbuilder.session", {}):
+ yield
-@pytest.mark.parametrize('layers', [
- pytest.param([('someid', None, 'some data')], id='Single layer'),
- pytest.param([('parentid', None, 'some parent data'),
- ('someid', 'parentid', 'some data')],
- id='Multi layer'),
-])
+@pytest.mark.parametrize(
+ "layers",
+ [
+ pytest.param([("someid", None, "some data")], id="Single layer"),
+ pytest.param(
+ [
+ ("parentid", None, "some parent data"),
+ ("someid", "parentid", "some data"),
+ ],
+ id="Multi layer",
+ ),
+ ],
+)
def test_build_manifest(layers, fake_session, registry_model):
- repository_ref = registry_model.lookup_repository('devtable', 'complex')
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- settings = BlobUploadSettings('2M', 512 * 1024, 3600)
- app_config = {'TESTING': True}
+ repository_ref = registry_model.lookup_repository("devtable", "complex")
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ settings = BlobUploadSettings("2M", 512 * 1024, 3600)
+ app_config = {"TESTING": True}
- builder = create_manifest_builder(repository_ref, storage, docker_v2_signing_key)
- assert lookup_manifest_builder(repository_ref, 'anotherid', storage,
- docker_v2_signing_key) is None
- assert lookup_manifest_builder(repository_ref, builder.builder_id, storage,
- docker_v2_signing_key) is not None
+ builder = create_manifest_builder(repository_ref, storage, docker_v2_signing_key)
+ assert (
+ lookup_manifest_builder(
+ repository_ref, "anotherid", storage, docker_v2_signing_key
+ )
+ is None
+ )
+ assert (
+ lookup_manifest_builder(
+ repository_ref, builder.builder_id, storage, docker_v2_signing_key
+ )
+ is not None
+ )
- blobs_by_layer = {}
- for layer_id, parent_id, layer_bytes in layers:
- # Start a new layer.
- assert builder.start_layer(layer_id, json.dumps({'id': layer_id, 'parent': parent_id}),
- 'local_us', None, 60)
+ blobs_by_layer = {}
+ for layer_id, parent_id, layer_bytes in layers:
+ # Start a new layer.
+ assert builder.start_layer(
+ layer_id,
+ json.dumps({"id": layer_id, "parent": parent_id}),
+ "local_us",
+ None,
+ 60,
+ )
- checksum = hashlib.sha1(layer_bytes).hexdigest()
+ checksum = hashlib.sha1(layer_bytes).hexdigest()
- # Assign it a blob.
- with upload_blob(repository_ref, storage, settings) as uploader:
- uploader.upload_chunk(app_config, BytesIO(layer_bytes))
- blob = uploader.commit_to_blob(app_config)
- blobs_by_layer[layer_id] = blob
- builder.assign_layer_blob(builder.lookup_layer(layer_id), blob, [checksum])
+ # Assign it a blob.
+ with upload_blob(repository_ref, storage, settings) as uploader:
+ uploader.upload_chunk(app_config, BytesIO(layer_bytes))
+ blob = uploader.commit_to_blob(app_config)
+ blobs_by_layer[layer_id] = blob
+ builder.assign_layer_blob(builder.lookup_layer(layer_id), blob, [checksum])
- # Validate the checksum.
- assert builder.validate_layer_checksum(builder.lookup_layer(layer_id), checksum)
+ # Validate the checksum.
+ assert builder.validate_layer_checksum(builder.lookup_layer(layer_id), checksum)
- # Commit the manifest to a tag.
- tag = builder.commit_tag_and_manifest('somenewtag', builder.lookup_layer(layers[-1][0]))
- assert tag
- assert tag in builder.committed_tags
+ # Commit the manifest to a tag.
+ tag = builder.commit_tag_and_manifest(
+ "somenewtag", builder.lookup_layer(layers[-1][0])
+ )
+ assert tag
+ assert tag in builder.committed_tags
- # Mark the builder as done.
- builder.done()
+ # Mark the builder as done.
+ builder.done()
- # Verify the legacy image for the tag.
- found = registry_model.get_repo_tag(repository_ref, 'somenewtag', include_legacy_image=True)
- assert found
- assert found.name == 'somenewtag'
- assert found.legacy_image.docker_image_id == layers[-1][0]
+ # Verify the legacy image for the tag.
+ found = registry_model.get_repo_tag(
+ repository_ref, "somenewtag", include_legacy_image=True
+ )
+ assert found
+ assert found.name == "somenewtag"
+ assert found.legacy_image.docker_image_id == layers[-1][0]
- # Verify the blob and manifest.
- manifest = registry_model.get_manifest_for_tag(found)
- assert manifest
+ # Verify the blob and manifest.
+ manifest = registry_model.get_manifest_for_tag(found)
+ assert manifest
- parsed = manifest.get_parsed_manifest()
- assert len(list(parsed.layers)) == len(layers)
+ parsed = manifest.get_parsed_manifest()
+ assert len(list(parsed.layers)) == len(layers)
- for index, (layer_id, parent_id, layer_bytes) in enumerate(layers):
- assert list(parsed.blob_digests)[index] == blobs_by_layer[layer_id].digest
- assert list(parsed.layers)[index].v1_metadata.image_id == layer_id
- assert list(parsed.layers)[index].v1_metadata.parent_image_id == parent_id
+ for index, (layer_id, parent_id, layer_bytes) in enumerate(layers):
+ assert list(parsed.blob_digests)[index] == blobs_by_layer[layer_id].digest
+ assert list(parsed.layers)[index].v1_metadata.image_id == layer_id
+ assert list(parsed.layers)[index].v1_metadata.parent_image_id == parent_id
- assert parsed.leaf_layer_v1_image_id == layers[-1][0]
+ assert parsed.leaf_layer_v1_image_id == layers[-1][0]
def test_build_manifest_missing_parent(fake_session, registry_model):
- storage = DistributedStorage({'local_us': FakeStorage(None)}, ['local_us'])
- repository_ref = registry_model.lookup_repository('devtable', 'complex')
- builder = create_manifest_builder(repository_ref, storage, docker_v2_signing_key)
+ storage = DistributedStorage({"local_us": FakeStorage(None)}, ["local_us"])
+ repository_ref = registry_model.lookup_repository("devtable", "complex")
+ builder = create_manifest_builder(repository_ref, storage, docker_v2_signing_key)
- assert builder.start_layer('somelayer', json.dumps({'id': 'somelayer', 'parent': 'someparent'}),
- 'local_us', None, 60) is None
+ assert (
+ builder.start_layer(
+ "somelayer",
+ json.dumps({"id": "somelayer", "parent": "someparent"}),
+ "local_us",
+ None,
+ 60,
+ )
+ is None
+ )
diff --git a/data/runmigration.py b/data/runmigration.py
index f4126aba1..31eeb86e9 100644
--- a/data/runmigration.py
+++ b/data/runmigration.py
@@ -5,23 +5,24 @@ from alembic.script import ScriptDirectory
from alembic.environment import EnvironmentContext
from alembic.migration import __name__ as migration_name
+
def run_alembic_migration(db_uri, log_handler=None, setup_app=True):
- if log_handler:
- logging.getLogger(migration_name).addHandler(log_handler)
+ if log_handler:
+ logging.getLogger(migration_name).addHandler(log_handler)
- config = Config()
- config.set_main_option("script_location", "data:migrations")
- config.set_main_option("db_uri", db_uri)
+ config = Config()
+ config.set_main_option("script_location", "data:migrations")
+ config.set_main_option("db_uri", db_uri)
- if setup_app:
- config.set_main_option('alembic_setup_app', 'True')
- else:
- config.set_main_option('alembic_setup_app', '')
+ if setup_app:
+ config.set_main_option("alembic_setup_app", "True")
+ else:
+ config.set_main_option("alembic_setup_app", "")
- script = ScriptDirectory.from_config(config)
+ script = ScriptDirectory.from_config(config)
- def fn(rev, context):
- return script._upgrade_revs('head', rev)
+ def fn(rev, context):
+ return script._upgrade_revs("head", rev)
- with EnvironmentContext(config, script, fn=fn, destination_rev='head'):
- script.run_env()
\ No newline at end of file
+ with EnvironmentContext(config, script, fn=fn, destination_rev="head"):
+ script.run_env()
diff --git a/data/test/test_encryption.py b/data/test/test_encryption.py
index f6ec8a94b..6bc9fefc6 100644
--- a/data/test/test_encryption.py
+++ b/data/test/test_encryption.py
@@ -4,44 +4,48 @@ import pytest
from data.encryption import FieldEncrypter, _VERSIONS, DecryptionFailureException
-@pytest.mark.parametrize('test_data', [
- '',
- 'hello world',
- 'wassup?!',
- 'IGZ2Y8KUN3EFWAZZXR3D7U4V5NXDVYZI5VGU6STPB6KM83PAB8WRGM32RD9FW0C0',
- 'JLRFBYS1EHKUE73S99HWOQWNPGLUZTBRF5HQEFUJS5BK3XVB54RNXYV4AUMJXCMC',
- 'a' * 3,
- 'a' * 4,
- 'a' * 5,
- 'a' * 31,
- 'a' * 32,
- 'a' * 33,
- 'a' * 150,
- u'😇',
-])
-@pytest.mark.parametrize('version', _VERSIONS.keys())
-@pytest.mark.parametrize('secret_key', [
- u'test1234',
- 'test1234',
- 'thisisanothercoolsecretkeyhere',
- '107383705745765174750346070528443780244192102846031525796571939503548634055845',
-])
-@pytest.mark.parametrize('use_valid_key', [
- True,
- False,
-])
+
+@pytest.mark.parametrize(
+ "test_data",
+ [
+ "",
+ "hello world",
+ "wassup?!",
+ "IGZ2Y8KUN3EFWAZZXR3D7U4V5NXDVYZI5VGU6STPB6KM83PAB8WRGM32RD9FW0C0",
+ "JLRFBYS1EHKUE73S99HWOQWNPGLUZTBRF5HQEFUJS5BK3XVB54RNXYV4AUMJXCMC",
+ "a" * 3,
+ "a" * 4,
+ "a" * 5,
+ "a" * 31,
+ "a" * 32,
+ "a" * 33,
+ "a" * 150,
+ u"😇",
+ ],
+)
+@pytest.mark.parametrize("version", _VERSIONS.keys())
+@pytest.mark.parametrize(
+ "secret_key",
+ [
+ u"test1234",
+ "test1234",
+ "thisisanothercoolsecretkeyhere",
+ "107383705745765174750346070528443780244192102846031525796571939503548634055845",
+ ],
+)
+@pytest.mark.parametrize("use_valid_key", [True, False])
def test_encryption(test_data, version, secret_key, use_valid_key):
- encrypter = FieldEncrypter(secret_key, version)
- encrypted = encrypter.encrypt_value(test_data, field_max_length=255)
- assert encrypted != test_data
+ encrypter = FieldEncrypter(secret_key, version)
+ encrypted = encrypter.encrypt_value(test_data, field_max_length=255)
+ assert encrypted != test_data
- if use_valid_key:
- decrypted = encrypter.decrypt_value(encrypted)
- assert decrypted == test_data
+ if use_valid_key:
+ decrypted = encrypter.decrypt_value(encrypted)
+ assert decrypted == test_data
- with pytest.raises(DecryptionFailureException):
- encrypter.decrypt_value('somerandomvalue')
- else:
- decrypter = FieldEncrypter('some other key', version)
- with pytest.raises(DecryptionFailureException):
- decrypter.decrypt_value(encrypted)
+ with pytest.raises(DecryptionFailureException):
+ encrypter.decrypt_value("somerandomvalue")
+ else:
+ decrypter = FieldEncrypter("some other key", version)
+ with pytest.raises(DecryptionFailureException):
+ decrypter.decrypt_value(encrypted)
diff --git a/data/test/test_queue.py b/data/test/test_queue.py
index 36f61b502..0ed829213 100644
--- a/data/test/test_queue.py
+++ b/data/test/test_queue.py
@@ -12,409 +12,454 @@ from data.queue import WorkQueue, MINIMUM_EXTENSION
from test.fixtures import *
-QUEUE_NAME = 'testqueuename'
+QUEUE_NAME = "testqueuename"
class SaveLastCountReporter(object):
- def __init__(self):
- self.currently_processing = None
- self.running_count = None
- self.total = None
+ def __init__(self):
+ self.currently_processing = None
+ self.running_count = None
+ self.total = None
- def __call__(self, currently_processing, running_count, total_jobs):
- self.currently_processing = currently_processing
- self.running_count = running_count
- self.total = total_jobs
+ def __call__(self, currently_processing, running_count, total_jobs):
+ self.currently_processing = currently_processing
+ self.running_count = running_count
+ self.total = total_jobs
class AutoUpdatingQueue(object):
- def __init__(self, queue_to_wrap):
- self._queue = queue_to_wrap
+ def __init__(self, queue_to_wrap):
+ self._queue = queue_to_wrap
- def _wrapper(self, func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- to_return = func(*args, **kwargs)
- self._queue.update_metrics()
- return to_return
- return wrapper
+ def _wrapper(self, func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ to_return = func(*args, **kwargs)
+ self._queue.update_metrics()
+ return to_return
- def __getattr__(self, attr_name):
- method_or_attr = getattr(self._queue, attr_name)
- if callable(method_or_attr):
- return self._wrapper(method_or_attr)
- else:
- return method_or_attr
+ return wrapper
+
+ def __getattr__(self, attr_name):
+ method_or_attr = getattr(self._queue, attr_name)
+ if callable(method_or_attr):
+ return self._wrapper(method_or_attr)
+ else:
+ return method_or_attr
-TEST_MESSAGE_1 = json.dumps({'data': 1})
-TEST_MESSAGE_2 = json.dumps({'data': 2})
-TEST_MESSAGES = [json.dumps({'data': str(i)}) for i in range(1, 101)]
+TEST_MESSAGE_1 = json.dumps({"data": 1})
+TEST_MESSAGE_2 = json.dumps({"data": 2})
+TEST_MESSAGES = [json.dumps({"data": str(i)}) for i in range(1, 101)]
@contextmanager
def fake_transaction(arg):
- yield
+ yield
+
@pytest.fixture()
def reporter():
- return SaveLastCountReporter()
+ return SaveLastCountReporter()
@pytest.fixture()
def transaction_factory():
- return fake_transaction
+ return fake_transaction
@pytest.fixture()
def queue(reporter, transaction_factory, initialized_db):
- return AutoUpdatingQueue(WorkQueue(QUEUE_NAME, transaction_factory, reporter=reporter))
+ return AutoUpdatingQueue(
+ WorkQueue(QUEUE_NAME, transaction_factory, reporter=reporter)
+ )
def test_get_single_item(queue, reporter, transaction_factory):
- # Add a single item to the queue.
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
+ # Add a single item to the queue.
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
- # Have two "instances" retrieve an item to claim. Since there is only one, both calls should
- # return the same item.
- now = datetime.utcnow()
- first_item = queue._select_available_item(False, now)
- second_item = queue._select_available_item(False, now)
+ # Have two "instances" retrieve an item to claim. Since there is only one, both calls should
+ # return the same item.
+ now = datetime.utcnow()
+ first_item = queue._select_available_item(False, now)
+ second_item = queue._select_available_item(False, now)
- assert first_item.id == second_item.id
- assert first_item.state_id == second_item.state_id
+ assert first_item.id == second_item.id
+ assert first_item.state_id == second_item.state_id
- # Have both "instances" now try to claim the item. Only one should succeed.
- first_claimed = queue._attempt_to_claim_item(first_item, now, 300)
- second_claimed = queue._attempt_to_claim_item(first_item, now, 300)
+ # Have both "instances" now try to claim the item. Only one should succeed.
+ first_claimed = queue._attempt_to_claim_item(first_item, now, 300)
+ second_claimed = queue._attempt_to_claim_item(first_item, now, 300)
- assert first_claimed
- assert not second_claimed
+ assert first_claimed
+ assert not second_claimed
- # Ensure the item is no longer available.
- assert queue.get() is None
+ # Ensure the item is no longer available.
+ assert queue.get() is None
+
+ # Ensure the item's state ID has changed.
+ assert first_item.state_id != QueueItem.get().state_id
- # Ensure the item's state ID has changed.
- assert first_item.state_id != QueueItem.get().state_id
def test_extend_processing(queue, reporter, transaction_factory):
- # Add and retrieve a queue item.
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue_item = queue.get(processing_time=10)
- assert queue_item is not None
+ # Add and retrieve a queue item.
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue_item = queue.get(processing_time=10)
+ assert queue_item is not None
- existing_db_item = QueueItem.get(id=queue_item.id)
+ existing_db_item = QueueItem.get(id=queue_item.id)
- # Call extend processing with a timedelta less than the minimum and ensure its
- # processing_expires and state_id do not change.
- changed = queue.extend_processing(queue_item, 10 + MINIMUM_EXTENSION.total_seconds() - 1)
- assert not changed
+ # Call extend processing with a timedelta less than the minimum and ensure its
+ # processing_expires and state_id do not change.
+ changed = queue.extend_processing(
+ queue_item, 10 + MINIMUM_EXTENSION.total_seconds() - 1
+ )
+ assert not changed
- updated_db_item = QueueItem.get(id=queue_item.id)
+ updated_db_item = QueueItem.get(id=queue_item.id)
- assert existing_db_item.processing_expires == updated_db_item.processing_expires
- assert existing_db_item.state_id == updated_db_item.state_id
+ assert existing_db_item.processing_expires == updated_db_item.processing_expires
+ assert existing_db_item.state_id == updated_db_item.state_id
- # Call extend processing with a timedelta greater than the minimum and ensure its
- # processing_expires and state_id are changed.
- changed = queue.extend_processing(queue_item, 10 + MINIMUM_EXTENSION.total_seconds() + 1)
- assert changed
+ # Call extend processing with a timedelta greater than the minimum and ensure its
+ # processing_expires and state_id are changed.
+ changed = queue.extend_processing(
+ queue_item, 10 + MINIMUM_EXTENSION.total_seconds() + 1
+ )
+ assert changed
- updated_db_item = QueueItem.get(id=queue_item.id)
+ updated_db_item = QueueItem.get(id=queue_item.id)
- assert existing_db_item.processing_expires != updated_db_item.processing_expires
- assert existing_db_item.state_id != updated_db_item.state_id
+ assert existing_db_item.processing_expires != updated_db_item.processing_expires
+ assert existing_db_item.state_id != updated_db_item.state_id
- # Call extend processing with a timedelta less than the minimum but also with new data and
- # ensure its processing_expires and state_id are changed.
- changed = queue.extend_processing(queue_item, 10 + MINIMUM_EXTENSION.total_seconds() - 1,
- updated_data='newbody')
- assert changed
+ # Call extend processing with a timedelta less than the minimum but also with new data and
+ # ensure its processing_expires and state_id are changed.
+ changed = queue.extend_processing(
+ queue_item, 10 + MINIMUM_EXTENSION.total_seconds() - 1, updated_data="newbody"
+ )
+ assert changed
- updated_db_item = QueueItem.get(id=queue_item.id)
+ updated_db_item = QueueItem.get(id=queue_item.id)
+
+ assert existing_db_item.processing_expires != updated_db_item.processing_expires
+ assert existing_db_item.state_id != updated_db_item.state_id
+ assert updated_db_item.body == "newbody"
- assert existing_db_item.processing_expires != updated_db_item.processing_expires
- assert existing_db_item.state_id != updated_db_item.state_id
- assert updated_db_item.body == 'newbody'
def test_same_canonical_names(queue, reporter, transaction_factory):
- assert reporter.currently_processing is None
- assert reporter.running_count is None
- assert reporter.total is None
+ assert reporter.currently_processing is None
+ assert reporter.running_count is None
+ assert reporter.total is None
- id_1 = int(queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1))
- id_2 = int(queue.put(['abc', 'def'], TEST_MESSAGE_2, available_after=-1))
- assert id_1 + 1 == id_2
- assert not reporter.currently_processing
- assert reporter.running_count == 0
- assert reporter.total == 1
+ id_1 = int(queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1))
+ id_2 = int(queue.put(["abc", "def"], TEST_MESSAGE_2, available_after=-1))
+ assert id_1 + 1 == id_2
+ assert not reporter.currently_processing
+ assert reporter.running_count == 0
+ assert reporter.total == 1
- one = queue.get(ordering_required=True)
- assert one is not None
- assert one.body == TEST_MESSAGE_1
- assert reporter.currently_processing
- assert reporter.running_count == 1
- assert reporter.total == 1
+ one = queue.get(ordering_required=True)
+ assert one is not None
+ assert one.body == TEST_MESSAGE_1
+ assert reporter.currently_processing
+ assert reporter.running_count == 1
+ assert reporter.total == 1
- two_fail = queue.get(ordering_required=True)
- assert two_fail is None
- assert reporter.running_count == 1
- assert reporter.total == 1
+ two_fail = queue.get(ordering_required=True)
+ assert two_fail is None
+ assert reporter.running_count == 1
+ assert reporter.total == 1
- queue.complete(one)
- assert not reporter.currently_processing
- assert reporter.running_count == 0
- assert reporter.total == 1
+ queue.complete(one)
+ assert not reporter.currently_processing
+ assert reporter.running_count == 0
+ assert reporter.total == 1
+
+ two = queue.get(ordering_required=True)
+ assert two is not None
+ assert reporter.currently_processing
+ assert two.body == TEST_MESSAGE_2
+ assert reporter.running_count == 1
+ assert reporter.total == 1
- two = queue.get(ordering_required=True)
- assert two is not None
- assert reporter.currently_processing
- assert two.body == TEST_MESSAGE_2
- assert reporter.running_count == 1
- assert reporter.total == 1
def test_different_canonical_names(queue, reporter, transaction_factory):
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue.put(['abc', 'ghi'], TEST_MESSAGE_2, available_after=-1)
- assert reporter.running_count == 0
- assert reporter.total == 2
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue.put(["abc", "ghi"], TEST_MESSAGE_2, available_after=-1)
+ assert reporter.running_count == 0
+ assert reporter.total == 2
- one = queue.get(ordering_required=True)
- assert one is not None
- assert one.body == TEST_MESSAGE_1
- assert reporter.running_count == 1
- assert reporter.total == 2
+ one = queue.get(ordering_required=True)
+ assert one is not None
+ assert one.body == TEST_MESSAGE_1
+ assert reporter.running_count == 1
+ assert reporter.total == 2
+
+ two = queue.get(ordering_required=True)
+ assert two is not None
+ assert two.body == TEST_MESSAGE_2
+ assert reporter.running_count == 2
+ assert reporter.total == 2
- two = queue.get(ordering_required=True)
- assert two is not None
- assert two.body == TEST_MESSAGE_2
- assert reporter.running_count == 2
- assert reporter.total == 2
def test_canonical_name(queue, reporter, transaction_factory):
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue.put(['abc', 'def', 'ghi'], TEST_MESSAGE_1, available_after=-1)
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue.put(["abc", "def", "ghi"], TEST_MESSAGE_1, available_after=-1)
- one = queue.get(ordering_required=True)
- assert QUEUE_NAME + '/abc/def/' != one
+ one = queue.get(ordering_required=True)
+ assert QUEUE_NAME + "/abc/def/" != one
+
+ two = queue.get(ordering_required=True)
+ assert QUEUE_NAME + "/abc/def/ghi/" != two
- two = queue.get(ordering_required=True)
- assert QUEUE_NAME + '/abc/def/ghi/' != two
def test_expiration(queue, reporter, transaction_factory):
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- assert reporter.running_count == 0
- assert reporter.total == 1
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ assert reporter.running_count == 0
+ assert reporter.total == 1
- one = queue.get(processing_time=0.5, ordering_required=True)
- assert one is not None
- assert reporter.running_count == 1
- assert reporter.total == 1
+ one = queue.get(processing_time=0.5, ordering_required=True)
+ assert one is not None
+ assert reporter.running_count == 1
+ assert reporter.total == 1
- one_fail = queue.get(ordering_required=True)
- assert one_fail is None
+ one_fail = queue.get(ordering_required=True)
+ assert one_fail is None
- time.sleep(1)
- queue.update_metrics()
- assert reporter.running_count == 0
- assert reporter.total == 1
+ time.sleep(1)
+ queue.update_metrics()
+ assert reporter.running_count == 0
+ assert reporter.total == 1
+
+ one_again = queue.get(ordering_required=True)
+ assert one_again is not None
+ assert reporter.running_count == 1
+ assert reporter.total == 1
- one_again = queue.get(ordering_required=True)
- assert one_again is not None
- assert reporter.running_count == 1
- assert reporter.total == 1
def test_alive(queue, reporter, transaction_factory):
- # No queue item = not alive.
- assert not queue.alive(['abc', 'def'])
+ # No queue item = not alive.
+ assert not queue.alive(["abc", "def"])
- # Add a queue item.
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- assert queue.alive(['abc', 'def'])
+ # Add a queue item.
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ assert queue.alive(["abc", "def"])
- # Retrieve the queue item.
- queue_item = queue.get()
- assert queue_item is not None
- assert queue.alive(['abc', 'def'])
+ # Retrieve the queue item.
+ queue_item = queue.get()
+ assert queue_item is not None
+ assert queue.alive(["abc", "def"])
- # Make sure it is running by trying to retrieve it again.
- assert queue.get() is None
+ # Make sure it is running by trying to retrieve it again.
+ assert queue.get() is None
+
+ # Delete the queue item.
+ queue.complete(queue_item)
+ assert not queue.alive(["abc", "def"])
- # Delete the queue item.
- queue.complete(queue_item)
- assert not queue.alive(['abc', 'def'])
def test_specialized_queue(queue, reporter, transaction_factory):
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue.put(['def', 'def'], TEST_MESSAGE_2, available_after=-1)
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue.put(["def", "def"], TEST_MESSAGE_2, available_after=-1)
- my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, transaction_factory, ['def']))
+ my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, transaction_factory, ["def"]))
- two = my_queue.get(ordering_required=True)
- assert two is not None
- assert two.body == TEST_MESSAGE_2
+ two = my_queue.get(ordering_required=True)
+ assert two is not None
+ assert two.body == TEST_MESSAGE_2
- one_fail = my_queue.get(ordering_required=True)
- assert one_fail is None
+ one_fail = my_queue.get(ordering_required=True)
+ assert one_fail is None
+
+ one = queue.get(ordering_required=True)
+ assert one is not None
+ assert one.body == TEST_MESSAGE_1
- one = queue.get(ordering_required=True)
- assert one is not None
- assert one.body == TEST_MESSAGE_1
def test_random_queue_no_duplicates(queue, reporter, transaction_factory):
- for msg in TEST_MESSAGES:
- queue.put(['abc', 'def'], msg, available_after=-1)
- seen = set()
+ for msg in TEST_MESSAGES:
+ queue.put(["abc", "def"], msg, available_after=-1)
+ seen = set()
- for _ in range(1, 101):
- item = queue.get()
- json_body = json.loads(item.body)
- msg = str(json_body['data'])
- assert msg not in seen
- seen.add(msg)
+ for _ in range(1, 101):
+ item = queue.get()
+ json_body = json.loads(item.body)
+ msg = str(json_body["data"])
+ assert msg not in seen
+ seen.add(msg)
+
+ for body in TEST_MESSAGES:
+ json_body = json.loads(body)
+ msg = str(json_body["data"])
+ assert msg in seen
- for body in TEST_MESSAGES:
- json_body = json.loads(body)
- msg = str(json_body['data'])
- assert msg in seen
def test_bulk_insert(queue, reporter, transaction_factory):
- assert reporter.currently_processing is None
- assert reporter.running_count is None
- assert reporter.total is None
+ assert reporter.currently_processing is None
+ assert reporter.running_count is None
+ assert reporter.total is None
- with queue.batch_insert() as queue_put:
- queue_put(['abc', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue_put(['abc', 'def'], TEST_MESSAGE_2, available_after=-1)
+ with queue.batch_insert() as queue_put:
+ queue_put(["abc", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue_put(["abc", "def"], TEST_MESSAGE_2, available_after=-1)
- queue.update_metrics()
- assert not reporter.currently_processing
- assert reporter.running_count == 0
- assert reporter.total == 1
+ queue.update_metrics()
+ assert not reporter.currently_processing
+ assert reporter.running_count == 0
+ assert reporter.total == 1
- with queue.batch_insert() as queue_put:
- queue_put(['abd', 'def'], TEST_MESSAGE_1, available_after=-1)
- queue_put(['abd', 'ghi'], TEST_MESSAGE_2, available_after=-1)
+ with queue.batch_insert() as queue_put:
+ queue_put(["abd", "def"], TEST_MESSAGE_1, available_after=-1)
+ queue_put(["abd", "ghi"], TEST_MESSAGE_2, available_after=-1)
+
+ queue.update_metrics()
+ assert not reporter.currently_processing
+ assert reporter.running_count == 0
+ assert reporter.total == 3
- queue.update_metrics()
- assert not reporter.currently_processing
- assert reporter.running_count == 0
- assert reporter.total == 3
def test_num_available_between(queue, reporter, transaction_factory):
- now = datetime.utcnow()
- queue.put(['abc', 'def'], TEST_MESSAGE_1, available_after=-10)
- queue.put(['abc', 'ghi'], TEST_MESSAGE_2, available_after=-5)
+ now = datetime.utcnow()
+ queue.put(["abc", "def"], TEST_MESSAGE_1, available_after=-10)
+ queue.put(["abc", "ghi"], TEST_MESSAGE_2, available_after=-5)
- # Partial results
- count = queue.num_available_jobs_between(now-timedelta(seconds=8), now, ['abc'])
- assert count == 1
+ # Partial results
+ count = queue.num_available_jobs_between(now - timedelta(seconds=8), now, ["abc"])
+ assert count == 1
- # All results
- count = queue.num_available_jobs_between(now-timedelta(seconds=20), now, ['/abc'])
- assert count == 2
+ # All results
+ count = queue.num_available_jobs_between(now - timedelta(seconds=20), now, ["/abc"])
+ assert count == 2
+
+ # No results
+ count = queue.num_available_jobs_between(now, now, "abc")
+ assert count == 0
- # No results
- count = queue.num_available_jobs_between(now, now, 'abc')
- assert count == 0
def test_incomplete(queue, reporter, transaction_factory):
- # Add an item.
- queue.put(['somenamespace', 'abc', 'def'], TEST_MESSAGE_1, available_after=-10)
+ # Add an item.
+ queue.put(["somenamespace", "abc", "def"], TEST_MESSAGE_1, available_after=-10)
- now = datetime.utcnow()
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 1
+ now = datetime.utcnow()
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 1
- # Retrieve it.
- item = queue.get()
- assert item is not None
- assert reporter.currently_processing
+ # Retrieve it.
+ item = queue.get()
+ assert item is not None
+ assert reporter.currently_processing
- # Mark it as incomplete.
- queue.incomplete(item, retry_after=-1)
- assert not reporter.currently_processing
+ # Mark it as incomplete.
+ queue.incomplete(item, retry_after=-1)
+ assert not reporter.currently_processing
- # Retrieve again to ensure it is once again available.
- same_item = queue.get()
- assert same_item is not None
- assert reporter.currently_processing
+ # Retrieve again to ensure it is once again available.
+ same_item = queue.get()
+ assert same_item is not None
+ assert reporter.currently_processing
+
+ assert item.id == same_item.id
- assert item.id == same_item.id
def test_complete(queue, reporter, transaction_factory):
- # Add an item.
- queue.put(['somenamespace', 'abc', 'def'], TEST_MESSAGE_1, available_after=-10)
+ # Add an item.
+ queue.put(["somenamespace", "abc", "def"], TEST_MESSAGE_1, available_after=-10)
- now = datetime.utcnow()
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 1
+ now = datetime.utcnow()
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 1
- # Retrieve it.
- item = queue.get()
- assert item is not None
- assert reporter.currently_processing
+ # Retrieve it.
+ item = queue.get()
+ assert item is not None
+ assert reporter.currently_processing
+
+ # Mark it as complete.
+ queue.complete(item)
+ assert not reporter.currently_processing
- # Mark it as complete.
- queue.complete(item)
- assert not reporter.currently_processing
def test_cancel(queue, reporter, transaction_factory):
- # Add an item.
- queue.put(['somenamespace', 'abc', 'def'], TEST_MESSAGE_1, available_after=-10)
- queue.put(['somenamespace', 'abc', 'def'], TEST_MESSAGE_2, available_after=-5)
+ # Add an item.
+ queue.put(["somenamespace", "abc", "def"], TEST_MESSAGE_1, available_after=-10)
+ queue.put(["somenamespace", "abc", "def"], TEST_MESSAGE_2, available_after=-5)
- now = datetime.utcnow()
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 2
+ now = datetime.utcnow()
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 2
- # Retrieve it.
- item = queue.get()
- assert item is not None
+ # Retrieve it.
+ item = queue.get()
+ assert item is not None
- # Make sure we can cancel it.
- assert queue.cancel(item.id)
+ # Make sure we can cancel it.
+ assert queue.cancel(item.id)
- now = datetime.utcnow()
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 1
+ now = datetime.utcnow()
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 1
+
+ # Make sure it is gone.
+ assert not queue.cancel(item.id)
- # Make sure it is gone.
- assert not queue.cancel(item.id)
def test_deleted_namespaced_items(queue, reporter, transaction_factory):
- queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, transaction_factory,
- reporter=reporter,
- has_namespace=True))
+ queue = AutoUpdatingQueue(
+ WorkQueue(
+ QUEUE_NAME, transaction_factory, reporter=reporter, has_namespace=True
+ )
+ )
- queue.put(['somenamespace', 'abc', 'def'], TEST_MESSAGE_1, available_after=-10)
- queue.put(['somenamespace', 'abc', 'ghi'], TEST_MESSAGE_2, available_after=-5)
- queue.put(['anothernamespace', 'abc', 'def'], TEST_MESSAGE_1, available_after=-10)
+ queue.put(["somenamespace", "abc", "def"], TEST_MESSAGE_1, available_after=-10)
+ queue.put(["somenamespace", "abc", "ghi"], TEST_MESSAGE_2, available_after=-5)
+ queue.put(["anothernamespace", "abc", "def"], TEST_MESSAGE_1, available_after=-10)
- # Ensure we have 2 items under `somenamespace` and 1 item under `anothernamespace`.
- now = datetime.utcnow()
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 2
+ # Ensure we have 2 items under `somenamespace` and 1 item under `anothernamespace`.
+ now = datetime.utcnow()
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 2
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/anothernamespace'])
- assert count == 1
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/anothernamespace"]
+ )
+ assert count == 1
- # Delete all `somenamespace` items.
- queue.delete_namespaced_items('somenamespace')
+ # Delete all `somenamespace` items.
+ queue.delete_namespaced_items("somenamespace")
- # Check the updated counts.
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 0
+ # Check the updated counts.
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 0
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/anothernamespace'])
- assert count == 1
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/anothernamespace"]
+ )
+ assert count == 1
- # Delete all `anothernamespace` items.
- queue.delete_namespaced_items('anothernamespace')
+ # Delete all `anothernamespace` items.
+ queue.delete_namespaced_items("anothernamespace")
- # Check the updated counts.
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/somenamespace'])
- assert count == 0
+ # Check the updated counts.
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/somenamespace"]
+ )
+ assert count == 0
- count = queue.num_available_jobs_between(now - timedelta(seconds=60), now, ['/anothernamespace'])
- assert count == 0
+ count = queue.num_available_jobs_between(
+ now - timedelta(seconds=60), now, ["/anothernamespace"]
+ )
+ assert count == 0
diff --git a/data/test/test_readreplica.py b/data/test/test_readreplica.py
index 7f7111d2a..94ad1328a 100644
--- a/data/test/test_readreplica.py
+++ b/data/test/test_readreplica.py
@@ -11,92 +11,88 @@ from test.testconfig import FakeTransaction
from test.fixtures import *
-@pytest.mark.skipif(bool(os.environ.get('TEST_DATABASE_URI')), reason='Testing requires SQLite')
+@pytest.mark.skipif(
+ bool(os.environ.get("TEST_DATABASE_URI")), reason="Testing requires SQLite"
+)
def test_readreplica(init_db_path, tmpdir_factory):
- primary_file = str(tmpdir_factory.mktemp("data").join("primary.db"))
- replica_file = str(tmpdir_factory.mktemp("data").join("replica.db"))
+ primary_file = str(tmpdir_factory.mktemp("data").join("primary.db"))
+ replica_file = str(tmpdir_factory.mktemp("data").join("replica.db"))
- # Copy the initialized database to two different locations.
- shutil.copy2(init_db_path, primary_file)
- shutil.copy2(init_db_path, replica_file)
+ # Copy the initialized database to two different locations.
+ shutil.copy2(init_db_path, primary_file)
+ shutil.copy2(init_db_path, replica_file)
- db_config = {
- 'DB_URI': 'sqlite:///{0}'.format(primary_file),
- 'DB_READ_REPLICAS': [
- {'DB_URI': 'sqlite:///{0}'.format(replica_file)},
- ],
- "DB_CONNECTION_ARGS": {
- 'threadlocals': True,
- 'autorollback': True,
- },
- "DB_TRANSACTION_FACTORY": lambda x: FakeTransaction(),
- "FOR_TESTING": True,
- "DATABASE_SECRET_KEY": "anothercrazykey!",
- }
+ db_config = {
+ "DB_URI": "sqlite:///{0}".format(primary_file),
+ "DB_READ_REPLICAS": [{"DB_URI": "sqlite:///{0}".format(replica_file)}],
+ "DB_CONNECTION_ARGS": {"threadlocals": True, "autorollback": True},
+ "DB_TRANSACTION_FACTORY": lambda x: FakeTransaction(),
+ "FOR_TESTING": True,
+ "DATABASE_SECRET_KEY": "anothercrazykey!",
+ }
- # Initialize the DB with the primary and the replica.
- configure(db_config)
- assert not read_only_config.obj.is_readonly
- assert read_only_config.obj.read_replicas
+ # Initialize the DB with the primary and the replica.
+ configure(db_config)
+ assert not read_only_config.obj.is_readonly
+ assert read_only_config.obj.read_replicas
- # Ensure we can read the data.
- devtable_user = User.get(username='devtable')
- assert devtable_user.username == 'devtable'
+ # Ensure we can read the data.
+ devtable_user = User.get(username="devtable")
+ assert devtable_user.username == "devtable"
- # Configure with a bad primary. Reading should still work since we're hitting the replica.
- db_config['DB_URI'] = 'sqlite:///does/not/exist'
- configure(db_config)
+ # Configure with a bad primary. Reading should still work since we're hitting the replica.
+ db_config["DB_URI"] = "sqlite:///does/not/exist"
+ configure(db_config)
- assert not read_only_config.obj.is_readonly
- assert read_only_config.obj.read_replicas
+ assert not read_only_config.obj.is_readonly
+ assert read_only_config.obj.read_replicas
- devtable_user = User.get(username='devtable')
- assert devtable_user.username == 'devtable'
+ devtable_user = User.get(username="devtable")
+ assert devtable_user.username == "devtable"
- # Try to change some data. This should fail because the primary is broken.
- with pytest.raises(OperationalError):
- devtable_user.email = 'newlychanged'
+ # Try to change some data. This should fail because the primary is broken.
+ with pytest.raises(OperationalError):
+ devtable_user.email = "newlychanged"
+ devtable_user.save()
+
+ # Fix the primary and try again.
+ db_config["DB_URI"] = "sqlite:///{0}".format(primary_file)
+ configure(db_config)
+
+ assert not read_only_config.obj.is_readonly
+ assert read_only_config.obj.read_replicas
+
+ devtable_user.email = "newlychanged"
devtable_user.save()
- # Fix the primary and try again.
- db_config['DB_URI'] = 'sqlite:///{0}'.format(primary_file)
- configure(db_config)
+ # Mark the system as readonly.
+ db_config["DB_URI"] = "sqlite:///{0}".format(primary_file)
+ db_config["REGISTRY_STATE"] = "readonly"
+ configure(db_config)
- assert not read_only_config.obj.is_readonly
- assert read_only_config.obj.read_replicas
+ assert read_only_config.obj.is_readonly
+ assert read_only_config.obj.read_replicas
- devtable_user.email = 'newlychanged'
- devtable_user.save()
+ # Ensure all write operations raise a readonly mode exception.
+ with pytest.raises(ReadOnlyModeException):
+ devtable_user.email = "newlychanged2"
+ devtable_user.save()
- # Mark the system as readonly.
- db_config['DB_URI'] = 'sqlite:///{0}'.format(primary_file)
- db_config['REGISTRY_STATE'] = 'readonly'
- configure(db_config)
+ with pytest.raises(ReadOnlyModeException):
+ User.create(username="foo")
- assert read_only_config.obj.is_readonly
- assert read_only_config.obj.read_replicas
+ with pytest.raises(ReadOnlyModeException):
+ User.delete().where(User.username == "foo").execute()
- # Ensure all write operations raise a readonly mode exception.
- with pytest.raises(ReadOnlyModeException):
- devtable_user.email = 'newlychanged2'
- devtable_user.save()
+ with pytest.raises(ReadOnlyModeException):
+ User.update(username="bar").where(User.username == "foo").execute()
- with pytest.raises(ReadOnlyModeException):
- User.create(username='foo')
-
- with pytest.raises(ReadOnlyModeException):
- User.delete().where(User.username == 'foo').execute()
-
- with pytest.raises(ReadOnlyModeException):
- User.update(username='bar').where(User.username == 'foo').execute()
-
- # Reset the config on the DB, so we don't mess up other tests.
- configure({
- 'DB_URI': 'sqlite:///{0}'.format(primary_file),
- "DB_CONNECTION_ARGS": {
- 'threadlocals': True,
- 'autorollback': True,
- },
- "DB_TRANSACTION_FACTORY": lambda x: FakeTransaction(),
- "DATABASE_SECRET_KEY": "anothercrazykey!",
- })
+ # Reset the config on the DB, so we don't mess up other tests.
+ configure(
+ {
+ "DB_URI": "sqlite:///{0}".format(primary_file),
+ "DB_CONNECTION_ARGS": {"threadlocals": True, "autorollback": True},
+ "DB_TRANSACTION_FACTORY": lambda x: FakeTransaction(),
+ "DATABASE_SECRET_KEY": "anothercrazykey!",
+ }
+ )
diff --git a/data/test/test_text.py b/data/test/test_text.py
index 14b4519d1..bd73c38a2 100644
--- a/data/test/test_text.py
+++ b/data/test/test_text.py
@@ -4,26 +4,31 @@ from data.text import match_mysql, match_like
from data.database import Repository
from test.fixtures import *
-@pytest.mark.parametrize('input', [
- ('hello world'),
- ('hello \' world'),
- ('hello " world'),
- ('hello ` world'),
-])
+
+@pytest.mark.parametrize(
+ "input", [("hello world"), ("hello ' world"), ('hello " world'), ("hello ` world")]
+)
def test_mysql_text_escaping(input):
- query, values = Repository.select().where(match_mysql(Repository.description, input)).sql()
- assert input not in query
+ query, values = (
+ Repository.select().where(match_mysql(Repository.description, input)).sql()
+ )
+ assert input not in query
-@pytest.mark.parametrize('input, expected', [
- ('hello world', 'hello world'),
- ('hello \'world', 'hello world'),
- ('hello "world', 'hello world'),
- ('hello `world', 'hello world'),
- ('hello !world', 'hello !!world'),
- ('hello %world', 'hello !%world'),
-])
+@pytest.mark.parametrize(
+ "input, expected",
+ [
+ ("hello world", "hello world"),
+ ("hello 'world", "hello world"),
+ ('hello "world', "hello world"),
+ ("hello `world", "hello world"),
+ ("hello !world", "hello !!world"),
+ ("hello %world", "hello !%world"),
+ ],
+)
def test_postgres_text_escaping(input, expected):
- query, values = Repository.select().where(match_like(Repository.description, input)).sql()
- assert input not in query
- assert values[0] == '%' + expected + '%'
+ query, values = (
+ Repository.select().where(match_like(Repository.description, input)).sql()
+ )
+ assert input not in query
+ assert values[0] == "%" + expected + "%"
diff --git a/data/test/test_userfiles.py b/data/test/test_userfiles.py
index 671011e58..c832855b7 100644
--- a/data/test/test_userfiles.py
+++ b/data/test/test_userfiles.py
@@ -7,48 +7,54 @@ from data.userfiles import DelegateUserfiles, Userfiles
from test.fixtures import *
-@pytest.mark.parametrize('prefix,path,expected', [
- ('test', 'foo', 'test/foo'),
- ('test', 'bar', 'test/bar'),
- ('test', '/bar', 'test/bar'),
- ('test', '../foo', 'test/foo'),
- ('test', 'foo/bar/baz', 'test/baz'),
- ('test', 'foo/../baz', 'test/baz'),
-
- (None, 'foo', 'foo'),
- (None, 'foo/bar/baz', 'baz'),
-])
+@pytest.mark.parametrize(
+ "prefix,path,expected",
+ [
+ ("test", "foo", "test/foo"),
+ ("test", "bar", "test/bar"),
+ ("test", "/bar", "test/bar"),
+ ("test", "../foo", "test/foo"),
+ ("test", "foo/bar/baz", "test/baz"),
+ ("test", "foo/../baz", "test/baz"),
+ (None, "foo", "foo"),
+ (None, "foo/bar/baz", "baz"),
+ ],
+)
def test_filepath(prefix, path, expected):
- userfiles = DelegateUserfiles(None, None, 'local_us', prefix)
- assert userfiles.get_file_id_path(path) == expected
+ userfiles = DelegateUserfiles(None, None, "local_us", prefix)
+ assert userfiles.get_file_id_path(path) == expected
def test_lookup_userfile(app, client):
- uuid = 'deadbeef-dead-beef-dead-beefdeadbeef'
- bad_uuid = 'deadduck-dead-duck-dead-duckdeadduck'
- upper_uuid = 'DEADBEEF-DEAD-BEEF-DEAD-BEEFDEADBEEF'
+ uuid = "deadbeef-dead-beef-dead-beefdeadbeef"
+ bad_uuid = "deadduck-dead-duck-dead-duckdeadduck"
+ upper_uuid = "DEADBEEF-DEAD-BEEF-DEAD-BEEFDEADBEEF"
- def _stream_read_file(locations, path):
- if path.find(uuid) > 0 or path.find(upper_uuid) > 0:
- return BytesIO("hello world")
+ def _stream_read_file(locations, path):
+ if path.find(uuid) > 0 or path.find(upper_uuid) > 0:
+ return BytesIO("hello world")
- raise IOError('Not found!')
+ raise IOError("Not found!")
- storage_mock = Mock()
- storage_mock.stream_read_file = _stream_read_file
+ storage_mock = Mock()
+ storage_mock.stream_read_file = _stream_read_file
- app.config['USERFILES_PATH'] = 'foo'
- Userfiles(app, distributed_storage=storage_mock, path='mockuserfiles',
- handler_name='mockuserfiles')
+ app.config["USERFILES_PATH"] = "foo"
+ Userfiles(
+ app,
+ distributed_storage=storage_mock,
+ path="mockuserfiles",
+ handler_name="mockuserfiles",
+ )
- rv = client.open('/mockuserfiles/' + uuid, method='GET')
- assert rv.status_code == 200
+ rv = client.open("/mockuserfiles/" + uuid, method="GET")
+ assert rv.status_code == 200
- rv = client.open('/mockuserfiles/' + upper_uuid, method='GET')
- assert rv.status_code == 200
+ rv = client.open("/mockuserfiles/" + upper_uuid, method="GET")
+ assert rv.status_code == 200
- rv = client.open('/mockuserfiles/' + bad_uuid, method='GET')
- assert rv.status_code == 404
+ rv = client.open("/mockuserfiles/" + bad_uuid, method="GET")
+ assert rv.status_code == 404
- rv = client.open('/mockuserfiles/foo/bar/baz', method='GET')
- assert rv.status_code == 404
+ rv = client.open("/mockuserfiles/foo/bar/baz", method="GET")
+ assert rv.status_code == 404
diff --git a/data/text.py b/data/text.py
index 9fa6bbf3e..9ef4449d2 100644
--- a/data/text.py
+++ b/data/text.py
@@ -1,53 +1,59 @@
from peewee import NodeList, SQL, fn, TextField, Field
+
def _escape_wildcard(search_query):
- """ Escapes the wildcards found in the given search query so that they are treated as *characters*
+ """ Escapes the wildcards found in the given search query so that they are treated as *characters*
rather than wildcards when passed to a LIKE or ILIKE clause with an ESCAPE '!'.
"""
- search_query = (search_query
- .replace('!', '!!')
- .replace('%', '!%')
- .replace('_', '!_')
- .replace('[', '!['))
+ search_query = (
+ search_query.replace("!", "!!")
+ .replace("%", "!%")
+ .replace("_", "!_")
+ .replace("[", "![")
+ )
- # Just to be absolutely sure.
- search_query = search_query.replace('\'', '')
- search_query = search_query.replace('"', '')
- search_query = search_query.replace('`', '')
+ # Just to be absolutely sure.
+ search_query = search_query.replace("'", "")
+ search_query = search_query.replace('"', "")
+ search_query = search_query.replace("`", "")
- return search_query
+ return search_query
def prefix_search(field, prefix_query):
- """ Returns the wildcard match for searching for the given prefix query. """
- # Escape the known wildcard characters.
- prefix_query = _escape_wildcard(prefix_query)
- return Field.__pow__(field, NodeList((prefix_query + '%', SQL("ESCAPE '!'"))))
+ """ Returns the wildcard match for searching for the given prefix query. """
+ # Escape the known wildcard characters.
+ prefix_query = _escape_wildcard(prefix_query)
+ return Field.__pow__(field, NodeList((prefix_query + "%", SQL("ESCAPE '!'"))))
def match_mysql(field, search_query):
- """ Generates a full-text match query using a Match operation, which is needed for MySQL.
+ """ Generates a full-text match query using a Match operation, which is needed for MySQL.
"""
- if field.name.find('`') >= 0: # Just to be safe.
- raise Exception("How did field name '%s' end up containing a backtick?" % field.name)
+ if field.name.find("`") >= 0: # Just to be safe.
+ raise Exception(
+ "How did field name '%s' end up containing a backtick?" % field.name
+ )
- # Note: There is a known bug in MySQL (https://bugs.mysql.com/bug.php?id=78485) that causes
- # queries of the form `*` to raise a parsing error. If found, simply filter out.
- search_query = search_query.replace('*', '')
+ # Note: There is a known bug in MySQL (https://bugs.mysql.com/bug.php?id=78485) that causes
+ # queries of the form `*` to raise a parsing error. If found, simply filter out.
+ search_query = search_query.replace("*", "")
- # Just to be absolutely sure.
- search_query = search_query.replace('\'', '')
- search_query = search_query.replace('"', '')
- search_query = search_query.replace('`', '')
+ # Just to be absolutely sure.
+ search_query = search_query.replace("'", "")
+ search_query = search_query.replace('"', "")
+ search_query = search_query.replace("`", "")
- return NodeList((fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL('%s', [search_query]))),
- parens=True)
+ return NodeList(
+ (fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL("%s", [search_query]))),
+ parens=True,
+ )
def match_like(field, search_query):
- """ Generates a full-text match query using an ILIKE operation, which is needed for SQLite and
+ """ Generates a full-text match query using an ILIKE operation, which is needed for SQLite and
Postgres.
"""
- escaped_query = _escape_wildcard(search_query)
- clause = NodeList(('%' + escaped_query + '%', SQL("ESCAPE '!'")))
- return Field.__pow__(field, clause)
+ escaped_query = _escape_wildcard(search_query)
+ clause = NodeList(("%" + escaped_query + "%", SQL("ESCAPE '!'")))
+ return Field.__pow__(field, clause)
diff --git a/data/userevent.py b/data/userevent.py
index b4f340e5e..384c18f45 100644
--- a/data/userevent.py
+++ b/data/userevent.py
@@ -6,149 +6,156 @@ import redis
logger = logging.getLogger(__name__)
+
class CannotReadUserEventsException(Exception):
- """ Exception raised if user events cannot be read. """
+ """ Exception raised if user events cannot be read. """
+
class UserEventBuilder(object):
- """
+ """
Defines a helper class for constructing UserEvent and UserEventListener
instances.
"""
- def __init__(self, redis_config):
- self._redis_config = redis_config
- def get_event(self, username):
- return UserEvent(self._redis_config, username)
+ def __init__(self, redis_config):
+ self._redis_config = redis_config
- def get_listener(self, username, events):
- return UserEventListener(self._redis_config, username, events)
+ def get_event(self, username):
+ return UserEvent(self._redis_config, username)
+
+ def get_listener(self, username, events):
+ return UserEventListener(self._redis_config, username, events)
class UserEventsBuilderModule(object):
- def __init__(self, app=None):
- self.app = app
- if app is not None:
- self.state = self.init_app(app)
- else:
- self.state = None
+ def __init__(self, app=None):
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(app)
+ else:
+ self.state = None
- def init_app(self, app):
- redis_config = app.config.get('USER_EVENTS_REDIS')
- if not redis_config:
- # This is the old key name.
- redis_config = {
- 'host': app.config.get('USER_EVENTS_REDIS_HOSTNAME'),
- }
+ def init_app(self, app):
+ redis_config = app.config.get("USER_EVENTS_REDIS")
+ if not redis_config:
+ # This is the old key name.
+ redis_config = {"host": app.config.get("USER_EVENTS_REDIS_HOSTNAME")}
- user_events = UserEventBuilder(redis_config)
+ user_events = UserEventBuilder(redis_config)
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['userevents'] = user_events
- return user_events
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["userevents"] = user_events
+ return user_events
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
class UserEvent(object):
- """
+ """
Defines a helper class for publishing to realtime user events
as backed by Redis.
"""
- def __init__(self, redis_config, username):
- self._redis = redis.StrictRedis(socket_connect_timeout=2, socket_timeout=2, **redis_config)
- self._username = username
- @staticmethod
- def _user_event_key(username, event_id):
- return 'user/%s/events/%s' % (username, event_id)
+ def __init__(self, redis_config, username):
+ self._redis = redis.StrictRedis(
+ socket_connect_timeout=2, socket_timeout=2, **redis_config
+ )
+ self._username = username
- def publish_event_data_sync(self, event_id, data_obj):
- return self._redis.publish(self._user_event_key(self._username, event_id), json.dumps(data_obj))
+ @staticmethod
+ def _user_event_key(username, event_id):
+ return "user/%s/events/%s" % (username, event_id)
- def publish_event_data(self, event_id, data_obj):
- """
+ def publish_event_data_sync(self, event_id, data_obj):
+ return self._redis.publish(
+ self._user_event_key(self._username, event_id), json.dumps(data_obj)
+ )
+
+ def publish_event_data(self, event_id, data_obj):
+ """
Publishes the serialized form of the data object for the given event. Note that this occurs
in a thread to prevent blocking.
"""
- def conduct():
- try:
- self.publish_event_data_sync(event_id, data_obj)
- logger.debug('Published user event %s: %s', event_id, data_obj)
- except redis.RedisError:
- logger.exception('Could not publish user event')
- thread = threading.Thread(target=conduct)
- thread.start()
+ def conduct():
+ try:
+ self.publish_event_data_sync(event_id, data_obj)
+ logger.debug("Published user event %s: %s", event_id, data_obj)
+ except redis.RedisError:
+ logger.exception("Could not publish user event")
+
+ thread = threading.Thread(target=conduct)
+ thread.start()
class UserEventListener(object):
- """
+ """
Defines a helper class for subscribing to realtime user events as
backed by Redis.
"""
- def __init__(self, redis_config, username, events=None):
- events = events or set([])
- channels = [self._user_event_key(username, e) for e in events]
- args = dict(redis_config)
- args.update({'socket_connect_timeout': 5,
- 'single_connection_client': True})
+ def __init__(self, redis_config, username, events=None):
+ events = events or set([])
+ channels = [self._user_event_key(username, e) for e in events]
- try:
- self._redis = redis.StrictRedis(**args)
- self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True)
- self._pubsub.subscribe(channels)
- except redis.RedisError as re:
- logger.exception('Could not reach user events redis: %s', re)
- raise CannotReadUserEventsException
+ args = dict(redis_config)
+ args.update({"socket_connect_timeout": 5, "single_connection_client": True})
- @staticmethod
- def _user_event_key(username, event_id):
- return 'user/%s/events/%s' % (username, event_id)
+ try:
+ self._redis = redis.StrictRedis(**args)
+ self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True)
+ self._pubsub.subscribe(channels)
+ except redis.RedisError as re:
+ logger.exception("Could not reach user events redis: %s", re)
+ raise CannotReadUserEventsException
- def event_stream(self):
- """
+ @staticmethod
+ def _user_event_key(username, event_id):
+ return "user/%s/events/%s" % (username, event_id)
+
+ def event_stream(self):
+ """
Starts listening for events on the channel(s), yielding for each event
found. Will yield a "pulse" event (a custom event we've decided) as a heartbeat
every few seconds.
"""
- while True:
- pubsub = self._pubsub
- if pubsub is None:
- raise StopIteration
+ while True:
+ pubsub = self._pubsub
+ if pubsub is None:
+ raise StopIteration
- try:
- item = pubsub.get_message(ignore_subscribe_messages=True, timeout=5)
- except redis.RedisError:
- item = None
+ try:
+ item = pubsub.get_message(ignore_subscribe_messages=True, timeout=5)
+ except redis.RedisError:
+ item = None
- if item is None:
- yield 'pulse', {}
- else:
- channel = item['channel']
- event_id = channel.split('/')[3] # user/{username}/{events}/{id}
- data = None
+ if item is None:
+ yield "pulse", {}
+ else:
+ channel = item["channel"]
+ event_id = channel.split("/")[3] # user/{username}/{events}/{id}
+ data = None
- try:
- data = json.loads(item['data'] or '{}')
- except ValueError:
- continue
+ try:
+ data = json.loads(item["data"] or "{}")
+ except ValueError:
+ continue
- if data:
- yield event_id, data
+ if data:
+ yield event_id, data
- def stop(self):
- """
+ def stop(self):
+ """
Unsubscribes from the channel(s). Should be called once the connection
has terminated.
"""
- if self._pubsub is not None:
- self._pubsub.unsubscribe()
- self._pubsub.close()
- if self._redis is not None:
- self._redis.close()
+ if self._pubsub is not None:
+ self._pubsub.unsubscribe()
+ self._pubsub.close()
+ if self._redis is not None:
+ self._redis.close()
- self._pubsub = None
- self._redis = None
+ self._pubsub = None
+ self._redis = None
diff --git a/data/userfiles.py b/data/userfiles.py
index 1803c94ef..617e3dd7a 100644
--- a/data/userfiles.py
+++ b/data/userfiles.py
@@ -16,145 +16,183 @@ logger = logging.getLogger(__name__)
class UserfilesHandlers(View):
- methods = ['GET', 'PUT']
+ methods = ["GET", "PUT"]
- def __init__(self, distributed_storage, location, files):
- self._storage = distributed_storage
- self._files = files
- self._locations = {location}
- self._magic = magic.Magic(mime=True)
+ def __init__(self, distributed_storage, location, files):
+ self._storage = distributed_storage
+ self._files = files
+ self._locations = {location}
+ self._magic = magic.Magic(mime=True)
- def get(self, file_id):
- path = self._files.get_file_id_path(file_id)
- try:
- file_stream = self._storage.stream_read_file(self._locations, path)
- buffered = BufferedReader(file_stream)
- file_header_bytes = buffered.peek(1024)
- return send_file(buffered, mimetype=self._magic.from_buffer(file_header_bytes),
- as_attachment=True, attachment_filename=file_id)
- except IOError:
- logger.exception('Error reading user file')
- abort(404)
+ def get(self, file_id):
+ path = self._files.get_file_id_path(file_id)
+ try:
+ file_stream = self._storage.stream_read_file(self._locations, path)
+ buffered = BufferedReader(file_stream)
+ file_header_bytes = buffered.peek(1024)
+ return send_file(
+ buffered,
+ mimetype=self._magic.from_buffer(file_header_bytes),
+ as_attachment=True,
+ attachment_filename=file_id,
+ )
+ except IOError:
+ logger.exception("Error reading user file")
+ abort(404)
- def put(self, file_id):
- input_stream = request.stream
- if request.headers.get('transfer-encoding') == 'chunked':
- # Careful, might work only with WSGI servers supporting chunked
- # encoding (Gunicorn)
- input_stream = request.environ['wsgi.input']
+ def put(self, file_id):
+ input_stream = request.stream
+ if request.headers.get("transfer-encoding") == "chunked":
+ # Careful, might work only with WSGI servers supporting chunked
+ # encoding (Gunicorn)
+ input_stream = request.environ["wsgi.input"]
- c_type = request.headers.get('Content-Type', None)
+ c_type = request.headers.get("Content-Type", None)
- path = self._files.get_file_id_path(file_id)
- self._storage.stream_write(self._locations, path, input_stream, c_type)
+ path = self._files.get_file_id_path(file_id)
+ self._storage.stream_write(self._locations, path, input_stream, c_type)
- return make_response('Okay')
+ return make_response("Okay")
- def dispatch_request(self, file_id):
- if request.method == 'GET':
- return self.get(file_id)
- elif request.method == 'PUT':
- return self.put(file_id)
+ def dispatch_request(self, file_id):
+ if request.method == "GET":
+ return self.get(file_id)
+ elif request.method == "PUT":
+ return self.put(file_id)
class MissingHandlerException(Exception):
- pass
+ pass
class DelegateUserfiles(object):
- def __init__(self, app, distributed_storage, location, path, handler_name=None):
- self._app = app
- self._storage = distributed_storage
- self._locations = {location}
- self._prefix = path
- self._handler_name = handler_name
+ def __init__(self, app, distributed_storage, location, path, handler_name=None):
+ self._app = app
+ self._storage = distributed_storage
+ self._locations = {location}
+ self._prefix = path
+ self._handler_name = handler_name
- def _build_url_adapter(self):
- return self._app.url_map.bind(self._app.config['SERVER_HOSTNAME'],
- script_name=self._app.config['APPLICATION_ROOT'] or '/',
- url_scheme=self._app.config['PREFERRED_URL_SCHEME'])
+ def _build_url_adapter(self):
+ return self._app.url_map.bind(
+ self._app.config["SERVER_HOSTNAME"],
+ script_name=self._app.config["APPLICATION_ROOT"] or "/",
+ url_scheme=self._app.config["PREFERRED_URL_SCHEME"],
+ )
- def get_file_id_path(self, file_id):
- # Note: We use basename here to prevent paths with ..'s and absolute paths.
- return os.path.join(self._prefix or '', os.path.basename(file_id))
+ def get_file_id_path(self, file_id):
+ # Note: We use basename here to prevent paths with ..'s and absolute paths.
+ return os.path.join(self._prefix or "", os.path.basename(file_id))
- def prepare_for_drop(self, mime_type, requires_cors=True):
- """ Returns a signed URL to upload a file to our bucket. """
- logger.debug('Requested upload url with content type: %s' % mime_type)
- file_id = str(uuid4())
- path = self.get_file_id_path(file_id)
- url = self._storage.get_direct_upload_url(self._locations, path, mime_type, requires_cors)
+ def prepare_for_drop(self, mime_type, requires_cors=True):
+ """ Returns a signed URL to upload a file to our bucket. """
+ logger.debug("Requested upload url with content type: %s" % mime_type)
+ file_id = str(uuid4())
+ path = self.get_file_id_path(file_id)
+ url = self._storage.get_direct_upload_url(
+ self._locations, path, mime_type, requires_cors
+ )
- if url is None:
- if self._handler_name is None:
- raise MissingHandlerException()
+ if url is None:
+ if self._handler_name is None:
+ raise MissingHandlerException()
- with self._app.app_context() as ctx:
- ctx.url_adapter = self._build_url_adapter()
- file_relative_url = url_for(self._handler_name, file_id=file_id)
- file_url = urlparse.urljoin(get_app_url(self._app.config), file_relative_url)
- return (file_url, file_id)
+ with self._app.app_context() as ctx:
+ ctx.url_adapter = self._build_url_adapter()
+ file_relative_url = url_for(self._handler_name, file_id=file_id)
+ file_url = urlparse.urljoin(
+ get_app_url(self._app.config), file_relative_url
+ )
+ return (file_url, file_id)
- return (url, file_id)
+ return (url, file_id)
- def store_file(self, file_like_obj, content_type, content_encoding=None, file_id=None):
- if file_id is None:
- file_id = str(uuid4())
+ def store_file(
+ self, file_like_obj, content_type, content_encoding=None, file_id=None
+ ):
+ if file_id is None:
+ file_id = str(uuid4())
- path = self.get_file_id_path(file_id)
- self._storage.stream_write(self._locations, path, file_like_obj, content_type,
- content_encoding)
- return file_id
+ path = self.get_file_id_path(file_id)
+ self._storage.stream_write(
+ self._locations, path, file_like_obj, content_type, content_encoding
+ )
+ return file_id
- def get_file_url(self, file_id, remote_ip, expires_in=300, requires_cors=False):
- path = self.get_file_id_path(file_id)
- url = self._storage.get_direct_download_url(self._locations, path, remote_ip, expires_in,
- requires_cors)
+ def get_file_url(self, file_id, remote_ip, expires_in=300, requires_cors=False):
+ path = self.get_file_id_path(file_id)
+ url = self._storage.get_direct_download_url(
+ self._locations, path, remote_ip, expires_in, requires_cors
+ )
- if url is None:
- if self._handler_name is None:
- raise MissingHandlerException()
+ if url is None:
+ if self._handler_name is None:
+ raise MissingHandlerException()
- with self._app.app_context() as ctx:
- ctx.url_adapter = self._build_url_adapter()
- file_relative_url = url_for(self._handler_name, file_id=file_id)
- return urlparse.urljoin(get_app_url(self._app.config), file_relative_url)
+ with self._app.app_context() as ctx:
+ ctx.url_adapter = self._build_url_adapter()
+ file_relative_url = url_for(self._handler_name, file_id=file_id)
+ return urlparse.urljoin(
+ get_app_url(self._app.config), file_relative_url
+ )
- return url
+ return url
- def get_file_checksum(self, file_id):
- path = self.get_file_id_path(file_id)
- return self._storage.get_checksum(self._locations, path)
+ def get_file_checksum(self, file_id):
+ path = self.get_file_id_path(file_id)
+ return self._storage.get_checksum(self._locations, path)
class Userfiles(object):
- def __init__(self, app=None, distributed_storage=None, path='userfiles',
- handler_name='userfiles_handler'):
- self.app = app
- if app is not None:
- self.state = self.init_app(app, distributed_storage, path=path, handler_name=handler_name)
- else:
- self.state = None
+ def __init__(
+ self,
+ app=None,
+ distributed_storage=None,
+ path="userfiles",
+ handler_name="userfiles_handler",
+ ):
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(
+ app, distributed_storage, path=path, handler_name=handler_name
+ )
+ else:
+ self.state = None
- def init_app(self, app, distributed_storage, path='userfiles', handler_name='userfiles_handler'):
- location = app.config.get('USERFILES_LOCATION')
- userfiles_path = app.config.get('USERFILES_PATH', None)
+ def init_app(
+ self,
+ app,
+ distributed_storage,
+ path="userfiles",
+ handler_name="userfiles_handler",
+ ):
+ location = app.config.get("USERFILES_LOCATION")
+ userfiles_path = app.config.get("USERFILES_PATH", None)
- if userfiles_path is not None:
- userfiles = DelegateUserfiles(app, distributed_storage, location, userfiles_path,
- handler_name=handler_name)
+ if userfiles_path is not None:
+ userfiles = DelegateUserfiles(
+ app,
+ distributed_storage,
+ location,
+ userfiles_path,
+ handler_name=handler_name,
+ )
- app.add_url_rule('/%s/' % path,
- view_func=UserfilesHandlers.as_view(handler_name,
- distributed_storage=distributed_storage,
- location=location,
- files=userfiles))
+ app.add_url_rule(
+ '/%s/' % path,
+ view_func=UserfilesHandlers.as_view(
+ handler_name,
+ distributed_storage=distributed_storage,
+ location=location,
+ files=userfiles,
+ ),
+ )
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['userfiles'] = userfiles
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["userfiles"] = userfiles
- return userfiles
+ return userfiles
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
diff --git a/data/users/__init__.py b/data/users/__init__.py
index 78e025028..b9eb748f4 100644
--- a/data/users/__init__.py
+++ b/data/users/__init__.py
@@ -17,165 +17,193 @@ from util.security.secret import convert_secret_key
logger = logging.getLogger(__name__)
+
def get_federated_service_name(authentication_type):
- if authentication_type == 'LDAP':
- return 'ldap'
+ if authentication_type == "LDAP":
+ return "ldap"
- if authentication_type == 'JWT':
- return 'jwtauthn'
+ if authentication_type == "JWT":
+ return "jwtauthn"
- if authentication_type == 'Keystone':
- return 'keystone'
+ if authentication_type == "Keystone":
+ return "keystone"
- if authentication_type == 'AppToken':
- return None
+ if authentication_type == "AppToken":
+ return None
- if authentication_type == 'Database':
- return None
+ if authentication_type == "Database":
+ return None
- raise Exception('Unknown auth type: %s' % authentication_type)
+ raise Exception("Unknown auth type: %s" % authentication_type)
-LDAP_CERT_FILENAME = 'ldap.crt'
+LDAP_CERT_FILENAME = "ldap.crt"
+
def get_users_handler(config, _, override_config_dir):
- """ Returns a users handler for the authentication configured in the given config object. """
- authentication_type = config.get('AUTHENTICATION_TYPE', 'Database')
+ """ Returns a users handler for the authentication configured in the given config object. """
+ authentication_type = config.get("AUTHENTICATION_TYPE", "Database")
- if authentication_type == 'Database':
- return DatabaseUsers()
+ if authentication_type == "Database":
+ return DatabaseUsers()
- if authentication_type == 'LDAP':
- ldap_uri = config.get('LDAP_URI', 'ldap://localhost')
- base_dn = config.get('LDAP_BASE_DN')
- admin_dn = config.get('LDAP_ADMIN_DN')
- admin_passwd = config.get('LDAP_ADMIN_PASSWD')
- user_rdn = config.get('LDAP_USER_RDN', [])
- uid_attr = config.get('LDAP_UID_ATTR', 'uid')
- email_attr = config.get('LDAP_EMAIL_ATTR', 'mail')
- secondary_user_rdns = config.get('LDAP_SECONDARY_USER_RDNS', [])
- timeout = config.get('LDAP_TIMEOUT')
- network_timeout = config.get('LDAP_NETWORK_TIMEOUT')
+ if authentication_type == "LDAP":
+ ldap_uri = config.get("LDAP_URI", "ldap://localhost")
+ base_dn = config.get("LDAP_BASE_DN")
+ admin_dn = config.get("LDAP_ADMIN_DN")
+ admin_passwd = config.get("LDAP_ADMIN_PASSWD")
+ user_rdn = config.get("LDAP_USER_RDN", [])
+ uid_attr = config.get("LDAP_UID_ATTR", "uid")
+ email_attr = config.get("LDAP_EMAIL_ATTR", "mail")
+ secondary_user_rdns = config.get("LDAP_SECONDARY_USER_RDNS", [])
+ timeout = config.get("LDAP_TIMEOUT")
+ network_timeout = config.get("LDAP_NETWORK_TIMEOUT")
- allow_tls_fallback = config.get('LDAP_ALLOW_INSECURE_FALLBACK', False)
- return LDAPUsers(ldap_uri, base_dn, admin_dn, admin_passwd, user_rdn, uid_attr, email_attr,
- allow_tls_fallback, secondary_user_rdns=secondary_user_rdns,
- requires_email=features.MAILING, timeout=timeout,
- network_timeout=network_timeout)
+ allow_tls_fallback = config.get("LDAP_ALLOW_INSECURE_FALLBACK", False)
+ return LDAPUsers(
+ ldap_uri,
+ base_dn,
+ admin_dn,
+ admin_passwd,
+ user_rdn,
+ uid_attr,
+ email_attr,
+ allow_tls_fallback,
+ secondary_user_rdns=secondary_user_rdns,
+ requires_email=features.MAILING,
+ timeout=timeout,
+ network_timeout=network_timeout,
+ )
- if authentication_type == 'JWT':
- verify_url = config.get('JWT_VERIFY_ENDPOINT')
- issuer = config.get('JWT_AUTH_ISSUER')
- max_fresh_s = config.get('JWT_AUTH_MAX_FRESH_S', 300)
+ if authentication_type == "JWT":
+ verify_url = config.get("JWT_VERIFY_ENDPOINT")
+ issuer = config.get("JWT_AUTH_ISSUER")
+ max_fresh_s = config.get("JWT_AUTH_MAX_FRESH_S", 300)
- query_url = config.get('JWT_QUERY_ENDPOINT', None)
- getuser_url = config.get('JWT_GETUSER_ENDPOINT', None)
+ query_url = config.get("JWT_QUERY_ENDPOINT", None)
+ getuser_url = config.get("JWT_GETUSER_ENDPOINT", None)
- return ExternalJWTAuthN(verify_url, query_url, getuser_url, issuer, override_config_dir,
- config['HTTPCLIENT'], max_fresh_s,
- requires_email=features.MAILING)
+ return ExternalJWTAuthN(
+ verify_url,
+ query_url,
+ getuser_url,
+ issuer,
+ override_config_dir,
+ config["HTTPCLIENT"],
+ max_fresh_s,
+ requires_email=features.MAILING,
+ )
- if authentication_type == 'Keystone':
- auth_url = config.get('KEYSTONE_AUTH_URL')
- auth_version = int(config.get('KEYSTONE_AUTH_VERSION', 2))
- timeout = config.get('KEYSTONE_AUTH_TIMEOUT')
- keystone_admin_username = config.get('KEYSTONE_ADMIN_USERNAME')
- keystone_admin_password = config.get('KEYSTONE_ADMIN_PASSWORD')
- keystone_admin_tenant = config.get('KEYSTONE_ADMIN_TENANT')
- return get_keystone_users(auth_version, auth_url, keystone_admin_username,
- keystone_admin_password, keystone_admin_tenant, timeout,
- requires_email=features.MAILING)
+ if authentication_type == "Keystone":
+ auth_url = config.get("KEYSTONE_AUTH_URL")
+ auth_version = int(config.get("KEYSTONE_AUTH_VERSION", 2))
+ timeout = config.get("KEYSTONE_AUTH_TIMEOUT")
+ keystone_admin_username = config.get("KEYSTONE_ADMIN_USERNAME")
+ keystone_admin_password = config.get("KEYSTONE_ADMIN_PASSWORD")
+ keystone_admin_tenant = config.get("KEYSTONE_ADMIN_TENANT")
+ return get_keystone_users(
+ auth_version,
+ auth_url,
+ keystone_admin_username,
+ keystone_admin_password,
+ keystone_admin_tenant,
+ timeout,
+ requires_email=features.MAILING,
+ )
- if authentication_type == 'AppToken':
- if features.DIRECT_LOGIN:
- raise Exception('Direct login feature must be disabled to use AppToken internal auth')
+ if authentication_type == "AppToken":
+ if features.DIRECT_LOGIN:
+ raise Exception(
+ "Direct login feature must be disabled to use AppToken internal auth"
+ )
- if not features.APP_SPECIFIC_TOKENS:
- raise Exception('AppToken internal auth requires app specific token support to be enabled')
+ if not features.APP_SPECIFIC_TOKENS:
+ raise Exception(
+ "AppToken internal auth requires app specific token support to be enabled"
+ )
- return AppTokenInternalAuth()
+ return AppTokenInternalAuth()
+
+ raise RuntimeError("Unknown authentication type: %s" % authentication_type)
- raise RuntimeError('Unknown authentication type: %s' % authentication_type)
class UserAuthentication(object):
- def __init__(self, app=None, config_provider=None, override_config_dir=None):
- self.secret_key = None
- self.app = app
- if app is not None:
- self.state = self.init_app(app, config_provider, override_config_dir)
- else:
- self.state = None
+ def __init__(self, app=None, config_provider=None, override_config_dir=None):
+ self.secret_key = None
+ self.app = app
+ if app is not None:
+ self.state = self.init_app(app, config_provider, override_config_dir)
+ else:
+ self.state = None
- def init_app(self, app, config_provider, override_config_dir):
- self.secret_key = convert_secret_key(app.config['SECRET_KEY'])
- users = get_users_handler(app.config, config_provider, override_config_dir)
+ def init_app(self, app, config_provider, override_config_dir):
+ self.secret_key = convert_secret_key(app.config["SECRET_KEY"])
+ users = get_users_handler(app.config, config_provider, override_config_dir)
- # register extension with app
- app.extensions = getattr(app, 'extensions', {})
- app.extensions['authentication'] = users
+ # register extension with app
+ app.extensions = getattr(app, "extensions", {})
+ app.extensions["authentication"] = users
- return users
+ return users
- def encrypt_user_password(self, password):
- """ Returns an encrypted version of the user's password. """
- data = {
- 'password': password
- }
+ def encrypt_user_password(self, password):
+ """ Returns an encrypted version of the user's password. """
+ data = {"password": password}
- message = json.dumps(data)
- cipher = AESCipher(self.secret_key)
- return cipher.encrypt(message)
+ message = json.dumps(data)
+ cipher = AESCipher(self.secret_key)
+ return cipher.encrypt(message)
- def _decrypt_user_password(self, encrypted):
- """ Attempts to decrypt the given password and returns it. """
- cipher = AESCipher(self.secret_key)
+ def _decrypt_user_password(self, encrypted):
+ """ Attempts to decrypt the given password and returns it. """
+ cipher = AESCipher(self.secret_key)
- try:
- message = cipher.decrypt(encrypted)
- except ValueError:
- return None
- except TypeError:
- return None
+ try:
+ message = cipher.decrypt(encrypted)
+ except ValueError:
+ return None
+ except TypeError:
+ return None
- try:
- data = json.loads(message)
- except ValueError:
- return None
+ try:
+ data = json.loads(message)
+ except ValueError:
+ return None
- return data.get('password', encrypted)
+ return data.get("password", encrypted)
- def ping(self):
- """ Returns whether the authentication engine is reachable and working. """
- return self.state.ping()
+ def ping(self):
+ """ Returns whether the authentication engine is reachable and working. """
+ return self.state.ping()
- @property
- def federated_service(self):
- """ Returns the name of the federated service for the auth system. If none, should return None.
+ @property
+ def federated_service(self):
+ """ Returns the name of the federated service for the auth system. If none, should return None.
"""
- return self.state.federated_service
+ return self.state.federated_service
- @property
- def requires_distinct_cli_password(self):
- """ Returns whether this auth system requires a distinct CLI password to be created,
+ @property
+ def requires_distinct_cli_password(self):
+ """ Returns whether this auth system requires a distinct CLI password to be created,
in-system, before the CLI can be used. """
- return self.state.requires_distinct_cli_password
+ return self.state.requires_distinct_cli_password
- @property
- def supports_encrypted_credentials(self):
- """ Returns whether this auth system supports using encrypted credentials. """
- return self.state.supports_encrypted_credentials
+ @property
+ def supports_encrypted_credentials(self):
+ """ Returns whether this auth system supports using encrypted credentials. """
+ return self.state.supports_encrypted_credentials
- def has_password_set(self, username):
- """ Returns whether the user has a password set in the auth system. """
- return self.state.has_password_set(username)
+ def has_password_set(self, username):
+ """ Returns whether the user has a password set in the auth system. """
+ return self.state.has_password_set(username)
- @property
- def supports_fresh_login(self):
- """ Returns whether this auth system supports the fresh login check. """
- return self.state.supports_fresh_login
+ @property
+ def supports_fresh_login(self):
+ """ Returns whether this auth system supports the fresh login check. """
+ return self.state.supports_fresh_login
- def query_users(self, query, limit=20):
- """ Performs a lookup against the user system for the specified query. The returned tuple
+ def query_users(self, query, limit=20):
+ """ Performs a lookup against the user system for the specified query. The returned tuple
will be of the form (results, federated_login_id, err_msg). If the method is unsupported,
the results portion of the tuple will be None instead of empty list.
@@ -185,79 +213,91 @@ class UserAuthentication(object):
Results will be in the form of objects's with username and email fields.
"""
- return self.state.query_users(query, limit)
+ return self.state.query_users(query, limit)
- def link_user(self, username_or_email):
- """ Returns a tuple containing the database user record linked to the given username/email
+ def link_user(self, username_or_email):
+ """ Returns a tuple containing the database user record linked to the given username/email
and any error that occurred when trying to link the user.
"""
- return self.state.link_user(username_or_email)
+ return self.state.link_user(username_or_email)
- def get_and_link_federated_user_info(self, user_info, internal_create=False):
- """ Returns a tuple containing the database user record linked to the given UserInformation
+ def get_and_link_federated_user_info(self, user_info, internal_create=False):
+ """ Returns a tuple containing the database user record linked to the given UserInformation
pair and any error that occurred when trying to link the user.
If `internal_create` is True, the caller is an internal user creation process (such
as team syncing), and the "can a user be created" check will be bypassed.
"""
- return self.state.get_and_link_federated_user_info(user_info, internal_create=internal_create)
+ return self.state.get_and_link_federated_user_info(
+ user_info, internal_create=internal_create
+ )
- def confirm_existing_user(self, username, password):
- """ Verifies that the given password matches to the given DB username. Unlike
+ def confirm_existing_user(self, username, password):
+ """ Verifies that the given password matches to the given DB username. Unlike
verify_credentials, this call first translates the DB user via the FederatedLogin table
(where applicable).
"""
- return self.state.confirm_existing_user(username, password)
+ return self.state.confirm_existing_user(username, password)
- def verify_credentials(self, username_or_email, password):
- """ Verifies that the given username and password credentials are valid. """
- return self.state.verify_credentials(username_or_email, password)
+ def verify_credentials(self, username_or_email, password):
+ """ Verifies that the given username and password credentials are valid. """
+ return self.state.verify_credentials(username_or_email, password)
- def check_group_lookup_args(self, group_lookup_args):
- """ Verifies that the given group lookup args point to a valid group. Returns a tuple consisting
+ def check_group_lookup_args(self, group_lookup_args):
+ """ Verifies that the given group lookup args point to a valid group. Returns a tuple consisting
of a boolean status and an error message (if any).
"""
- return self.state.check_group_lookup_args(group_lookup_args)
+ return self.state.check_group_lookup_args(group_lookup_args)
- def service_metadata(self):
- """ Returns a dictionary of extra metadata to present to *superusers* about this auth engine.
+ def service_metadata(self):
+ """ Returns a dictionary of extra metadata to present to *superusers* about this auth engine.
For example, LDAP returns the base DN so we can display to the user during sync setup.
"""
- return self.state.service_metadata()
+ return self.state.service_metadata()
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- """ Returns a tuple of an iterator over all the members of the group matching the given lookup
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ """ Returns a tuple of an iterator over all the members of the group matching the given lookup
args dictionary, or the error that occurred if the initial call failed or is unsupported.
The format of the lookup args dictionary is specific to the implementation.
Each result in the iterator is a tuple of (UserInformation, error_message), and only
one will be not-None.
"""
- return self.state.iterate_group_members(group_lookup_args, page_size=page_size,
- disable_pagination=disable_pagination)
+ return self.state.iterate_group_members(
+ group_lookup_args,
+ page_size=page_size,
+ disable_pagination=disable_pagination,
+ )
- def verify_and_link_user(self, username_or_email, password, basic_auth=False):
- """ Verifies that the given username and password credentials are valid and, if so,
+ def verify_and_link_user(self, username_or_email, password, basic_auth=False):
+ """ Verifies that the given username and password credentials are valid and, if so,
creates or links the database user to the federated identity. """
- # First try to decode the password as a signed token.
- if basic_auth:
- decrypted = self._decrypt_user_password(password)
- if decrypted is None:
- # This is a normal password.
- if features.REQUIRE_ENCRYPTED_BASIC_AUTH:
- msg = ('Client login with unencrypted passwords is disabled. Please generate an ' +
- 'encrypted password in the user admin panel for use here.')
- return (None, msg)
- else:
- password = decrypted
+ # First try to decode the password as a signed token.
+ if basic_auth:
+ decrypted = self._decrypt_user_password(password)
+ if decrypted is None:
+ # This is a normal password.
+ if features.REQUIRE_ENCRYPTED_BASIC_AUTH:
+ msg = (
+ "Client login with unencrypted passwords is disabled. Please generate an "
+ + "encrypted password in the user admin panel for use here."
+ )
+ return (None, msg)
+ else:
+ password = decrypted
- (result, err_msg) = self.state.verify_and_link_user(username_or_email, password)
- if not result:
- return (result, err_msg)
+ (result, err_msg) = self.state.verify_and_link_user(username_or_email, password)
+ if not result:
+ return (result, err_msg)
- if not result.enabled:
- return (None, 'This user has been disabled. Please contact your administrator.')
+ if not result.enabled:
+ return (
+ None,
+ "This user has been disabled. Please contact your administrator.",
+ )
- return (result, err_msg)
+ return (result, err_msg)
- def __getattr__(self, name):
- return getattr(self.state, name, None)
+ def __getattr__(self, name):
+ return getattr(self.state, name, None)
diff --git a/data/users/apptoken.py b/data/users/apptoken.py
index c306e7064..3aa8b09bb 100644
--- a/data/users/apptoken.py
+++ b/data/users/apptoken.py
@@ -8,60 +8,64 @@ from util.security.jwtutil import InvalidTokenError
logger = logging.getLogger(__name__)
+
class AppTokenInternalAuth(object):
- """ Forces all internal credential login to go through an app token, by disabling all other
+ """ Forces all internal credential login to go through an app token, by disabling all other
access.
"""
- @property
- def supports_fresh_login(self):
- # Since there is no password.
- return False
- @property
- def federated_service(self):
- return None
+ @property
+ def supports_fresh_login(self):
+ # Since there is no password.
+ return False
- @property
- def requires_distinct_cli_password(self):
- # Since there is no supported "password".
- return False
+ @property
+ def federated_service(self):
+ return None
- def has_password_set(self, username):
- # Since there is no supported "password".
- return False
+ @property
+ def requires_distinct_cli_password(self):
+ # Since there is no supported "password".
+ return False
- @property
- def supports_encrypted_credentials(self):
- # Since there is no supported "password".
- return False
+ def has_password_set(self, username):
+ # Since there is no supported "password".
+ return False
- def verify_credentials(self, username_or_email, id_token):
- return (None, 'An application specific token is required to login')
+ @property
+ def supports_encrypted_credentials(self):
+ # Since there is no supported "password".
+ return False
- def verify_and_link_user(self, username_or_email, password):
- return self.verify_credentials(username_or_email, password)
+ def verify_credentials(self, username_or_email, id_token):
+ return (None, "An application specific token is required to login")
- def confirm_existing_user(self, username, password):
- return self.verify_credentials(username, password)
+ def verify_and_link_user(self, username_or_email, password):
+ return self.verify_credentials(username_or_email, password)
- def link_user(self, username_or_email):
- return (None, 'Unsupported for this authentication system')
+ def confirm_existing_user(self, username, password):
+ return self.verify_credentials(username, password)
- def get_and_link_federated_user_info(self, user_info):
- return (None, 'Unsupported for this authentication system')
+ def link_user(self, username_or_email):
+ return (None, "Unsupported for this authentication system")
- def query_users(self, query, limit):
- return (None, '', '')
+ def get_and_link_federated_user_info(self, user_info):
+ return (None, "Unsupported for this authentication system")
- def check_group_lookup_args(self, group_lookup_args):
- return (False, 'Not supported')
+ def query_users(self, query, limit):
+ return (None, "", "")
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- return (None, 'Not supported')
+ def check_group_lookup_args(self, group_lookup_args):
+ return (False, "Not supported")
- def service_metadata(self):
- return {}
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ return (None, "Not supported")
- def ping(self):
- """ Always assumed to be working. If the DB is broken, other checks will handle it. """
- return (True, None)
+ def service_metadata(self):
+ return {}
+
+ def ping(self):
+ """ Always assumed to be working. If the DB is broken, other checks will handle it. """
+ return (True, None)
diff --git a/data/users/database.py b/data/users/database.py
index 2a1780429..581cdff49 100644
--- a/data/users/database.py
+++ b/data/users/database.py
@@ -1,66 +1,69 @@
from data import model
+
class DatabaseUsers(object):
- @property
- def federated_service(self):
- return None
+ @property
+ def federated_service(self):
+ return None
- @property
- def supports_fresh_login(self):
- return True
+ @property
+ def supports_fresh_login(self):
+ return True
- def ping(self):
- """ Always assumed to be working. If the DB is broken, other checks will handle it. """
- return (True, None)
+ def ping(self):
+ """ Always assumed to be working. If the DB is broken, other checks will handle it. """
+ return (True, None)
- @property
- def supports_encrypted_credentials(self):
- return True
+ @property
+ def supports_encrypted_credentials(self):
+ return True
- def has_password_set(self, username):
- user = model.user.get_user(username)
- return user and user.password_hash is not None
+ def has_password_set(self, username):
+ user = model.user.get_user(username)
+ return user and user.password_hash is not None
- @property
- def requires_distinct_cli_password(self):
- # Since the database stores its own password.
- return True
+ @property
+ def requires_distinct_cli_password(self):
+ # Since the database stores its own password.
+ return True
- def verify_credentials(self, username_or_email, password):
- """ Simply delegate to the model implementation. """
- result = model.user.verify_user(username_or_email, password)
- if not result:
- return (None, 'Invalid Username or Password')
+ def verify_credentials(self, username_or_email, password):
+ """ Simply delegate to the model implementation. """
+ result = model.user.verify_user(username_or_email, password)
+ if not result:
+ return (None, "Invalid Username or Password")
- return (result, None)
+ return (result, None)
- def verify_and_link_user(self, username_or_email, password):
- """ Simply delegate to the model implementation. """
- return self.verify_credentials(username_or_email, password)
+ def verify_and_link_user(self, username_or_email, password):
+ """ Simply delegate to the model implementation. """
+ return self.verify_credentials(username_or_email, password)
- def confirm_existing_user(self, username, password):
- return self.verify_credentials(username, password)
+ def confirm_existing_user(self, username, password):
+ return self.verify_credentials(username, password)
- def link_user(self, username_or_email):
- """ Never used since all users being added are already, by definition, in the database. """
- return (None, 'Unsupported for this authentication system')
+ def link_user(self, username_or_email):
+ """ Never used since all users being added are already, by definition, in the database. """
+ return (None, "Unsupported for this authentication system")
- def get_and_link_federated_user_info(self, user_info, internal_create=False):
- """ Never used since all users being added are already, by definition, in the database. """
- return (None, 'Unsupported for this authentication system')
+ def get_and_link_federated_user_info(self, user_info, internal_create=False):
+ """ Never used since all users being added are already, by definition, in the database. """
+ return (None, "Unsupported for this authentication system")
- def query_users(self, query, limit):
- """ No need to implement, as we already query for users directly in the database. """
- return (None, '', '')
+ def query_users(self, query, limit):
+ """ No need to implement, as we already query for users directly in the database. """
+ return (None, "", "")
- def check_group_lookup_args(self, group_lookup_args):
- """ Never used since all groups, by definition, are in the database. """
- return (False, 'Not supported')
+ def check_group_lookup_args(self, group_lookup_args):
+ """ Never used since all groups, by definition, are in the database. """
+ return (False, "Not supported")
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- """ Never used since all groups, by definition, are in the database. """
- return (None, 'Not supported')
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ """ Never used since all groups, by definition, are in the database. """
+ return (None, "Not supported")
- def service_metadata(self):
- """ Never used since database has no metadata """
- return {}
+ def service_metadata(self):
+ """ Never used since database has no metadata """
+ return {}
diff --git a/data/users/externaljwt.py b/data/users/externaljwt.py
index 7f2fea255..0475e4581 100644
--- a/data/users/externaljwt.py
+++ b/data/users/externaljwt.py
@@ -10,119 +10,152 @@ logger = logging.getLogger(__name__)
class ExternalJWTAuthN(FederatedUsers):
- """ Delegates authentication to a REST endpoint that returns JWTs. """
- PUBLIC_KEY_FILENAME = 'jwt-authn.cert'
+ """ Delegates authentication to a REST endpoint that returns JWTs. """
- def __init__(self, verify_url, query_url, getuser_url, issuer, override_config_dir, http_client,
- max_fresh_s, public_key_path=None, requires_email=True):
- super(ExternalJWTAuthN, self).__init__('jwtauthn', requires_email)
- self.verify_url = verify_url
- self.query_url = query_url
- self.getuser_url = getuser_url
+ PUBLIC_KEY_FILENAME = "jwt-authn.cert"
- self.issuer = issuer
- self.client = http_client
- self.max_fresh_s = max_fresh_s
- self.requires_email = requires_email
+ def __init__(
+ self,
+ verify_url,
+ query_url,
+ getuser_url,
+ issuer,
+ override_config_dir,
+ http_client,
+ max_fresh_s,
+ public_key_path=None,
+ requires_email=True,
+ ):
+ super(ExternalJWTAuthN, self).__init__("jwtauthn", requires_email)
+ self.verify_url = verify_url
+ self.query_url = query_url
+ self.getuser_url = getuser_url
- default_key_path = os.path.join(override_config_dir, ExternalJWTAuthN.PUBLIC_KEY_FILENAME)
- public_key_path = public_key_path or default_key_path
- if not os.path.exists(public_key_path):
- error_message = ('JWT Authentication public key file "%s" not found' % public_key_path)
+ self.issuer = issuer
+ self.client = http_client
+ self.max_fresh_s = max_fresh_s
+ self.requires_email = requires_email
- raise Exception(error_message)
+ default_key_path = os.path.join(
+ override_config_dir, ExternalJWTAuthN.PUBLIC_KEY_FILENAME
+ )
+ public_key_path = public_key_path or default_key_path
+ if not os.path.exists(public_key_path):
+ error_message = (
+ 'JWT Authentication public key file "%s" not found' % public_key_path
+ )
- self.public_key_path = public_key_path
+ raise Exception(error_message)
- with open(public_key_path) as public_key_file:
- self.public_key = public_key_file.read()
+ self.public_key_path = public_key_path
- def has_password_set(self, username):
- return True
+ with open(public_key_path) as public_key_file:
+ self.public_key = public_key_file.read()
- def ping(self):
- result = self.client.get(self.getuser_url, timeout=2)
- # We expect a 401 or 403 of some kind, since we explicitly don't send an auth header
- if result.status_code // 100 != 4:
- return (False, result.text or 'Could not reach JWT authn endpoint')
+ def has_password_set(self, username):
+ return True
- return (True, None)
+ def ping(self):
+ result = self.client.get(self.getuser_url, timeout=2)
+ # We expect a 401 or 403 of some kind, since we explicitly don't send an auth header
+ if result.status_code // 100 != 4:
+ return (False, result.text or "Could not reach JWT authn endpoint")
- def get_user(self, username_or_email):
- if self.getuser_url is None:
- return (None, 'No endpoint defined for retrieving user')
+ return (True, None)
- (payload, err_msg) = self._execute_call(self.getuser_url, 'quay.io/jwtauthn/getuser',
- params=dict(username=username_or_email))
- if err_msg is not None:
- return (None, err_msg)
+ def get_user(self, username_or_email):
+ if self.getuser_url is None:
+ return (None, "No endpoint defined for retrieving user")
- if not 'sub' in payload:
- raise Exception('Missing sub field in JWT')
+ (payload, err_msg) = self._execute_call(
+ self.getuser_url,
+ "quay.io/jwtauthn/getuser",
+ params=dict(username=username_or_email),
+ )
+ if err_msg is not None:
+ return (None, err_msg)
- if self.requires_email and not 'email' in payload:
- raise Exception('Missing email field in JWT')
+ if not "sub" in payload:
+ raise Exception("Missing sub field in JWT")
- # Parse out the username and email.
- user_info = UserInformation(username=payload['sub'], email=payload.get('email'),
- id=payload['sub'])
- return (user_info, None)
+ if self.requires_email and not "email" in payload:
+ raise Exception("Missing email field in JWT")
+ # Parse out the username and email.
+ user_info = UserInformation(
+ username=payload["sub"], email=payload.get("email"), id=payload["sub"]
+ )
+ return (user_info, None)
- def query_users(self, query, limit=20):
- if self.query_url is None:
- return (None, self.federated_service, 'No endpoint defined for querying users')
+ def query_users(self, query, limit=20):
+ if self.query_url is None:
+ return (
+ None,
+ self.federated_service,
+ "No endpoint defined for querying users",
+ )
- (payload, err_msg) = self._execute_call(self.query_url, 'quay.io/jwtauthn/query',
- params=dict(query=query, limit=limit))
- if err_msg is not None:
- return (None, self.federated_service, err_msg)
+ (payload, err_msg) = self._execute_call(
+ self.query_url,
+ "quay.io/jwtauthn/query",
+ params=dict(query=query, limit=limit),
+ )
+ if err_msg is not None:
+ return (None, self.federated_service, err_msg)
- query_results = []
- for result in payload['results'][0:limit]:
- user_info = UserInformation(username=result['username'], email=result.get('email'),
- id=result['username'])
- query_results.append(user_info)
+ query_results = []
+ for result in payload["results"][0:limit]:
+ user_info = UserInformation(
+ username=result["username"],
+ email=result.get("email"),
+ id=result["username"],
+ )
+ query_results.append(user_info)
- return (query_results, self.federated_service, None)
+ return (query_results, self.federated_service, None)
+ def verify_credentials(self, username_or_email, password):
+ (payload, err_msg) = self._execute_call(
+ self.verify_url, "quay.io/jwtauthn", auth=(username_or_email, password)
+ )
+ if err_msg is not None:
+ return (None, err_msg)
- def verify_credentials(self, username_or_email, password):
- (payload, err_msg) = self._execute_call(self.verify_url, 'quay.io/jwtauthn',
- auth=(username_or_email, password))
- if err_msg is not None:
- return (None, err_msg)
+ if not "sub" in payload:
+ raise Exception("Missing sub field in JWT")
- if not 'sub' in payload:
- raise Exception('Missing sub field in JWT')
+ if self.requires_email and not "email" in payload:
+ raise Exception("Missing email field in JWT")
- if self.requires_email and not 'email' in payload:
- raise Exception('Missing email field in JWT')
+ user_info = UserInformation(
+ username=payload["sub"], email=payload.get("email"), id=payload["sub"]
+ )
+ return (user_info, None)
- user_info = UserInformation(username=payload['sub'], email=payload.get('email'),
- id=payload['sub'])
- return (user_info, None)
+ def _execute_call(self, url, aud, auth=None, params=None):
+ """ Executes a call to the external JWT auth provider. """
+ result = self.client.get(url, timeout=2, auth=auth, params=params)
+ if result.status_code != 200:
+ return (None, result.text or "Could not make JWT auth call")
+ try:
+ result_data = json.loads(result.text)
+ except ValueError:
+ raise Exception("Returned JWT body for url %s does not contain JSON", url)
- def _execute_call(self, url, aud, auth=None, params=None):
- """ Executes a call to the external JWT auth provider. """
- result = self.client.get(url, timeout=2, auth=auth, params=params)
- if result.status_code != 200:
- return (None, result.text or 'Could not make JWT auth call')
-
- try:
- result_data = json.loads(result.text)
- except ValueError:
- raise Exception('Returned JWT body for url %s does not contain JSON', url)
-
- # Load the JWT returned.
- encoded = result_data.get('token', '')
- exp_limit_options = jwtutil.exp_max_s_option(self.max_fresh_s)
- try:
- payload = jwtutil.decode(encoded, self.public_key, algorithms=['RS256'],
- audience=aud, issuer=self.issuer,
- options=exp_limit_options)
- return (payload, None)
- except jwtutil.InvalidTokenError:
- logger.exception('Exception when decoding returned JWT for url %s', url)
- return (None, 'Exception when decoding returned JWT')
+ # Load the JWT returned.
+ encoded = result_data.get("token", "")
+ exp_limit_options = jwtutil.exp_max_s_option(self.max_fresh_s)
+ try:
+ payload = jwtutil.decode(
+ encoded,
+ self.public_key,
+ algorithms=["RS256"],
+ audience=aud,
+ issuer=self.issuer,
+ options=exp_limit_options,
+ )
+ return (payload, None)
+ except jwtutil.InvalidTokenError:
+ logger.exception("Exception when decoding returned JWT for url %s", url)
+ return (None, "Exception when decoding returned JWT")
diff --git a/data/users/externalldap.py b/data/users/externalldap.py
index c1242e5d1..0fcfe81cf 100644
--- a/data/users/externalldap.py
+++ b/data/users/externalldap.py
@@ -11,403 +11,519 @@ from util.itertoolrecipes import take
logger = logging.getLogger(__name__)
-_DEFAULT_NETWORK_TIMEOUT = 10.0 # seconds
-_DEFAULT_TIMEOUT = 10.0 # seconds
+_DEFAULT_NETWORK_TIMEOUT = 10.0 # seconds
+_DEFAULT_TIMEOUT = 10.0 # seconds
_DEFAULT_PAGE_SIZE = 1000
class LDAPConnectionBuilder(object):
- def __init__(self, ldap_uri, user_dn, user_pw, allow_tls_fallback=False,
- timeout=None, network_timeout=None):
- self._ldap_uri = ldap_uri
- self._user_dn = user_dn
- self._user_pw = user_pw
- self._allow_tls_fallback = allow_tls_fallback
- self._timeout = timeout
- self._network_timeout = network_timeout
+ def __init__(
+ self,
+ ldap_uri,
+ user_dn,
+ user_pw,
+ allow_tls_fallback=False,
+ timeout=None,
+ network_timeout=None,
+ ):
+ self._ldap_uri = ldap_uri
+ self._user_dn = user_dn
+ self._user_pw = user_pw
+ self._allow_tls_fallback = allow_tls_fallback
+ self._timeout = timeout
+ self._network_timeout = network_timeout
- def get_connection(self):
- return LDAPConnection(self._ldap_uri, self._user_dn, self._user_pw, self._allow_tls_fallback,
- self._timeout, self._network_timeout)
+ def get_connection(self):
+ return LDAPConnection(
+ self._ldap_uri,
+ self._user_dn,
+ self._user_pw,
+ self._allow_tls_fallback,
+ self._timeout,
+ self._network_timeout,
+ )
class LDAPConnection(object):
- def __init__(self, ldap_uri, user_dn, user_pw, allow_tls_fallback=False,
- timeout=None, network_timeout=None):
- self._ldap_uri = ldap_uri
- self._user_dn = user_dn
- self._user_pw = user_pw
- self._allow_tls_fallback = allow_tls_fallback
- self._timeout = timeout
- self._network_timeout = network_timeout
- self._conn = None
+ def __init__(
+ self,
+ ldap_uri,
+ user_dn,
+ user_pw,
+ allow_tls_fallback=False,
+ timeout=None,
+ network_timeout=None,
+ ):
+ self._ldap_uri = ldap_uri
+ self._user_dn = user_dn
+ self._user_pw = user_pw
+ self._allow_tls_fallback = allow_tls_fallback
+ self._timeout = timeout
+ self._network_timeout = network_timeout
+ self._conn = None
- def __enter__(self):
- trace_level = 2 if os.environ.get('USERS_DEBUG') == '1' else 0
+ def __enter__(self):
+ trace_level = 2 if os.environ.get("USERS_DEBUG") == "1" else 0
- self._conn = ldap.initialize(self._ldap_uri, trace_level=trace_level)
- self._conn.set_option(ldap.OPT_REFERRALS, 1)
- self._conn.set_option(ldap.OPT_NETWORK_TIMEOUT,
- self._network_timeout or _DEFAULT_NETWORK_TIMEOUT)
- self._conn.set_option(ldap.OPT_TIMEOUT, self._timeout or _DEFAULT_TIMEOUT)
+ self._conn = ldap.initialize(self._ldap_uri, trace_level=trace_level)
+ self._conn.set_option(ldap.OPT_REFERRALS, 1)
+ self._conn.set_option(
+ ldap.OPT_NETWORK_TIMEOUT, self._network_timeout or _DEFAULT_NETWORK_TIMEOUT
+ )
+ self._conn.set_option(ldap.OPT_TIMEOUT, self._timeout or _DEFAULT_TIMEOUT)
- if self._allow_tls_fallback:
- logger.debug('TLS Fallback enabled in LDAP')
- self._conn.set_option(ldap.OPT_X_TLS_TRY, 1)
+ if self._allow_tls_fallback:
+ logger.debug("TLS Fallback enabled in LDAP")
+ self._conn.set_option(ldap.OPT_X_TLS_TRY, 1)
- self._conn.simple_bind_s(self._user_dn, self._user_pw)
- return self._conn
+ self._conn.simple_bind_s(self._user_dn, self._user_pw)
+ return self._conn
- def __exit__(self, exc_type, value, tb):
- self._conn.unbind_s()
+ def __exit__(self, exc_type, value, tb):
+ self._conn.unbind_s()
class LDAPUsers(FederatedUsers):
- _LDAPResult = namedtuple('LDAPResult', ['dn', 'attrs'])
+ _LDAPResult = namedtuple("LDAPResult", ["dn", "attrs"])
- def __init__(self, ldap_uri, base_dn, admin_dn, admin_passwd, user_rdn, uid_attr, email_attr,
- allow_tls_fallback=False, secondary_user_rdns=None, requires_email=True,
- timeout=None, network_timeout=None, force_no_pagination=False):
- super(LDAPUsers, self).__init__('ldap', requires_email)
+ def __init__(
+ self,
+ ldap_uri,
+ base_dn,
+ admin_dn,
+ admin_passwd,
+ user_rdn,
+ uid_attr,
+ email_attr,
+ allow_tls_fallback=False,
+ secondary_user_rdns=None,
+ requires_email=True,
+ timeout=None,
+ network_timeout=None,
+ force_no_pagination=False,
+ ):
+ super(LDAPUsers, self).__init__("ldap", requires_email)
- self._ldap = LDAPConnectionBuilder(ldap_uri, admin_dn, admin_passwd, allow_tls_fallback,
- timeout, network_timeout)
- self._ldap_uri = ldap_uri
- self._uid_attr = uid_attr
- self._email_attr = email_attr
- self._allow_tls_fallback = allow_tls_fallback
- self._requires_email = requires_email
- self._force_no_pagination = force_no_pagination
+ self._ldap = LDAPConnectionBuilder(
+ ldap_uri,
+ admin_dn,
+ admin_passwd,
+ allow_tls_fallback,
+ timeout,
+ network_timeout,
+ )
+ self._ldap_uri = ldap_uri
+ self._uid_attr = uid_attr
+ self._email_attr = email_attr
+ self._allow_tls_fallback = allow_tls_fallback
+ self._requires_email = requires_email
+ self._force_no_pagination = force_no_pagination
- # Note: user_rdn is a list of RDN pieces (for historical reasons), and secondary_user_rds
- # is a list of RDN strings.
- relative_user_dns = [','.join(user_rdn)] + (secondary_user_rdns or [])
+ # Note: user_rdn is a list of RDN pieces (for historical reasons), and secondary_user_rds
+ # is a list of RDN strings.
+ relative_user_dns = [",".join(user_rdn)] + (secondary_user_rdns or [])
- def get_full_rdn(relative_dn):
- prefix = relative_dn.split(',') if relative_dn else []
- return ','.join(prefix + base_dn)
+ def get_full_rdn(relative_dn):
+ prefix = relative_dn.split(",") if relative_dn else []
+ return ",".join(prefix + base_dn)
- # Create the set of full DN paths.
- self._user_dns = [get_full_rdn(relative_dn) for relative_dn in relative_user_dns]
- self._base_dn = ','.join(base_dn)
+ # Create the set of full DN paths.
+ self._user_dns = [
+ get_full_rdn(relative_dn) for relative_dn in relative_user_dns
+ ]
+ self._base_dn = ",".join(base_dn)
- def _get_ldap_referral_dn(self, referral_exception):
- logger.debug('Got referral: %s', referral_exception.args[0])
- if not referral_exception.args[0] or not referral_exception.args[0].get('info'):
- logger.debug('LDAP referral missing info block')
- return None
+ def _get_ldap_referral_dn(self, referral_exception):
+ logger.debug("Got referral: %s", referral_exception.args[0])
+ if not referral_exception.args[0] or not referral_exception.args[0].get("info"):
+ logger.debug("LDAP referral missing info block")
+ return None
- referral_info = referral_exception.args[0]['info']
- if not referral_info.startswith('Referral:\n'):
- logger.debug('LDAP referral missing Referral header')
- return None
+ referral_info = referral_exception.args[0]["info"]
+ if not referral_info.startswith("Referral:\n"):
+ logger.debug("LDAP referral missing Referral header")
+ return None
- referral_uri = referral_info[len('Referral:\n'):]
- if not referral_uri.startswith('ldap:///'):
- logger.debug('LDAP referral URI does not start with ldap:///')
- return None
+ referral_uri = referral_info[len("Referral:\n") :]
+ if not referral_uri.startswith("ldap:///"):
+ logger.debug("LDAP referral URI does not start with ldap:///")
+ return None
- referral_dn = referral_uri[len('ldap:///'):]
- return referral_dn
+ referral_dn = referral_uri[len("ldap:///") :]
+ return referral_dn
- def _ldap_user_search_with_rdn(self, conn, username_or_email, user_search_dn, suffix=''):
- query = u'(|({0}={2}{3})({1}={2}{3}))'.format(self._uid_attr, self._email_attr,
- escape_filter_chars(username_or_email),
- suffix)
- logger.debug('Conducting user search: %s under %s', query, user_search_dn)
- try:
- return (conn.search_s(user_search_dn, ldap.SCOPE_SUBTREE, query.encode('utf-8')), None)
- except ldap.REFERRAL as re:
- referral_dn = self._get_ldap_referral_dn(re)
- if not referral_dn:
- return (None, 'Failed to follow referral when looking up username')
-
- try:
- subquery = u'(%s=%s)' % (self._uid_attr, username_or_email)
- return (conn.search_s(referral_dn, ldap.SCOPE_BASE, subquery), None)
- except ldap.LDAPError:
- logger.debug('LDAP referral search exception')
- return (None, 'Username not found')
-
- except ldap.LDAPError:
- logger.debug('LDAP search exception')
- return (None, 'Username not found')
-
- def _ldap_user_search(self, username_or_email, limit=20, suffix=''):
- if not username_or_email:
- return (None, 'Empty username/email')
-
- # Verify the admin connection works first. We do this here to avoid wrapping
- # the entire block in the INVALID CREDENTIALS check.
- try:
- with self._ldap.get_connection():
- pass
- except ldap.INVALID_CREDENTIALS:
- return (None, 'LDAP Admin dn or password is invalid')
-
- with self._ldap.get_connection() as conn:
- logger.debug('Incoming username or email param: %s', username_or_email.__repr__())
-
- for user_search_dn in self._user_dns:
- (pairs, err_msg) = self._ldap_user_search_with_rdn(conn, username_or_email, user_search_dn,
- suffix=suffix)
- if pairs is not None and len(pairs) > 0:
- break
-
- if err_msg is not None:
- return (None, err_msg)
-
- logger.debug('Found matching pairs: %s', pairs)
- results = [LDAPUsers._LDAPResult(*pair) for pair in take(limit, pairs)]
-
- # Filter out pairs without DNs. Some LDAP impls will return such pairs.
- with_dns = [result for result in results if result.dn]
- return (with_dns, None)
-
- def _ldap_single_user_search(self, username_or_email):
- with_dns, err_msg = self._ldap_user_search(username_or_email)
- if err_msg is not None:
- return (None, err_msg)
-
- # Make sure we have at least one result.
- if len(with_dns) < 1:
- return (None, 'Username not found')
-
- # If we have found a single pair, then return it.
- if len(with_dns) == 1:
- return (with_dns[0], None)
-
- # Otherwise, there are multiple pairs with DNs, so find the one with the mail
- # attribute (if any).
- with_mail = [result for result in with_dns if result.attrs.get(self._email_attr)]
- return (with_mail[0] if with_mail else with_dns[0], None)
-
- def _build_user_information(self, response):
- if not response.get(self._uid_attr):
- return (None, 'Missing uid field "%s" in user record' % self._uid_attr)
-
- if self._requires_email and not response.get(self._email_attr):
- return (None, 'Missing mail field "%s" in user record' % self._email_attr)
-
- username = response[self._uid_attr][0].decode('utf-8')
- email = response.get(self._email_attr, [None])[0]
- return (UserInformation(username=username, email=email, id=username), None)
-
- def ping(self):
- try:
- with self._ldap.get_connection():
- pass
- except ldap.INVALID_CREDENTIALS:
- return (False, 'LDAP Admin dn or password is invalid')
- except ldap.LDAPError as lde:
- logger.exception('Exception when trying to health check LDAP')
- return (False, lde.message)
-
- return (True, None)
-
- def at_least_one_user_exists(self):
- logger.debug('Checking if any users exist in LDAP')
- try:
- with self._ldap.get_connection():
- pass
- except ldap.INVALID_CREDENTIALS:
- return (None, 'LDAP Admin dn or password is invalid')
-
- has_pagination = not self._force_no_pagination
- with self._ldap.get_connection() as conn:
- for user_search_dn in self._user_dns:
- lc = ldap.controls.libldap.SimplePagedResultsControl(criticality=True, size=1, cookie='')
+ def _ldap_user_search_with_rdn(
+ self, conn, username_or_email, user_search_dn, suffix=""
+ ):
+ query = u"(|({0}={2}{3})({1}={2}{3}))".format(
+ self._uid_attr,
+ self._email_attr,
+ escape_filter_chars(username_or_email),
+ suffix,
+ )
+ logger.debug("Conducting user search: %s under %s", query, user_search_dn)
try:
- if has_pagination:
- msgid = conn.search_ext(user_search_dn, ldap.SCOPE_SUBTREE, serverctrls=[lc])
- _, rdata, _, serverctrls = conn.result3(msgid)
- else:
- msgid = conn.search(user_search_dn, ldap.SCOPE_SUBTREE)
- _, rdata = conn.result(msgid)
+ return (
+ conn.search_s(
+ user_search_dn, ldap.SCOPE_SUBTREE, query.encode("utf-8")
+ ),
+ None,
+ )
+ except ldap.REFERRAL as re:
+ referral_dn = self._get_ldap_referral_dn(re)
+ if not referral_dn:
+ return (None, "Failed to follow referral when looking up username")
- for entry in rdata: # Handles both lists and iterators.
- return (True, None)
+ try:
+ subquery = u"(%s=%s)" % (self._uid_attr, username_or_email)
+ return (conn.search_s(referral_dn, ldap.SCOPE_BASE, subquery), None)
+ except ldap.LDAPError:
+ logger.debug("LDAP referral search exception")
+ return (None, "Username not found")
- except ldap.LDAPError as lde:
- return (False, str(lde) or 'Could not find DN %s' % user_search_dn)
+ except ldap.LDAPError:
+ logger.debug("LDAP search exception")
+ return (None, "Username not found")
- return (False, None)
+ def _ldap_user_search(self, username_or_email, limit=20, suffix=""):
+ if not username_or_email:
+ return (None, "Empty username/email")
- def get_user(self, username_or_email):
- """ Looks up a username or email in LDAP. """
- logger.debug('Looking up LDAP username or email %s', username_or_email)
- (found_user, err_msg) = self._ldap_single_user_search(username_or_email)
- if err_msg is not None:
- return (None, err_msg)
-
- logger.debug('Found user for LDAP username or email %s', username_or_email)
- _, found_response = found_user
- return self._build_user_information(found_response)
-
- def query_users(self, query, limit=20):
- """ Queries LDAP for matching users. """
- if not query:
- return (None, self.federated_service, 'Empty query')
-
- logger.debug('Got query %s with limit %s', query, limit)
- (results, err_msg) = self._ldap_user_search(query, limit=limit, suffix='*')
- if err_msg is not None:
- return (None, self.federated_service, err_msg)
-
- final_results = []
- for result in results[0:limit]:
- credentials, err_msg = self._build_user_information(result.attrs)
- if err_msg is not None:
- continue
-
- final_results.append(credentials)
-
- logger.debug('For query %s found results %s', query, final_results)
- return (final_results, self.federated_service, None)
-
- def verify_credentials(self, username_or_email, password):
- """ Verify the credentials with LDAP. """
- # Make sure that even if the server supports anonymous binds, we don't allow it
- if not password:
- return (None, 'Anonymous binding not allowed')
-
- (found_user, err_msg) = self._ldap_single_user_search(username_or_email)
- if found_user is None:
- return (None, err_msg)
-
- found_dn, found_response = found_user
- logger.debug('Found user for LDAP username %s; validating password', username_or_email)
- logger.debug('DN %s found: %s', found_dn, found_response)
-
- # First validate the password by binding as the user
- try:
- with LDAPConnection(self._ldap_uri, found_dn, password.encode('utf-8'),
- self._allow_tls_fallback):
- pass
- except ldap.REFERRAL as re:
- referral_dn = self._get_ldap_referral_dn(re)
- if not referral_dn:
- return (None, 'Invalid username')
-
- try:
- with LDAPConnection(self._ldap_uri, referral_dn, password.encode('utf-8'),
- self._allow_tls_fallback):
- pass
- except ldap.INVALID_CREDENTIALS:
- logger.debug('Invalid LDAP credentials')
- return (None, 'Invalid password')
-
- except ldap.INVALID_CREDENTIALS:
- logger.debug('Invalid LDAP credentials')
- return (None, 'Invalid password')
-
- return self._build_user_information(found_response)
-
- def service_metadata(self):
- return {
- 'base_dn': self._base_dn,
- }
-
- def check_group_lookup_args(self, group_lookup_args, disable_pagination=False):
- if not group_lookup_args.get('group_dn'):
- return (False, 'Missing group_dn')
-
- (it, err) = self.iterate_group_members(group_lookup_args, page_size=1,
- disable_pagination=disable_pagination)
- if err is not None:
- return (False, err)
-
- if not next(it, False):
- return (False, 'Group does not exist or is empty')
-
- return (True, None)
-
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- try:
- with self._ldap.get_connection():
- pass
- except ldap.INVALID_CREDENTIALS:
- return (None, 'LDAP Admin dn or password is invalid')
-
- group_dn = group_lookup_args['group_dn']
- page_size = page_size or _DEFAULT_PAGE_SIZE
- return (self._iterate_members(group_dn, page_size, disable_pagination), None)
-
- def _iterate_members(self, group_dn, page_size, disable_pagination):
- has_pagination = not(self._force_no_pagination or disable_pagination)
- with self._ldap.get_connection() as conn:
- search_flt = filter_format('(memberOf=%s,%s)', (group_dn, self._base_dn))
- attributes = [self._uid_attr, self._email_attr]
-
- for user_search_dn in self._user_dns:
- lc = ldap.controls.libldap.SimplePagedResultsControl(criticality=True, size=page_size,
- cookie='')
-
- # Conduct the initial search for users that are a member of the group.
- logger.debug('Conducting LDAP search of DN: %s and filter %s', user_search_dn, search_flt)
+ # Verify the admin connection works first. We do this here to avoid wrapping
+ # the entire block in the INVALID CREDENTIALS check.
try:
- if has_pagination:
- msgid = conn.search_ext(user_search_dn, ldap.SCOPE_SUBTREE, search_flt,
- serverctrls=[lc], attrlist=attributes)
- else:
- msgid = conn.search(user_search_dn, ldap.SCOPE_SUBTREE, search_flt, attrlist=attributes)
+ with self._ldap.get_connection():
+ pass
+ except ldap.INVALID_CREDENTIALS:
+ return (None, "LDAP Admin dn or password is invalid")
+
+ with self._ldap.get_connection() as conn:
+ logger.debug(
+ "Incoming username or email param: %s", username_or_email.__repr__()
+ )
+
+ for user_search_dn in self._user_dns:
+ (pairs, err_msg) = self._ldap_user_search_with_rdn(
+ conn, username_or_email, user_search_dn, suffix=suffix
+ )
+ if pairs is not None and len(pairs) > 0:
+ break
+
+ if err_msg is not None:
+ return (None, err_msg)
+
+ logger.debug("Found matching pairs: %s", pairs)
+ results = [LDAPUsers._LDAPResult(*pair) for pair in take(limit, pairs)]
+
+ # Filter out pairs without DNs. Some LDAP impls will return such pairs.
+ with_dns = [result for result in results if result.dn]
+ return (with_dns, None)
+
+ def _ldap_single_user_search(self, username_or_email):
+ with_dns, err_msg = self._ldap_user_search(username_or_email)
+ if err_msg is not None:
+ return (None, err_msg)
+
+ # Make sure we have at least one result.
+ if len(with_dns) < 1:
+ return (None, "Username not found")
+
+ # If we have found a single pair, then return it.
+ if len(with_dns) == 1:
+ return (with_dns[0], None)
+
+ # Otherwise, there are multiple pairs with DNs, so find the one with the mail
+ # attribute (if any).
+ with_mail = [
+ result for result in with_dns if result.attrs.get(self._email_attr)
+ ]
+ return (with_mail[0] if with_mail else with_dns[0], None)
+
+ def _build_user_information(self, response):
+ if not response.get(self._uid_attr):
+ return (None, 'Missing uid field "%s" in user record' % self._uid_attr)
+
+ if self._requires_email and not response.get(self._email_attr):
+ return (None, 'Missing mail field "%s" in user record' % self._email_attr)
+
+ username = response[self._uid_attr][0].decode("utf-8")
+ email = response.get(self._email_attr, [None])[0]
+ return (UserInformation(username=username, email=email, id=username), None)
+
+ def ping(self):
+ try:
+ with self._ldap.get_connection():
+ pass
+ except ldap.INVALID_CREDENTIALS:
+ return (False, "LDAP Admin dn or password is invalid")
except ldap.LDAPError as lde:
- logger.exception('Got error when trying to search %s with filter %s: %s',
- user_search_dn, search_flt, lde.message)
- break
+ logger.exception("Exception when trying to health check LDAP")
+ return (False, lde.message)
- while True:
- try:
- if has_pagination:
- _, rdata, _, serverctrls = conn.result3(msgid)
- else:
- _, rdata = conn.result(msgid)
+ return (True, None)
- # Yield any users found.
- found_results = 0
- for userdata in rdata:
- found_results = found_results + 1
- yield self._build_user_information(userdata[1])
+ def at_least_one_user_exists(self):
+ logger.debug("Checking if any users exist in LDAP")
+ try:
+ with self._ldap.get_connection():
+ pass
+ except ldap.INVALID_CREDENTIALS:
+ return (None, "LDAP Admin dn or password is invalid")
- logger.debug('Found %s users in group %s; %s', found_results, user_search_dn,
- search_flt)
- except ldap.NO_SUCH_OBJECT as nsoe:
- logger.debug('NSO when trying to lookup results of search %s with filter %s: %s',
- user_search_dn, search_flt, nsoe.message)
- except ldap.LDAPError as lde:
- logger.exception('Error when trying to lookup results of search %s with filter %s: %s',
- user_search_dn, search_flt, lde.message)
- break
+ has_pagination = not self._force_no_pagination
+ with self._ldap.get_connection() as conn:
+ for user_search_dn in self._user_dns:
+ lc = ldap.controls.libldap.SimplePagedResultsControl(
+ criticality=True, size=1, cookie=""
+ )
+ try:
+ if has_pagination:
+ msgid = conn.search_ext(
+ user_search_dn, ldap.SCOPE_SUBTREE, serverctrls=[lc]
+ )
+ _, rdata, _, serverctrls = conn.result3(msgid)
+ else:
+ msgid = conn.search(user_search_dn, ldap.SCOPE_SUBTREE)
+ _, rdata = conn.result(msgid)
- # If no additional results, nothing more to do.
- if not found_results:
- break
+ for entry in rdata: # Handles both lists and iterators.
+ return (True, None)
- # If pagination is disabled, nothing more to do.
- if not has_pagination:
- logger.debug('Pagination is disabled, no further queries')
- break
+ except ldap.LDAPError as lde:
+ return (False, str(lde) or "Could not find DN %s" % user_search_dn)
- # Filter down the controls with which the server responded, looking for the paging
- # control type. If not found, then the server does not support pagination and we already
- # got all of the results.
- pctrls = [control for control in serverctrls
- if control.controlType == ldap.controls.SimplePagedResultsControl.controlType]
+ return (False, None)
- if pctrls:
- # Server supports pagination. Update the cookie so the next search finds the next page,
- # then conduct the next search.
- cookie = lc.cookie = pctrls[0].cookie
- if cookie:
- logger.debug('Pagination is supported for this LDAP server; trying next page')
- msgid = conn.search_ext(user_search_dn, ldap.SCOPE_SUBTREE, search_flt,
- serverctrls=[lc], attrlist=attributes)
- continue
- else:
- # No additional results.
- logger.debug('Pagination is supported for this LDAP server but on last page')
- break
- else:
- # Pagination is not supported.
- logger.debug('Pagination is not supported for this LDAP server')
- break
+ def get_user(self, username_or_email):
+ """ Looks up a username or email in LDAP. """
+ logger.debug("Looking up LDAP username or email %s", username_or_email)
+ (found_user, err_msg) = self._ldap_single_user_search(username_or_email)
+ if err_msg is not None:
+ return (None, err_msg)
+
+ logger.debug("Found user for LDAP username or email %s", username_or_email)
+ _, found_response = found_user
+ return self._build_user_information(found_response)
+
+ def query_users(self, query, limit=20):
+ """ Queries LDAP for matching users. """
+ if not query:
+ return (None, self.federated_service, "Empty query")
+
+ logger.debug("Got query %s with limit %s", query, limit)
+ (results, err_msg) = self._ldap_user_search(query, limit=limit, suffix="*")
+ if err_msg is not None:
+ return (None, self.federated_service, err_msg)
+
+ final_results = []
+ for result in results[0:limit]:
+ credentials, err_msg = self._build_user_information(result.attrs)
+ if err_msg is not None:
+ continue
+
+ final_results.append(credentials)
+
+ logger.debug("For query %s found results %s", query, final_results)
+ return (final_results, self.federated_service, None)
+
+ def verify_credentials(self, username_or_email, password):
+ """ Verify the credentials with LDAP. """
+ # Make sure that even if the server supports anonymous binds, we don't allow it
+ if not password:
+ return (None, "Anonymous binding not allowed")
+
+ (found_user, err_msg) = self._ldap_single_user_search(username_or_email)
+ if found_user is None:
+ return (None, err_msg)
+
+ found_dn, found_response = found_user
+ logger.debug(
+ "Found user for LDAP username %s; validating password", username_or_email
+ )
+ logger.debug("DN %s found: %s", found_dn, found_response)
+
+ # First validate the password by binding as the user
+ try:
+ with LDAPConnection(
+ self._ldap_uri,
+ found_dn,
+ password.encode("utf-8"),
+ self._allow_tls_fallback,
+ ):
+ pass
+ except ldap.REFERRAL as re:
+ referral_dn = self._get_ldap_referral_dn(re)
+ if not referral_dn:
+ return (None, "Invalid username")
+
+ try:
+ with LDAPConnection(
+ self._ldap_uri,
+ referral_dn,
+ password.encode("utf-8"),
+ self._allow_tls_fallback,
+ ):
+ pass
+ except ldap.INVALID_CREDENTIALS:
+ logger.debug("Invalid LDAP credentials")
+ return (None, "Invalid password")
+
+ except ldap.INVALID_CREDENTIALS:
+ logger.debug("Invalid LDAP credentials")
+ return (None, "Invalid password")
+
+ return self._build_user_information(found_response)
+
+ def service_metadata(self):
+ return {"base_dn": self._base_dn}
+
+ def check_group_lookup_args(self, group_lookup_args, disable_pagination=False):
+ if not group_lookup_args.get("group_dn"):
+ return (False, "Missing group_dn")
+
+ (it, err) = self.iterate_group_members(
+ group_lookup_args, page_size=1, disable_pagination=disable_pagination
+ )
+ if err is not None:
+ return (False, err)
+
+ if not next(it, False):
+ return (False, "Group does not exist or is empty")
+
+ return (True, None)
+
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ try:
+ with self._ldap.get_connection():
+ pass
+ except ldap.INVALID_CREDENTIALS:
+ return (None, "LDAP Admin dn or password is invalid")
+
+ group_dn = group_lookup_args["group_dn"]
+ page_size = page_size or _DEFAULT_PAGE_SIZE
+ return (self._iterate_members(group_dn, page_size, disable_pagination), None)
+
+ def _iterate_members(self, group_dn, page_size, disable_pagination):
+ has_pagination = not (self._force_no_pagination or disable_pagination)
+ with self._ldap.get_connection() as conn:
+ search_flt = filter_format("(memberOf=%s,%s)", (group_dn, self._base_dn))
+ attributes = [self._uid_attr, self._email_attr]
+
+ for user_search_dn in self._user_dns:
+ lc = ldap.controls.libldap.SimplePagedResultsControl(
+ criticality=True, size=page_size, cookie=""
+ )
+
+ # Conduct the initial search for users that are a member of the group.
+ logger.debug(
+ "Conducting LDAP search of DN: %s and filter %s",
+ user_search_dn,
+ search_flt,
+ )
+ try:
+ if has_pagination:
+ msgid = conn.search_ext(
+ user_search_dn,
+ ldap.SCOPE_SUBTREE,
+ search_flt,
+ serverctrls=[lc],
+ attrlist=attributes,
+ )
+ else:
+ msgid = conn.search(
+ user_search_dn,
+ ldap.SCOPE_SUBTREE,
+ search_flt,
+ attrlist=attributes,
+ )
+ except ldap.LDAPError as lde:
+ logger.exception(
+ "Got error when trying to search %s with filter %s: %s",
+ user_search_dn,
+ search_flt,
+ lde.message,
+ )
+ break
+
+ while True:
+ try:
+ if has_pagination:
+ _, rdata, _, serverctrls = conn.result3(msgid)
+ else:
+ _, rdata = conn.result(msgid)
+
+ # Yield any users found.
+ found_results = 0
+ for userdata in rdata:
+ found_results = found_results + 1
+ yield self._build_user_information(userdata[1])
+
+ logger.debug(
+ "Found %s users in group %s; %s",
+ found_results,
+ user_search_dn,
+ search_flt,
+ )
+ except ldap.NO_SUCH_OBJECT as nsoe:
+ logger.debug(
+ "NSO when trying to lookup results of search %s with filter %s: %s",
+ user_search_dn,
+ search_flt,
+ nsoe.message,
+ )
+ except ldap.LDAPError as lde:
+ logger.exception(
+ "Error when trying to lookup results of search %s with filter %s: %s",
+ user_search_dn,
+ search_flt,
+ lde.message,
+ )
+ break
+
+ # If no additional results, nothing more to do.
+ if not found_results:
+ break
+
+ # If pagination is disabled, nothing more to do.
+ if not has_pagination:
+ logger.debug("Pagination is disabled, no further queries")
+ break
+
+ # Filter down the controls with which the server responded, looking for the paging
+ # control type. If not found, then the server does not support pagination and we already
+ # got all of the results.
+ pctrls = [
+ control
+ for control in serverctrls
+ if control.controlType
+ == ldap.controls.SimplePagedResultsControl.controlType
+ ]
+
+ if pctrls:
+ # Server supports pagination. Update the cookie so the next search finds the next page,
+ # then conduct the next search.
+ cookie = lc.cookie = pctrls[0].cookie
+ if cookie:
+ logger.debug(
+ "Pagination is supported for this LDAP server; trying next page"
+ )
+ msgid = conn.search_ext(
+ user_search_dn,
+ ldap.SCOPE_SUBTREE,
+ search_flt,
+ serverctrls=[lc],
+ attrlist=attributes,
+ )
+ continue
+ else:
+ # No additional results.
+ logger.debug(
+ "Pagination is supported for this LDAP server but on last page"
+ )
+ break
+ else:
+ # Pagination is not supported.
+ logger.debug("Pagination is not supported for this LDAP server")
+ break
diff --git a/data/users/federated.py b/data/users/federated.py
index 87210bccd..795875cfa 100644
--- a/data/users/federated.py
+++ b/data/users/federated.py
@@ -9,146 +9,167 @@ from util.validation import generate_valid_usernames
logger = logging.getLogger(__name__)
-UserInformation = namedtuple('UserInformation', ['username', 'email', 'id'])
+UserInformation = namedtuple("UserInformation", ["username", "email", "id"])
+
+DISABLED_MESSAGE = (
+ "User creation is disabled. Please contact your administrator to gain access."
+)
-DISABLED_MESSAGE = 'User creation is disabled. Please contact your administrator to gain access.'
class FederatedUsers(object):
- """ Base class for all federated users systems. """
+ """ Base class for all federated users systems. """
- def __init__(self, federated_service, requires_email):
- self._federated_service = federated_service
- self._requires_email = requires_email
+ def __init__(self, federated_service, requires_email):
+ self._federated_service = federated_service
+ self._requires_email = requires_email
- @property
- def federated_service(self):
- return self._federated_service
+ @property
+ def federated_service(self):
+ return self._federated_service
- @property
- def supports_fresh_login(self):
- return True
+ @property
+ def supports_fresh_login(self):
+ return True
- @property
- def supports_encrypted_credentials(self):
- return True
+ @property
+ def supports_encrypted_credentials(self):
+ return True
- def has_password_set(self, username):
- return True
+ def has_password_set(self, username):
+ return True
- @property
- def requires_distinct_cli_password(self):
- # Since the federated auth provides a password which works on the CLI.
- return False
+ @property
+ def requires_distinct_cli_password(self):
+ # Since the federated auth provides a password which works on the CLI.
+ return False
- def get_user(self, username_or_email):
- """ Retrieves the user with the given username or email, returning a tuple containing
+ def get_user(self, username_or_email):
+ """ Retrieves the user with the given username or email, returning a tuple containing
a UserInformation (if success) and the error message (on failure).
"""
- raise NotImplementedError
+ raise NotImplementedError
- def verify_credentials(self, username_or_email, password):
- """ Verifies the given credentials against the backing federated service, returning
+ def verify_credentials(self, username_or_email, password):
+ """ Verifies the given credentials against the backing federated service, returning
a tuple containing a UserInformation (on success) and the error message (on failure).
"""
- raise NotImplementedError
+ raise NotImplementedError
- def query_users(self, query, limit=20):
- """ If implemented, get_user must be implemented as well. """
- return (None, 'Not supported')
+ def query_users(self, query, limit=20):
+ """ If implemented, get_user must be implemented as well. """
+ return (None, "Not supported")
- def link_user(self, username_or_email):
- (user_info, err_msg) = self.get_user(username_or_email)
- if user_info is None:
- return (None, err_msg)
+ def link_user(self, username_or_email):
+ (user_info, err_msg) = self.get_user(username_or_email)
+ if user_info is None:
+ return (None, err_msg)
- return self.get_and_link_federated_user_info(user_info)
+ return self.get_and_link_federated_user_info(user_info)
- def get_and_link_federated_user_info(self, user_info, internal_create=False):
- return self._get_and_link_federated_user_info(user_info.username, user_info.email,
- internal_create=internal_create)
+ def get_and_link_federated_user_info(self, user_info, internal_create=False):
+ return self._get_and_link_federated_user_info(
+ user_info.username, user_info.email, internal_create=internal_create
+ )
- def verify_and_link_user(self, username_or_email, password):
- """ Verifies the given credentials and, if valid, creates/links a database user to the
+ def verify_and_link_user(self, username_or_email, password):
+ """ Verifies the given credentials and, if valid, creates/links a database user to the
associated federated service.
"""
- (credentials, err_msg) = self.verify_credentials(username_or_email, password)
- if credentials is None:
- return (None, err_msg)
+ (credentials, err_msg) = self.verify_credentials(username_or_email, password)
+ if credentials is None:
+ return (None, err_msg)
- return self._get_and_link_federated_user_info(credentials.username, credentials.email)
+ return self._get_and_link_federated_user_info(
+ credentials.username, credentials.email
+ )
- def confirm_existing_user(self, username, password):
- """ Confirms that the given *database* username and service password are valid for the linked
+ def confirm_existing_user(self, username, password):
+ """ Confirms that the given *database* username and service password are valid for the linked
service. This method is used when the federated service's username is not known.
"""
- db_user = model.user.get_user(username)
- if not db_user:
- return (None, 'Invalid user')
+ db_user = model.user.get_user(username)
+ if not db_user:
+ return (None, "Invalid user")
- federated_login = model.user.lookup_federated_login(db_user, self._federated_service)
- if not federated_login:
- return (None, 'Invalid user')
+ federated_login = model.user.lookup_federated_login(
+ db_user, self._federated_service
+ )
+ if not federated_login:
+ return (None, "Invalid user")
- (credentials, err_msg) = self.verify_credentials(federated_login.service_ident, password)
- if credentials is None:
- return (None, err_msg)
+ (credentials, err_msg) = self.verify_credentials(
+ federated_login.service_ident, password
+ )
+ if credentials is None:
+ return (None, err_msg)
- return (db_user, None)
+ return (db_user, None)
- def service_metadata(self):
- """ Returns a dictionary of extra metadata to present to *superusers* about this auth engine.
+ def service_metadata(self):
+ """ Returns a dictionary of extra metadata to present to *superusers* about this auth engine.
For example, LDAP returns the base DN so we can display to the user during sync setup.
"""
- return {}
+ return {}
- def check_group_lookup_args(self, group_lookup_args):
- """ Verifies that the given group lookup args point to a valid group. Returns a tuple consisting
+ def check_group_lookup_args(self, group_lookup_args):
+ """ Verifies that the given group lookup args point to a valid group. Returns a tuple consisting
of a boolean status and an error message (if any).
"""
- return (False, 'Not supported')
+ return (False, "Not supported")
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- """ Returns an iterator over all the members of the group matching the given lookup args
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ """ Returns an iterator over all the members of the group matching the given lookup args
dictionary. The format of the lookup args dictionary is specific to the implementation.
"""
- return (None, 'Not supported')
+ return (None, "Not supported")
- def _get_and_link_federated_user_info(self, username, email, internal_create=False):
- db_user = model.user.verify_federated_login(self._federated_service, username)
- if not db_user:
-
- # Fetch list of blacklisted domains
- blacklisted_domains = model.config.app_config.get('BLACKLISTED_EMAIL_DOMAINS')
+ def _get_and_link_federated_user_info(self, username, email, internal_create=False):
+ db_user = model.user.verify_federated_login(self._federated_service, username)
+ if not db_user:
- # We must create the user in our db. Check to see if this is allowed (except for internal
- # creation, which is always allowed).
- if not internal_create and not can_create_user(email, blacklisted_domains):
- return (None, DISABLED_MESSAGE)
+ # Fetch list of blacklisted domains
+ blacklisted_domains = model.config.app_config.get(
+ "BLACKLISTED_EMAIL_DOMAINS"
+ )
- valid_username = None
- for valid_username in generate_valid_usernames(username):
- if model.user.is_username_unique(valid_username):
- break
+ # We must create the user in our db. Check to see if this is allowed (except for internal
+ # creation, which is always allowed).
+ if not internal_create and not can_create_user(email, blacklisted_domains):
+ return (None, DISABLED_MESSAGE)
- if not valid_username:
- logger.error('Unable to pick a username for user: %s', username)
- return (None, 'Unable to pick a username. Please report this to your administrator.')
+ valid_username = None
+ for valid_username in generate_valid_usernames(username):
+ if model.user.is_username_unique(valid_username):
+ break
- prompts = model.user.get_default_user_prompts(features)
- try:
- db_user = model.user.create_federated_user(valid_username, email, self._federated_service,
- username,
- set_password_notification=False,
- email_required=self._requires_email,
- confirm_username=features.USERNAME_CONFIRMATION,
- prompts=prompts)
- except model.InvalidEmailAddressException as iae:
- return (None, str(iae))
+ if not valid_username:
+ logger.error("Unable to pick a username for user: %s", username)
+ return (
+ None,
+ "Unable to pick a username. Please report this to your administrator.",
+ )
- else:
- # Update the db attributes from the federated service.
- if email and db_user.email != email:
- db_user.email = email
- db_user.save()
+ prompts = model.user.get_default_user_prompts(features)
+ try:
+ db_user = model.user.create_federated_user(
+ valid_username,
+ email,
+ self._federated_service,
+ username,
+ set_password_notification=False,
+ email_required=self._requires_email,
+ confirm_username=features.USERNAME_CONFIRMATION,
+ prompts=prompts,
+ )
+ except model.InvalidEmailAddressException as iae:
+ return (None, str(iae))
- return (db_user, None)
+ else:
+ # Update the db attributes from the federated service.
+ if email and db_user.email != email:
+ db_user.email = email
+ db_user.save()
+
+ return (db_user, None)
diff --git a/data/users/keystone.py b/data/users/keystone.py
index b8e581e77..64dff0e78 100644
--- a/data/users/keystone.py
+++ b/data/users/keystone.py
@@ -7,7 +7,9 @@ from keystoneauth1 import session
from keystoneauth1.exceptions import ClientException
from keystoneclient.v2_0 import client as client_v2
from keystoneclient.v3 import client as client_v3
-from keystoneclient.exceptions import AuthorizationFailure as KeystoneAuthorizationFailure
+from keystoneclient.exceptions import (
+ AuthorizationFailure as KeystoneAuthorizationFailure,
+)
from keystoneclient.exceptions import Unauthorized as KeystoneUnauthorized
from keystoneclient.exceptions import NotFound as KeystoneNotFound
from data.users.federated import FederatedUsers, UserInformation
@@ -15,286 +17,370 @@ from util.itertoolrecipes import take
logger = logging.getLogger(__name__)
-DEFAULT_TIMEOUT = 10 # seconds
+DEFAULT_TIMEOUT = 10 # seconds
-def get_keystone_users(auth_version, auth_url, admin_username, admin_password, admin_tenant,
- timeout=None, requires_email=True):
- if auth_version == 3:
- return KeystoneV3Users(auth_url, admin_username, admin_password, admin_tenant, timeout,
- requires_email)
- else:
- return KeystoneV2Users(auth_url, admin_username, admin_password, admin_tenant, timeout,
- requires_email)
+
+def get_keystone_users(
+ auth_version,
+ auth_url,
+ admin_username,
+ admin_password,
+ admin_tenant,
+ timeout=None,
+ requires_email=True,
+):
+ if auth_version == 3:
+ return KeystoneV3Users(
+ auth_url,
+ admin_username,
+ admin_password,
+ admin_tenant,
+ timeout,
+ requires_email,
+ )
+ else:
+ return KeystoneV2Users(
+ auth_url,
+ admin_username,
+ admin_password,
+ admin_tenant,
+ timeout,
+ requires_email,
+ )
class KeystoneV2Users(FederatedUsers):
- """ Delegates authentication to OpenStack Keystone V2. """
- def __init__(self, auth_url, admin_username, admin_password, admin_tenant, timeout=None,
- requires_email=True):
- super(KeystoneV2Users, self).__init__('keystone', requires_email)
- self.auth_url = auth_url
- self.admin_username = admin_username
- self.admin_password = admin_password
- self.admin_tenant = admin_tenant
- self.timeout = timeout or DEFAULT_TIMEOUT
- self.debug = os.environ.get('USERS_DEBUG') == '1'
- self.requires_email = requires_email
+ """ Delegates authentication to OpenStack Keystone V2. """
- def _get_client(self, username, password, tenant_name=None):
- if tenant_name:
- auth = keystone_v2_auth.Password(auth_url=self.auth_url,
- username=username,
- password=password,
- tenant_name=tenant_name)
- else:
- auth = keystone_v2_auth.Password(auth_url=self.auth_url,
- username=username,
- password=password)
+ def __init__(
+ self,
+ auth_url,
+ admin_username,
+ admin_password,
+ admin_tenant,
+ timeout=None,
+ requires_email=True,
+ ):
+ super(KeystoneV2Users, self).__init__("keystone", requires_email)
+ self.auth_url = auth_url
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+ self.admin_tenant = admin_tenant
+ self.timeout = timeout or DEFAULT_TIMEOUT
+ self.debug = os.environ.get("USERS_DEBUG") == "1"
+ self.requires_email = requires_email
- sess = session.Session(auth=auth)
- client = client_v2.Client(session=sess,
- timeout=self.timeout,
- debug=self.debug)
- return client, sess
+ def _get_client(self, username, password, tenant_name=None):
+ if tenant_name:
+ auth = keystone_v2_auth.Password(
+ auth_url=self.auth_url,
+ username=username,
+ password=password,
+ tenant_name=tenant_name,
+ )
+ else:
+ auth = keystone_v2_auth.Password(
+ auth_url=self.auth_url, username=username, password=password
+ )
- def ping(self):
- try:
- _, sess = self._get_client(self.admin_username, self.admin_password, self.admin_tenant)
- assert sess.get_user_id() # Make sure we loaded a valid user.
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized admin')
- return (False, 'Keystone admin credentials are invalid: %s' % kut.message)
- except ClientException as e:
- logger.exception('Keystone unauthorized admin')
- return (False, 'Keystone ping check failed: %s' % e.message)
+ sess = session.Session(auth=auth)
+ client = client_v2.Client(session=sess, timeout=self.timeout, debug=self.debug)
+ return client, sess
- return (True, None)
+ def ping(self):
+ try:
+ _, sess = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ assert sess.get_user_id() # Make sure we loaded a valid user.
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized admin")
+ return (False, "Keystone admin credentials are invalid: %s" % kut.message)
+ except ClientException as e:
+ logger.exception("Keystone unauthorized admin")
+ return (False, "Keystone ping check failed: %s" % e.message)
- def at_least_one_user_exists(self):
- logger.debug('Checking if any users exist in Keystone')
- try:
- keystone_client, _ = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
- user_list = keystone_client.users.list(tenant_id=self.admin_tenant, limit=1)
+ return (True, None)
- if len(user_list) < 1:
- return (False, None)
+ def at_least_one_user_exists(self):
+ logger.debug("Checking if any users exist in Keystone")
+ try:
+ keystone_client, _ = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ user_list = keystone_client.users.list(tenant_id=self.admin_tenant, limit=1)
- return (True, None)
- except ClientException as e:
- # Catch exceptions to give the user our custom error message
- logger.exception('Unable to list users in Keystone')
- return (False, e.message)
+ if len(user_list) < 1:
+ return (False, None)
- def verify_credentials(self, username_or_email, password):
- try:
- _, sess = self._get_client(username_or_email, password)
- user_id = sess.get_user_id()
- except KeystoneAuthorizationFailure as kaf:
- logger.exception('Keystone auth failure for user: %s', username_or_email)
- return (None, 'Invalid username or password')
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized for user: %s', username_or_email)
- return (None, 'Invalid username or password')
- except ClientException as ex:
- logger.exception('Keystone unauthorized for user: %s', username_or_email)
- return (None, 'Invalid username or password')
+ return (True, None)
+ except ClientException as e:
+ # Catch exceptions to give the user our custom error message
+ logger.exception("Unable to list users in Keystone")
+ return (False, e.message)
- if user_id is None:
- return (None, 'Invalid username or password')
+ def verify_credentials(self, username_or_email, password):
+ try:
+ _, sess = self._get_client(username_or_email, password)
+ user_id = sess.get_user_id()
+ except KeystoneAuthorizationFailure as kaf:
+ logger.exception("Keystone auth failure for user: %s", username_or_email)
+ return (None, "Invalid username or password")
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized for user: %s", username_or_email)
+ return (None, "Invalid username or password")
+ except ClientException as ex:
+ logger.exception("Keystone unauthorized for user: %s", username_or_email)
+ return (None, "Invalid username or password")
- try:
- admin_client, _ = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
- user = admin_client.users.get(user_id)
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized admin')
- return (None, 'Keystone admin credentials are invalid: %s' % kut.message)
+ if user_id is None:
+ return (None, "Invalid username or password")
- if self.requires_email and not hasattr(user, 'email'):
- return (None, 'Missing email field for user %s' % user_id)
+ try:
+ admin_client, _ = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ user = admin_client.users.get(user_id)
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized admin")
+ return (None, "Keystone admin credentials are invalid: %s" % kut.message)
- email = user.email if hasattr(user, 'email') else None
- return (UserInformation(username=username_or_email, email=email, id=user_id), None)
+ if self.requires_email and not hasattr(user, "email"):
+ return (None, "Missing email field for user %s" % user_id)
- def query_users(self, query, limit=20):
- return (None, self.federated_service, 'Unsupported in Keystone V2')
+ email = user.email if hasattr(user, "email") else None
+ return (
+ UserInformation(username=username_or_email, email=email, id=user_id),
+ None,
+ )
- def get_user(self, username_or_email):
- return (None, 'Unsupported in Keystone V2')
+ def query_users(self, query, limit=20):
+ return (None, self.federated_service, "Unsupported in Keystone V2")
+
+ def get_user(self, username_or_email):
+ return (None, "Unsupported in Keystone V2")
class KeystoneV3Users(FederatedUsers):
- """ Delegates authentication to OpenStack Keystone V3. """
- def __init__(self, auth_url, admin_username, admin_password, admin_tenant, timeout=None,
- requires_email=True, project_domain_id='default', user_domain_id='default'):
- super(KeystoneV3Users, self).__init__('keystone', requires_email)
- self.auth_url = auth_url
- self.admin_username = admin_username
- self.admin_password = admin_password
- self.admin_tenant = admin_tenant
- self.project_domain_id = project_domain_id
- self.user_domain_id = user_domain_id
- self.timeout = timeout or DEFAULT_TIMEOUT
- self.debug = os.environ.get('USERS_DEBUG') == '1'
- self.requires_email = requires_email
+ """ Delegates authentication to OpenStack Keystone V3. """
- def _get_client(self, username, password, project_name=None):
- if project_name:
- auth = keystone_v3_auth.Password(auth_url=self.auth_url,
- username=username,
- password=password,
- project_name=project_name,
- project_domain_id=self.project_domain_id,
- user_domain_id=self.user_domain_id)
- else:
- auth = keystone_v3_auth.Password(auth_url=self.auth_url,
- username=username,
- password=password,
- user_domain_id=self.user_domain_id)
+ def __init__(
+ self,
+ auth_url,
+ admin_username,
+ admin_password,
+ admin_tenant,
+ timeout=None,
+ requires_email=True,
+ project_domain_id="default",
+ user_domain_id="default",
+ ):
+ super(KeystoneV3Users, self).__init__("keystone", requires_email)
+ self.auth_url = auth_url
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+ self.admin_tenant = admin_tenant
+ self.project_domain_id = project_domain_id
+ self.user_domain_id = user_domain_id
+ self.timeout = timeout or DEFAULT_TIMEOUT
+ self.debug = os.environ.get("USERS_DEBUG") == "1"
+ self.requires_email = requires_email
- sess = session.Session(auth=auth)
- client = client_v3.Client(session=sess,
- timeout=self.timeout,
- debug=self.debug)
- return client, sess
+ def _get_client(self, username, password, project_name=None):
+ if project_name:
+ auth = keystone_v3_auth.Password(
+ auth_url=self.auth_url,
+ username=username,
+ password=password,
+ project_name=project_name,
+ project_domain_id=self.project_domain_id,
+ user_domain_id=self.user_domain_id,
+ )
+ else:
+ auth = keystone_v3_auth.Password(
+ auth_url=self.auth_url,
+ username=username,
+ password=password,
+ user_domain_id=self.user_domain_id,
+ )
- def ping(self):
- try:
- _, sess = self._get_client(self.admin_username, self.admin_password)
- assert sess.get_user_id() # Make sure we loaded a valid user.
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized admin')
- return (False, 'Keystone admin credentials are invalid: %s' % kut.message)
- except ClientException as cle:
- logger.exception('Keystone unauthorized admin')
- return (False, 'Keystone ping check failed: %s' % cle.message)
+ sess = session.Session(auth=auth)
+ client = client_v3.Client(session=sess, timeout=self.timeout, debug=self.debug)
+ return client, sess
- return (True, None)
+ def ping(self):
+ try:
+ _, sess = self._get_client(self.admin_username, self.admin_password)
+ assert sess.get_user_id() # Make sure we loaded a valid user.
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized admin")
+ return (False, "Keystone admin credentials are invalid: %s" % kut.message)
+ except ClientException as cle:
+ logger.exception("Keystone unauthorized admin")
+ return (False, "Keystone ping check failed: %s" % cle.message)
- def at_least_one_user_exists(self):
- logger.debug('Checking if any users exist in admin tenant in Keystone')
- try:
- # Just make sure the admin can connect to the project.
- self._get_client(self.admin_username, self.admin_password, self.admin_tenant)
- return (True, None)
- except ClientException as cle:
- # Catch exceptions to give the user our custom error message
- logger.exception('Unable to list users in Keystone')
- return (False, cle.message)
+ return (True, None)
- def verify_credentials(self, username_or_email, password):
- try:
- keystone_client, sess = self._get_client(username_or_email, password)
- user_id = sess.get_user_id()
- assert user_id
+ def at_least_one_user_exists(self):
+ logger.debug("Checking if any users exist in admin tenant in Keystone")
+ try:
+ # Just make sure the admin can connect to the project.
+ self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ return (True, None)
+ except ClientException as cle:
+ # Catch exceptions to give the user our custom error message
+ logger.exception("Unable to list users in Keystone")
+ return (False, cle.message)
- keystone_client, sess = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
- user = keystone_client.users.get(user_id)
- if self.requires_email and not hasattr(user, 'email'):
- return (None, 'Missing email field for user %s' % user_id)
+ def verify_credentials(self, username_or_email, password):
+ try:
+ keystone_client, sess = self._get_client(username_or_email, password)
+ user_id = sess.get_user_id()
+ assert user_id
- return (self._user_info(user), None)
- except KeystoneAuthorizationFailure as kaf:
- logger.exception('Keystone auth failure for user: %s', username_or_email)
- return (None, 'Invalid username or password')
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized for user: %s', username_or_email)
- return (None, 'Invalid username or password')
- except ClientException as cle:
- logger.exception('Keystone unauthorized for user: %s', username_or_email)
- return (None, 'Invalid username or password')
+ keystone_client, sess = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ user = keystone_client.users.get(user_id)
+ if self.requires_email and not hasattr(user, "email"):
+ return (None, "Missing email field for user %s" % user_id)
- def get_user(self, username_or_email):
- users_found, _, err_msg = self.query_users(username_or_email)
- if err_msg is not None:
- return (None, err_msg)
+ return (self._user_info(user), None)
+ except KeystoneAuthorizationFailure as kaf:
+ logger.exception("Keystone auth failure for user: %s", username_or_email)
+ return (None, "Invalid username or password")
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized for user: %s", username_or_email)
+ return (None, "Invalid username or password")
+ except ClientException as cle:
+ logger.exception("Keystone unauthorized for user: %s", username_or_email)
+ return (None, "Invalid username or password")
- if len(users_found) != 1:
- return (None, 'Single user not found')
+ def get_user(self, username_or_email):
+ users_found, _, err_msg = self.query_users(username_or_email)
+ if err_msg is not None:
+ return (None, err_msg)
- user = users_found[0]
- if self.requires_email and not user.email:
- return (None, 'Missing email field for user %s' % user.id)
+ if len(users_found) != 1:
+ return (None, "Single user not found")
- return (user, None)
+ user = users_found[0]
+ if self.requires_email and not user.email:
+ return (None, "Missing email field for user %s" % user.id)
- def check_group_lookup_args(self, group_lookup_args):
- if not group_lookup_args.get('group_id'):
- return (False, 'Missing group_id')
+ return (user, None)
- group_id = group_lookup_args['group_id']
- return self._check_group(group_id)
+ def check_group_lookup_args(self, group_lookup_args):
+ if not group_lookup_args.get("group_id"):
+ return (False, "Missing group_id")
- def _check_group(self, group_id):
- try:
- admin_client, _ = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
- return (bool(admin_client.groups.get(group_id)), None)
- except KeystoneNotFound:
- return (False, 'Group not found')
- except KeystoneAuthorizationFailure as kaf:
- logger.exception('Keystone auth failure for admin user for group lookup %s', group_id)
- return (False, kaf.message or 'Invalid admin username or password')
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized for admin user for group lookup %s', group_id)
- return (False, kut.message or 'Invalid admin username or password')
- except ClientException as cle:
- logger.exception('Keystone unauthorized for admin user for group lookup %s', group_id)
- return (False, cle.message or 'Invalid admin username or password')
+ group_id = group_lookup_args["group_id"]
+ return self._check_group(group_id)
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- group_id = group_lookup_args['group_id']
+ def _check_group(self, group_id):
+ try:
+ admin_client, _ = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ return (bool(admin_client.groups.get(group_id)), None)
+ except KeystoneNotFound:
+ return (False, "Group not found")
+ except KeystoneAuthorizationFailure as kaf:
+ logger.exception(
+ "Keystone auth failure for admin user for group lookup %s", group_id
+ )
+ return (False, kaf.message or "Invalid admin username or password")
+ except KeystoneUnauthorized as kut:
+ logger.exception(
+ "Keystone unauthorized for admin user for group lookup %s", group_id
+ )
+ return (False, kut.message or "Invalid admin username or password")
+ except ClientException as cle:
+ logger.exception(
+ "Keystone unauthorized for admin user for group lookup %s", group_id
+ )
+ return (False, cle.message or "Invalid admin username or password")
- (status, err) = self._check_group(group_id)
- if not status:
- return (None, err)
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ group_id = group_lookup_args["group_id"]
- try:
- admin_client, _ = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
- user_info_iterator = admin_client.users.list(group=group_id)
- def iterator():
- for user in user_info_iterator:
- yield (self._user_info(user), None)
+ (status, err) = self._check_group(group_id)
+ if not status:
+ return (None, err)
- return (iterator(), None)
- except KeystoneAuthorizationFailure as kaf:
- logger.exception('Keystone auth failure for admin user for group lookup %s', group_id)
- return (False, kaf.message or 'Invalid admin username or password')
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized for admin user for group lookup %s', group_id)
- return (False, kut.message or 'Invalid admin username or password')
- except ClientException as cle:
- logger.exception('Keystone unauthorized for admin user for group lookup %s', group_id)
- return (False, cle.message or 'Invalid admin username or password')
+ try:
+ admin_client, _ = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+ user_info_iterator = admin_client.users.list(group=group_id)
- @staticmethod
- def _user_info(user):
- email = user.email if hasattr(user, 'email') else None
- return UserInformation(user.name, email, user.id)
+ def iterator():
+ for user in user_info_iterator:
+ yield (self._user_info(user), None)
- def query_users(self, query, limit=20):
- if len(query) < 3:
- return ([], self.federated_service, None)
+ return (iterator(), None)
+ except KeystoneAuthorizationFailure as kaf:
+ logger.exception(
+ "Keystone auth failure for admin user for group lookup %s", group_id
+ )
+ return (False, kaf.message or "Invalid admin username or password")
+ except KeystoneUnauthorized as kut:
+ logger.exception(
+ "Keystone unauthorized for admin user for group lookup %s", group_id
+ )
+ return (False, kut.message or "Invalid admin username or password")
+ except ClientException as cle:
+ logger.exception(
+ "Keystone unauthorized for admin user for group lookup %s", group_id
+ )
+ return (False, cle.message or "Invalid admin username or password")
- try:
- admin_client, _ = self._get_client(self.admin_username, self.admin_password,
- self.admin_tenant)
+ @staticmethod
+ def _user_info(user):
+ email = user.email if hasattr(user, "email") else None
+ return UserInformation(user.name, email, user.id)
- found_users = list(take(limit, admin_client.users.list(name=query)))
- logger.debug('For Keystone query %s found users: %s', query, found_users)
- if not found_users:
- return ([], self.federated_service, None)
+ def query_users(self, query, limit=20):
+ if len(query) < 3:
+ return ([], self.federated_service, None)
- return ([self._user_info(user) for user in found_users], self.federated_service, None)
- except KeystoneAuthorizationFailure as kaf:
- logger.exception('Keystone auth failure for admin user for query %s', query)
- return (None, self.federated_service, kaf.message or 'Invalid admin username or password')
- except KeystoneUnauthorized as kut:
- logger.exception('Keystone unauthorized for admin user for query %s', query)
- return (None, self.federated_service, kut.message or 'Invalid admin username or password')
- except ClientException as cle:
- logger.exception('Keystone unauthorized for admin user for query %s', query)
- return (None, self.federated_service, cle.message or 'Invalid admin username or password')
+ try:
+ admin_client, _ = self._get_client(
+ self.admin_username, self.admin_password, self.admin_tenant
+ )
+
+ found_users = list(take(limit, admin_client.users.list(name=query)))
+ logger.debug("For Keystone query %s found users: %s", query, found_users)
+ if not found_users:
+ return ([], self.federated_service, None)
+
+ return (
+ [self._user_info(user) for user in found_users],
+ self.federated_service,
+ None,
+ )
+ except KeystoneAuthorizationFailure as kaf:
+ logger.exception("Keystone auth failure for admin user for query %s", query)
+ return (
+ None,
+ self.federated_service,
+ kaf.message or "Invalid admin username or password",
+ )
+ except KeystoneUnauthorized as kut:
+ logger.exception("Keystone unauthorized for admin user for query %s", query)
+ return (
+ None,
+ self.federated_service,
+ kut.message or "Invalid admin username or password",
+ )
+ except ClientException as cle:
+ logger.exception("Keystone unauthorized for admin user for query %s", query)
+ return (
+ None,
+ self.federated_service,
+ cle.message or "Invalid admin username or password",
+ )
diff --git a/data/users/shared.py b/data/users/shared.py
index 8f1cc09df..fd507837e 100644
--- a/data/users/shared.py
+++ b/data/users/shared.py
@@ -7,24 +7,24 @@ from data import model
def can_create_user(email_address, blacklisted_domains=None):
- """ Returns true if a user with the specified e-mail address can be created. """
+ """ Returns true if a user with the specified e-mail address can be created. """
- if features.BLACKLISTED_EMAILS and email_address and '@' in email_address:
- blacklisted_domains = blacklisted_domains or []
- _, email_domain = email_address.split('@', 1)
- extracted = tldextract.extract(email_domain)
- if extracted.registered_domain.lower() in blacklisted_domains:
- return False
+ if features.BLACKLISTED_EMAILS and email_address and "@" in email_address:
+ blacklisted_domains = blacklisted_domains or []
+ _, email_domain = email_address.split("@", 1)
+ extracted = tldextract.extract(email_domain)
+ if extracted.registered_domain.lower() in blacklisted_domains:
+ return False
- if not features.USER_CREATION:
- return False
+ if not features.USER_CREATION:
+ return False
- if features.INVITE_ONLY_USER_CREATION:
- if not email_address:
- return False
+ if features.INVITE_ONLY_USER_CREATION:
+ if not email_address:
+ return False
- # Check to see that there is an invite for the e-mail address.
- return bool(model.team.lookup_team_invites_by_email(email_address))
+ # Check to see that there is an invite for the e-mail address.
+ return bool(model.team.lookup_team_invites_by_email(email_address))
- # Otherwise the user can be created (assuming it doesn't already exist, of course)
- return True
+ # Otherwise the user can be created (assuming it doesn't already exist, of course)
+ return True
diff --git a/data/users/teamsync.py b/data/users/teamsync.py
index 2ab0fea10..d122e6924 100644
--- a/data/users/teamsync.py
+++ b/data/users/teamsync.py
@@ -10,127 +10,180 @@ MAX_TEAMS_PER_ITERATION = 500
def sync_teams_to_groups(authentication, stale_cutoff):
- """ Performs team syncing by looking up any stale team(s) found, and performing the sync
+ """ Performs team syncing by looking up any stale team(s) found, and performing the sync
operation on them.
"""
- logger.debug('Looking up teams to sync to groups')
+ logger.debug("Looking up teams to sync to groups")
- sync_team_tried = set()
- while len(sync_team_tried) < MAX_TEAMS_PER_ITERATION:
- # Find a stale team.
- stale_team_sync = model.team.get_stale_team(stale_cutoff)
- if not stale_team_sync:
- logger.debug('No additional stale team found; sleeping')
- return
+ sync_team_tried = set()
+ while len(sync_team_tried) < MAX_TEAMS_PER_ITERATION:
+ # Find a stale team.
+ stale_team_sync = model.team.get_stale_team(stale_cutoff)
+ if not stale_team_sync:
+ logger.debug("No additional stale team found; sleeping")
+ return
- # Make sure we don't try to reprocess a team on this iteration.
- if stale_team_sync.id in sync_team_tried:
- break
+ # Make sure we don't try to reprocess a team on this iteration.
+ if stale_team_sync.id in sync_team_tried:
+ break
- sync_team_tried.add(stale_team_sync.id)
+ sync_team_tried.add(stale_team_sync.id)
- # Sync the team.
- sync_successful = sync_team(authentication, stale_team_sync)
- if not sync_successful:
- return
+ # Sync the team.
+ sync_successful = sync_team(authentication, stale_team_sync)
+ if not sync_successful:
+ return
def sync_team(authentication, stale_team_sync):
- """ Performs synchronization of a team (as referenced by the TeamSync stale_team_sync).
+ """ Performs synchronization of a team (as referenced by the TeamSync stale_team_sync).
Returns True on success and False otherwise.
"""
- sync_config = json.loads(stale_team_sync.config)
- logger.info('Syncing team `%s` under organization %s via %s (#%s)', stale_team_sync.team.name,
- stale_team_sync.team.organization.username, sync_config, stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config})
+ sync_config = json.loads(stale_team_sync.config)
+ logger.info(
+ "Syncing team `%s` under organization %s via %s (#%s)",
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={"team": stale_team_sync.team_id, "sync_config": sync_config},
+ )
- # Load all the existing members of the team in Quay that are bound to the auth service.
- existing_users = model.team.get_federated_team_member_mapping(stale_team_sync.team,
- authentication.federated_service)
+ # Load all the existing members of the team in Quay that are bound to the auth service.
+ existing_users = model.team.get_federated_team_member_mapping(
+ stale_team_sync.team, authentication.federated_service
+ )
- logger.debug('Existing membership of %s for team `%s` under organization %s via %s (#%s)',
- len(existing_users), stale_team_sync.team.name,
- stale_team_sync.team.organization.username, sync_config, stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config,
- 'existing_member_count': len(existing_users)})
+ logger.debug(
+ "Existing membership of %s for team `%s` under organization %s via %s (#%s)",
+ len(existing_users),
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={
+ "team": stale_team_sync.team_id,
+ "sync_config": sync_config,
+ "existing_member_count": len(existing_users),
+ },
+ )
- # Load all the members of the team from the authenication system.
- (member_iterator, err) = authentication.iterate_group_members(sync_config)
- if err is not None:
- logger.error('Got error when trying to iterate group members with config %s: %s',
- sync_config, err)
- return False
-
- # Collect all the members currently found in the group, adding them to the team as we go
- # along.
- group_membership = set()
- for (member_info, err) in member_iterator:
+ # Load all the members of the team from the authenication system.
+ (member_iterator, err) = authentication.iterate_group_members(sync_config)
if err is not None:
- logger.error('Got error when trying to construct a member: %s', err)
- continue
+ logger.error(
+ "Got error when trying to iterate group members with config %s: %s",
+ sync_config,
+ err,
+ )
+ return False
- # If the member is already in the team, nothing more to do.
- if member_info.username in existing_users:
- logger.debug('Member %s already in team `%s` under organization %s via %s (#%s)',
- member_info.username, stale_team_sync.team.name,
- stale_team_sync.team.organization.username, sync_config,
- stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config,
- 'member': member_info.username})
+ # Collect all the members currently found in the group, adding them to the team as we go
+ # along.
+ group_membership = set()
+ for (member_info, err) in member_iterator:
+ if err is not None:
+ logger.error("Got error when trying to construct a member: %s", err)
+ continue
- group_membership.add(existing_users[member_info.username])
- continue
+ # If the member is already in the team, nothing more to do.
+ if member_info.username in existing_users:
+ logger.debug(
+ "Member %s already in team `%s` under organization %s via %s (#%s)",
+ member_info.username,
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={
+ "team": stale_team_sync.team_id,
+ "sync_config": sync_config,
+ "member": member_info.username,
+ },
+ )
- # Retrieve the Quay user associated with the member info.
- (quay_user, err) = authentication.get_and_link_federated_user_info(member_info,
- internal_create=True)
- if err is not None:
- logger.error('Could not link external user %s to an internal user: %s',
- member_info.username, err,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config,
- 'member': member_info.username, 'error': err})
- continue
+ group_membership.add(existing_users[member_info.username])
+ continue
- # Add the user to the membership set.
- group_membership.add(quay_user.id)
+ # Retrieve the Quay user associated with the member info.
+ (quay_user, err) = authentication.get_and_link_federated_user_info(
+ member_info, internal_create=True
+ )
+ if err is not None:
+ logger.error(
+ "Could not link external user %s to an internal user: %s",
+ member_info.username,
+ err,
+ extra={
+ "team": stale_team_sync.team_id,
+ "sync_config": sync_config,
+ "member": member_info.username,
+ "error": err,
+ },
+ )
+ continue
- # Add the user to the team.
- try:
- logger.info('Adding member %s to team `%s` under organization %s via %s (#%s)',
- quay_user.username, stale_team_sync.team.name,
- stale_team_sync.team.organization.username, sync_config,
- stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config,
- 'member': quay_user.username})
+ # Add the user to the membership set.
+ group_membership.add(quay_user.id)
- model.team.add_user_to_team(quay_user, stale_team_sync.team)
- except model.UserAlreadyInTeam:
- # If the user is already present, nothing more to do for them.
- pass
+ # Add the user to the team.
+ try:
+ logger.info(
+ "Adding member %s to team `%s` under organization %s via %s (#%s)",
+ quay_user.username,
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={
+ "team": stale_team_sync.team_id,
+ "sync_config": sync_config,
+ "member": quay_user.username,
+ },
+ )
- # Update the transaction and last_updated time of the team sync. Only if it matches
- # the current value will we then perform the deletion step.
- got_transaction_handle = model.team.update_sync_status(stale_team_sync)
- if not got_transaction_handle:
- # Another worker updated this team. Nothing more to do.
- logger.debug('Another worker synced team `%s` under organization %s via %s (#%s)',
- stale_team_sync.team.name,
- stale_team_sync.team.organization.username, sync_config,
- stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config})
+ model.team.add_user_to_team(quay_user, stale_team_sync.team)
+ except model.UserAlreadyInTeam:
+ # If the user is already present, nothing more to do for them.
+ pass
+
+ # Update the transaction and last_updated time of the team sync. Only if it matches
+ # the current value will we then perform the deletion step.
+ got_transaction_handle = model.team.update_sync_status(stale_team_sync)
+ if not got_transaction_handle:
+ # Another worker updated this team. Nothing more to do.
+ logger.debug(
+ "Another worker synced team `%s` under organization %s via %s (#%s)",
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={"team": stale_team_sync.team_id, "sync_config": sync_config},
+ )
+ return True
+
+ # Delete any team members not found in the backing auth system.
+ logger.debug(
+ "Deleting stale members for team `%s` under organization %s via %s (#%s)",
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ extra={"team": stale_team_sync.team_id, "sync_config": sync_config},
+ )
+
+ deleted = model.team.delete_members_not_present(
+ stale_team_sync.team, group_membership
+ )
+
+ # Done!
+ logger.info(
+ "Finishing sync for team `%s` under organization %s via %s (#%s): %s deleted",
+ stale_team_sync.team.name,
+ stale_team_sync.team.organization.username,
+ sync_config,
+ stale_team_sync.team_id,
+ deleted,
+ extra={"team": stale_team_sync.team_id, "sync_config": sync_config},
+ )
return True
-
- # Delete any team members not found in the backing auth system.
- logger.debug('Deleting stale members for team `%s` under organization %s via %s (#%s)',
- stale_team_sync.team.name, stale_team_sync.team.organization.username,
- sync_config, stale_team_sync.team_id,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config})
-
- deleted = model.team.delete_members_not_present(stale_team_sync.team, group_membership)
-
- # Done!
- logger.info('Finishing sync for team `%s` under organization %s via %s (#%s): %s deleted',
- stale_team_sync.team.name, stale_team_sync.team.organization.username,
- sync_config, stale_team_sync.team_id, deleted,
- extra={'team': stale_team_sync.team_id, 'sync_config': sync_config})
- return True
diff --git a/data/users/test/test_shared.py b/data/users/test/test_shared.py
index d211fb485..18d2eb9cb 100644
--- a/data/users/test/test_shared.py
+++ b/data/users/test/test_shared.py
@@ -7,49 +7,67 @@ from data.users.shared import can_create_user
from test.fixtures import *
-@pytest.mark.parametrize('open_creation, invite_only, email, has_invite, can_create', [
- # Open user creation => always allowed.
- (True, False, None, False, True),
- # Open user creation => always allowed.
- (True, False, 'foo@example.com', False, True),
+@pytest.mark.parametrize(
+ "open_creation, invite_only, email, has_invite, can_create",
+ [
+ # Open user creation => always allowed.
+ (True, False, None, False, True),
+ # Open user creation => always allowed.
+ (True, False, "foo@example.com", False, True),
+ # Invite only user creation + no invite => disallowed.
+ (True, True, None, False, False),
+ # Invite only user creation + no invite => disallowed.
+ (True, True, "foo@example.com", False, False),
+ # Invite only user creation + invite => allowed.
+ (True, True, "foo@example.com", True, True),
+ # No open creation => Disallowed.
+ (False, True, "foo@example.com", False, False),
+ (False, True, "foo@example.com", True, False),
+ # Blacklisted emails => Disallowed.
+ (True, False, "foo@blacklisted.com", False, False),
+ (True, False, "foo@blacklisted.org", False, False),
+ (True, False, "foo@BlAcKlIsTeD.CoM", False, False), # Verify Capitalization
+ (True, False, u"foo@mail.bLacklisted.Com", False, False), # Verify unicode
+ (True, False, "foo@blacklisted.net", False, True), # Avoid False Positives
+ (
+ True,
+ False,
+ "foo@myblacklisted.com",
+ False,
+ True,
+ ), # Avoid partial domain matches
+ (
+ True,
+ False,
+ "fooATblacklisted.com",
+ False,
+ True,
+ ), # Ignore invalid email addresses
+ ],
+)
+@pytest.mark.parametrize("blacklisting_enabled", [True, False])
+def test_can_create_user(
+ open_creation, invite_only, email, has_invite, can_create, blacklisting_enabled, app
+):
- # Invite only user creation + no invite => disallowed.
- (True, True, None, False, False),
+ # Mock list of blacklisted domains
+ blacklisted_domains = ["blacklisted.com", "blacklisted.org"]
- # Invite only user creation + no invite => disallowed.
- (True, True, 'foo@example.com', False, False),
+ if has_invite:
+ inviter = model.user.get_user("devtable")
+ team = model.team.get_organization_team("buynlarge", "owners")
+ model.team.add_or_invite_to_team(inviter, team, email=email)
- # Invite only user creation + invite => allowed.
- (True, True, 'foo@example.com', True, True),
-
- # No open creation => Disallowed.
- (False, True, 'foo@example.com', False, False),
- (False, True, 'foo@example.com', True, False),
-
- # Blacklisted emails => Disallowed.
- (True, False, 'foo@blacklisted.com', False, False),
- (True, False, 'foo@blacklisted.org', False, False),
- (True, False, 'foo@BlAcKlIsTeD.CoM', False, False), # Verify Capitalization
- (True, False, u'foo@mail.bLacklisted.Com', False, False), # Verify unicode
- (True, False, 'foo@blacklisted.net', False, True), # Avoid False Positives
- (True, False, 'foo@myblacklisted.com', False, True), # Avoid partial domain matches
- (True, False, 'fooATblacklisted.com', False, True), # Ignore invalid email addresses
-])
-@pytest.mark.parametrize('blacklisting_enabled', [True, False])
-def test_can_create_user(open_creation, invite_only, email, has_invite, can_create, blacklisting_enabled, app):
-
- # Mock list of blacklisted domains
- blacklisted_domains = ['blacklisted.com', 'blacklisted.org']
-
- if has_invite:
- inviter = model.user.get_user('devtable')
- team = model.team.get_organization_team('buynlarge', 'owners')
- model.team.add_or_invite_to_team(inviter, team, email=email)
-
- with patch('features.USER_CREATION', open_creation):
- with patch('features.INVITE_ONLY_USER_CREATION', invite_only):
- with patch('features.BLACKLISTED_EMAILS', blacklisting_enabled):
- if email and any(domain in email.lower() for domain in blacklisted_domains) and not blacklisting_enabled:
- can_create = True # blacklisted domains can be used, if blacklisting is disabled
- assert can_create_user(email, blacklisted_domains) == can_create
+ with patch("features.USER_CREATION", open_creation):
+ with patch("features.INVITE_ONLY_USER_CREATION", invite_only):
+ with patch("features.BLACKLISTED_EMAILS", blacklisting_enabled):
+ if (
+ email
+ and any(domain in email.lower() for domain in blacklisted_domains)
+ and not blacklisting_enabled
+ ):
+ can_create = (
+ True
+ ) # blacklisted domains can be used, if blacklisting is disabled
+ assert can_create_user(email, blacklisted_domains) == can_create
diff --git a/data/users/test/test_teamsync.py b/data/users/test/test_teamsync.py
index 470c31707..da531e547 100644
--- a/data/users/test/test_teamsync.py
+++ b/data/users/test/test_teamsync.py
@@ -15,318 +15,365 @@ from util.names import parse_robot_username
from test.fixtures import *
-_FAKE_AUTH = 'fake'
+_FAKE_AUTH = "fake"
+
class FakeUsers(FederatedUsers):
- def __init__(self, group_members):
- super(FakeUsers, self).__init__(_FAKE_AUTH, False)
- self.group_tuples = [(m, None) for m in group_members]
+ def __init__(self, group_members):
+ super(FakeUsers, self).__init__(_FAKE_AUTH, False)
+ self.group_tuples = [(m, None) for m in group_members]
- def iterate_group_members(self, group_lookup_args, page_size=None, disable_pagination=False):
- return (self.group_tuples, None)
+ def iterate_group_members(
+ self, group_lookup_args, page_size=None, disable_pagination=False
+ ):
+ return (self.group_tuples, None)
@pytest.fixture(params=[True, False])
def user_creation(request):
- with patch('features.USER_CREATION', request.param):
- yield
+ with patch("features.USER_CREATION", request.param):
+ yield
@pytest.fixture(params=[True, False])
def invite_only_user_creation(request):
- with patch('features.INVITE_ONLY_USER_CREATION', request.param):
- yield
+ with patch("features.INVITE_ONLY_USER_CREATION", request.param):
+ yield
@pytest.fixture(params=[True, False])
def blacklisted_emails(request):
- mock_blacklisted_domains = {'BLACKLISTED_EMAIL_DOMAINS': ['blacklisted.com', 'blacklisted.net']}
- with patch('features.BLACKLISTED_EMAILS', request.param):
- with patch.dict('data.model.config.app_config', mock_blacklisted_domains):
- yield
+ mock_blacklisted_domains = {
+ "BLACKLISTED_EMAIL_DOMAINS": ["blacklisted.com", "blacklisted.net"]
+ }
+ with patch("features.BLACKLISTED_EMAILS", request.param):
+ with patch.dict("data.model.config.app_config", mock_blacklisted_domains):
+ yield
-@pytest.mark.skipif(os.environ.get('TEST_DATABASE_URI', '').find('postgres') >= 0,
- reason="Postgres fails when existing members are added under the savepoint")
-@pytest.mark.parametrize('starting_membership,group_membership,expected_membership', [
- # Empty team + single member in group => Single member in team.
- ([],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['someuser']),
+@pytest.mark.skipif(
+ os.environ.get("TEST_DATABASE_URI", "").find("postgres") >= 0,
+ reason="Postgres fails when existing members are added under the savepoint",
+)
+@pytest.mark.parametrize(
+ "starting_membership,group_membership,expected_membership",
+ [
+ # Empty team + single member in group => Single member in team.
+ (
+ [],
+ [UserInformation("someuser", "someuser", "someuser@devtable.com")],
+ ["someuser"],
+ ),
+ # Team with a Quay user + empty group => empty team.
+ ([("someuser", None)], [], []),
+ # Team with an existing external user + user is in the group => no changes.
+ (
+ [("someuser", "someuser")],
+ [UserInformation("someuser", "someuser", "someuser@devtable.com")],
+ ["someuser"],
+ ),
+ # Team with an existing external user (with a different Quay username) + user is in the group.
+ # => no changes
+ (
+ [("anotherquayname", "someuser")],
+ [UserInformation("someuser", "someuser", "someuser@devtable.com")],
+ ["someuser"],
+ ),
+ # Team missing a few members that are in the group => members added.
+ (
+ [("someuser", "someuser")],
+ [
+ UserInformation(
+ "anotheruser", "anotheruser", "anotheruser@devtable.com"
+ ),
+ UserInformation("someuser", "someuser", "someuser@devtable.com"),
+ UserInformation("thirduser", "thirduser", "thirduser@devtable.com"),
+ ],
+ ["anotheruser", "someuser", "thirduser"],
+ ),
+ # Team has a few extra members no longer in the group => members removed.
+ (
+ [
+ ("anotheruser", "anotheruser"),
+ ("someuser", "someuser"),
+ ("thirduser", "thirduser"),
+ ("nontestuser", None),
+ ],
+ [UserInformation("thirduser", "thirduser", "thirduser@devtable.com")],
+ ["thirduser"],
+ ),
+ # Team has different membership than the group => members added and removed.
+ (
+ [
+ ("anotheruser", "anotheruser"),
+ ("someuser", "someuser"),
+ ("nontestuser", None),
+ ],
+ [
+ UserInformation(
+ "anotheruser", "anotheruser", "anotheruser@devtable.com"
+ ),
+ UserInformation(
+ "missinguser", "missinguser", "missinguser@devtable.com"
+ ),
+ ],
+ ["anotheruser", "missinguser"],
+ ),
+ # Team has same membership but some robots => robots remain and no other changes.
+ (
+ [
+ ("someuser", "someuser"),
+ ("buynlarge+anotherbot", None),
+ ("buynlarge+somerobot", None),
+ ],
+ [UserInformation("someuser", "someuser", "someuser@devtable.com")],
+ ["someuser", "buynlarge+somerobot", "buynlarge+anotherbot"],
+ ),
+ # Team has an extra member and some robots => member removed and robots remain.
+ (
+ [
+ ("someuser", "someuser"),
+ ("buynlarge+anotherbot", None),
+ ("buynlarge+somerobot", None),
+ ],
+ [
+ # No members.
+ ],
+ ["buynlarge+somerobot", "buynlarge+anotherbot"],
+ ),
+ # Team has a different member and some robots => member changed and robots remain.
+ (
+ [
+ ("someuser", "someuser"),
+ ("buynlarge+anotherbot", None),
+ ("buynlarge+somerobot", None),
+ ],
+ [UserInformation("anotheruser", "anotheruser", "anotheruser@devtable.com")],
+ ["anotheruser", "buynlarge+somerobot", "buynlarge+anotherbot"],
+ ),
+ # Team with an existing external user (with a different Quay username) + user is in the group.
+ # => no changes and robots remain.
+ (
+ [("anotherquayname", "someuser"), ("buynlarge+anotherbot", None)],
+ [UserInformation("someuser", "someuser", "someuser@devtable.com")],
+ ["someuser", "buynlarge+anotherbot"],
+ ),
+ # Team which returns the same member twice, as pagination in some engines (like LDAP) is not
+ # stable.
+ (
+ [],
+ [
+ UserInformation("someuser", "someuser", "someuser@devtable.com"),
+ UserInformation(
+ "anotheruser", "anotheruser", "anotheruser@devtable.com"
+ ),
+ UserInformation("someuser", "someuser", "someuser@devtable.com"),
+ ],
+ ["anotheruser", "someuser"],
+ ),
+ ],
+)
+def test_syncing(
+ user_creation,
+ invite_only_user_creation,
+ starting_membership,
+ group_membership,
+ expected_membership,
+ blacklisted_emails,
+ app,
+):
+ org = model.organization.get_organization("buynlarge")
- # Team with a Quay user + empty group => empty team.
- ([('someuser', None)],
- [],
- []),
+ # Necessary for the fake auth entries to be created in FederatedLogin.
+ database.LoginService.create(name=_FAKE_AUTH)
- # Team with an existing external user + user is in the group => no changes.
- ([
- ('someuser', 'someuser'),
- ],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['someuser']),
+ # Assert the team is empty, so we have a clean slate.
+ sync_team_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert len(list(model.team.list_team_users(sync_team_info.team))) == 0
- # Team with an existing external user (with a different Quay username) + user is in the group.
- # => no changes
- ([
- ('anotherquayname', 'someuser'),
- ],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['someuser']),
+ # Add the existing starting members to the team.
+ for starting_member in starting_membership:
+ (quay_username, fakeauth_username) = starting_member
+ if "+" in quay_username:
+ # Add a robot.
+ (_, shortname) = parse_robot_username(quay_username)
+ robot, _ = model.user.create_robot(shortname, org)
+ model.team.add_user_to_team(robot, sync_team_info.team)
+ else:
+ email = quay_username + "@devtable.com"
- # Team missing a few members that are in the group => members added.
- ([('someuser', 'someuser')],
- [
- UserInformation('anotheruser', 'anotheruser', 'anotheruser@devtable.com'),
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- UserInformation('thirduser', 'thirduser', 'thirduser@devtable.com'),
- ],
- ['anotheruser', 'someuser', 'thirduser']),
+ if fakeauth_username is None:
+ quay_user = model.user.create_user_noverify(quay_username, email)
+ else:
+ quay_user = model.user.create_federated_user(
+ quay_username, email, _FAKE_AUTH, fakeauth_username, False
+ )
- # Team has a few extra members no longer in the group => members removed.
- ([
- ('anotheruser', 'anotheruser'),
- ('someuser', 'someuser'),
- ('thirduser', 'thirduser'),
- ('nontestuser', None),
- ],
- [
- UserInformation('thirduser', 'thirduser', 'thirduser@devtable.com'),
- ],
- ['thirduser']),
+ model.team.add_user_to_team(quay_user, sync_team_info.team)
- # Team has different membership than the group => members added and removed.
- ([
- ('anotheruser', 'anotheruser'),
- ('someuser', 'someuser'),
- ('nontestuser', None),
- ],
- [
- UserInformation('anotheruser', 'anotheruser', 'anotheruser@devtable.com'),
- UserInformation('missinguser', 'missinguser', 'missinguser@devtable.com'),
- ],
- ['anotheruser', 'missinguser']),
+ # Call syncing on the team.
+ fake_auth = FakeUsers(group_membership)
+ assert sync_team(fake_auth, sync_team_info)
- # Team has same membership but some robots => robots remain and no other changes.
- ([
- ('someuser', 'someuser'),
- ('buynlarge+anotherbot', None),
- ('buynlarge+somerobot', None),
- ],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['someuser', 'buynlarge+somerobot', 'buynlarge+anotherbot']),
+ # Ensure the last updated time and transaction_id's have changed.
+ updated_sync_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert updated_sync_info.last_updated is not None
+ assert updated_sync_info.transaction_id != sync_team_info.transaction_id
- # Team has an extra member and some robots => member removed and robots remain.
- ([
- ('someuser', 'someuser'),
- ('buynlarge+anotherbot', None),
- ('buynlarge+somerobot', None),
- ],
- [
- # No members.
- ],
- ['buynlarge+somerobot', 'buynlarge+anotherbot']),
+ users_expected = set([name for name in expected_membership if "+" not in name])
+ robots_expected = set([name for name in expected_membership if "+" in name])
+ assert len(users_expected) + len(robots_expected) == len(expected_membership)
- # Team has a different member and some robots => member changed and robots remain.
- ([
- ('someuser', 'someuser'),
- ('buynlarge+anotherbot', None),
- ('buynlarge+somerobot', None),
- ],
- [
- UserInformation('anotheruser', 'anotheruser', 'anotheruser@devtable.com'),
- ],
- ['anotheruser', 'buynlarge+somerobot', 'buynlarge+anotherbot']),
+ # Check that the team's users match those expected.
+ service_user_map = model.team.get_federated_team_member_mapping(
+ sync_team_info.team, _FAKE_AUTH
+ )
+ assert set(service_user_map.keys()) == users_expected
- # Team with an existing external user (with a different Quay username) + user is in the group.
- # => no changes and robots remain.
- ([
- ('anotherquayname', 'someuser'),
- ('buynlarge+anotherbot', None),
- ],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['someuser', 'buynlarge+anotherbot']),
+ quay_users = model.team.list_team_users(sync_team_info.team)
+ assert len(quay_users) == len(users_expected)
- # Team which returns the same member twice, as pagination in some engines (like LDAP) is not
- # stable.
- ([],
- [
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- UserInformation('anotheruser', 'anotheruser', 'anotheruser@devtable.com'),
- UserInformation('someuser', 'someuser', 'someuser@devtable.com'),
- ],
- ['anotheruser', 'someuser']),
-])
-def test_syncing(user_creation, invite_only_user_creation, starting_membership, group_membership,
- expected_membership, blacklisted_emails, app):
- org = model.organization.get_organization('buynlarge')
+ for quay_user in quay_users:
+ fakeauth_record = model.user.lookup_federated_login(quay_user, _FAKE_AUTH)
+ assert fakeauth_record is not None
+ assert fakeauth_record.service_ident in users_expected
+ assert service_user_map[fakeauth_record.service_ident] == quay_user.id
- # Necessary for the fake auth entries to be created in FederatedLogin.
- database.LoginService.create(name=_FAKE_AUTH)
-
- # Assert the team is empty, so we have a clean slate.
- sync_team_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert len(list(model.team.list_team_users(sync_team_info.team))) == 0
-
- # Add the existing starting members to the team.
- for starting_member in starting_membership:
- (quay_username, fakeauth_username) = starting_member
- if '+' in quay_username:
- # Add a robot.
- (_, shortname) = parse_robot_username(quay_username)
- robot, _ = model.user.create_robot(shortname, org)
- model.team.add_user_to_team(robot, sync_team_info.team)
- else:
- email = quay_username + '@devtable.com'
-
- if fakeauth_username is None:
- quay_user = model.user.create_user_noverify(quay_username, email)
- else:
- quay_user = model.user.create_federated_user(quay_username, email, _FAKE_AUTH,
- fakeauth_username, False)
-
- model.team.add_user_to_team(quay_user, sync_team_info.team)
-
- # Call syncing on the team.
- fake_auth = FakeUsers(group_membership)
- assert sync_team(fake_auth, sync_team_info)
-
- # Ensure the last updated time and transaction_id's have changed.
- updated_sync_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert updated_sync_info.last_updated is not None
- assert updated_sync_info.transaction_id != sync_team_info.transaction_id
-
- users_expected = set([name for name in expected_membership if '+' not in name])
- robots_expected = set([name for name in expected_membership if '+' in name])
- assert len(users_expected) + len(robots_expected) == len(expected_membership)
-
- # Check that the team's users match those expected.
- service_user_map = model.team.get_federated_team_member_mapping(sync_team_info.team,
- _FAKE_AUTH)
- assert set(service_user_map.keys()) == users_expected
-
- quay_users = model.team.list_team_users(sync_team_info.team)
- assert len(quay_users) == len(users_expected)
-
- for quay_user in quay_users:
- fakeauth_record = model.user.lookup_federated_login(quay_user, _FAKE_AUTH)
- assert fakeauth_record is not None
- assert fakeauth_record.service_ident in users_expected
- assert service_user_map[fakeauth_record.service_ident] == quay_user.id
-
- # Check that the team's robots match those expected.
- robots_found = set([r.username for r in model.team.list_team_robots(sync_team_info.team)])
- assert robots_expected == robots_found
+ # Check that the team's robots match those expected.
+ robots_found = set(
+ [r.username for r in model.team.list_team_robots(sync_team_info.team)]
+ )
+ assert robots_expected == robots_found
-def test_sync_teams_to_groups(user_creation, invite_only_user_creation, blacklisted_emails, app):
- # Necessary for the fake auth entries to be created in FederatedLogin.
- database.LoginService.create(name=_FAKE_AUTH)
+def test_sync_teams_to_groups(
+ user_creation, invite_only_user_creation, blacklisted_emails, app
+):
+ # Necessary for the fake auth entries to be created in FederatedLogin.
+ database.LoginService.create(name=_FAKE_AUTH)
- # Assert the team has not yet been updated.
- sync_team_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert sync_team_info.last_updated is None
+ # Assert the team has not yet been updated.
+ sync_team_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert sync_team_info.last_updated is None
- # Call to sync all teams.
- fake_auth = FakeUsers([])
- sync_teams_to_groups(fake_auth, timedelta(seconds=1))
+ # Call to sync all teams.
+ fake_auth = FakeUsers([])
+ sync_teams_to_groups(fake_auth, timedelta(seconds=1))
- # Ensure the team was synced.
- updated_sync_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert updated_sync_info.last_updated is not None
- assert updated_sync_info.transaction_id != sync_team_info.transaction_id
+ # Ensure the team was synced.
+ updated_sync_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert updated_sync_info.last_updated is not None
+ assert updated_sync_info.transaction_id != sync_team_info.transaction_id
- # Set the stale threshold to a high amount and ensure the team is not resynced.
- current_info = model.team.get_team_sync_information('buynlarge', 'synced')
- current_info.last_updated = datetime.now() - timedelta(seconds=2)
- current_info.save()
+ # Set the stale threshold to a high amount and ensure the team is not resynced.
+ current_info = model.team.get_team_sync_information("buynlarge", "synced")
+ current_info.last_updated = datetime.now() - timedelta(seconds=2)
+ current_info.save()
- sync_teams_to_groups(fake_auth, timedelta(seconds=120))
+ sync_teams_to_groups(fake_auth, timedelta(seconds=120))
- third_sync_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert third_sync_info.transaction_id == updated_sync_info.transaction_id
+ third_sync_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert third_sync_info.transaction_id == updated_sync_info.transaction_id
- # Set the stale threshold to 10 seconds, and ensure the team is resynced, after making it
- # "updated" 20s ago.
- current_info = model.team.get_team_sync_information('buynlarge', 'synced')
- current_info.last_updated = datetime.now() - timedelta(seconds=20)
- current_info.save()
+ # Set the stale threshold to 10 seconds, and ensure the team is resynced, after making it
+ # "updated" 20s ago.
+ current_info = model.team.get_team_sync_information("buynlarge", "synced")
+ current_info.last_updated = datetime.now() - timedelta(seconds=20)
+ current_info.save()
- sync_teams_to_groups(fake_auth, timedelta(seconds=10))
+ sync_teams_to_groups(fake_auth, timedelta(seconds=10))
- fourth_sync_info = model.team.get_team_sync_information('buynlarge', 'synced')
- assert fourth_sync_info.transaction_id != updated_sync_info.transaction_id
+ fourth_sync_info = model.team.get_team_sync_information("buynlarge", "synced")
+ assert fourth_sync_info.transaction_id != updated_sync_info.transaction_id
-@pytest.mark.parametrize('auth_system_builder,config', [
- (mock_ldap, {'group_dn': 'cn=AwesomeFolk'}),
- (fake_keystone, {'group_id': 'somegroupid'}),
-])
-def test_teamsync_end_to_end(user_creation, invite_only_user_creation, auth_system_builder, config,
- blacklisted_emails, app):
- with auth_system_builder() as auth:
- # Create an new team to sync.
- org = model.organization.get_organization('buynlarge')
- new_synced_team = model.team.create_team('synced2', org, 'member', 'Some synced team.')
- sync_team_info = model.team.set_team_syncing(new_synced_team, auth.federated_service, config)
+@pytest.mark.parametrize(
+ "auth_system_builder,config",
+ [
+ (mock_ldap, {"group_dn": "cn=AwesomeFolk"}),
+ (fake_keystone, {"group_id": "somegroupid"}),
+ ],
+)
+def test_teamsync_end_to_end(
+ user_creation,
+ invite_only_user_creation,
+ auth_system_builder,
+ config,
+ blacklisted_emails,
+ app,
+):
+ with auth_system_builder() as auth:
+ # Create an new team to sync.
+ org = model.organization.get_organization("buynlarge")
+ new_synced_team = model.team.create_team(
+ "synced2", org, "member", "Some synced team."
+ )
+ sync_team_info = model.team.set_team_syncing(
+ new_synced_team, auth.federated_service, config
+ )
- # Sync the team.
- assert sync_team(auth, sync_team_info)
+ # Sync the team.
+ assert sync_team(auth, sync_team_info)
- # Ensure we now have members.
- msg = 'Auth system: %s' % auth.federated_service
- sync_team_info = model.team.get_team_sync_information('buynlarge', 'synced2')
- team_members = list(model.team.list_team_users(sync_team_info.team))
- assert len(team_members) > 1, msg
+ # Ensure we now have members.
+ msg = "Auth system: %s" % auth.federated_service
+ sync_team_info = model.team.get_team_sync_information("buynlarge", "synced2")
+ team_members = list(model.team.list_team_users(sync_team_info.team))
+ assert len(team_members) > 1, msg
- it, _ = auth.iterate_group_members(config)
- assert len(team_members) == len(list(it)), msg
+ it, _ = auth.iterate_group_members(config)
+ assert len(team_members) == len(list(it)), msg
- sync_team_info.last_updated = datetime.now() - timedelta(hours=6)
- sync_team_info.save()
+ sync_team_info.last_updated = datetime.now() - timedelta(hours=6)
+ sync_team_info.save()
- # Remove one of the members and force a sync again to ensure we re-link the correct users.
- first_member = team_members[0]
- model.team.remove_user_from_team('buynlarge', 'synced2', first_member.username, 'devtable')
+ # Remove one of the members and force a sync again to ensure we re-link the correct users.
+ first_member = team_members[0]
+ model.team.remove_user_from_team(
+ "buynlarge", "synced2", first_member.username, "devtable"
+ )
- team_members2 = list(model.team.list_team_users(sync_team_info.team))
- assert len(team_members2) == 1, msg
- assert sync_team(auth, sync_team_info)
+ team_members2 = list(model.team.list_team_users(sync_team_info.team))
+ assert len(team_members2) == 1, msg
+ assert sync_team(auth, sync_team_info)
- team_members3 = list(model.team.list_team_users(sync_team_info.team))
- assert len(team_members3) > 1, msg
- assert set([m.id for m in team_members]) == set([m.id for m in team_members3])
+ team_members3 = list(model.team.list_team_users(sync_team_info.team))
+ assert len(team_members3) > 1, msg
+ assert set([m.id for m in team_members]) == set([m.id for m in team_members3])
-@pytest.mark.parametrize('auth_system_builder,config', [
- (mock_ldap, {'group_dn': 'cn=AwesomeFolk'}),
- (fake_keystone, {'group_id': 'somegroupid'}),
-])
-def test_teamsync_existing_email(user_creation, invite_only_user_creation, auth_system_builder,
- blacklisted_emails, config, app):
- with auth_system_builder() as auth:
- # Create an new team to sync.
- org = model.organization.get_organization('buynlarge')
- new_synced_team = model.team.create_team('synced2', org, 'member', 'Some synced team.')
- sync_team_info = model.team.set_team_syncing(new_synced_team, auth.federated_service, config)
+@pytest.mark.parametrize(
+ "auth_system_builder,config",
+ [
+ (mock_ldap, {"group_dn": "cn=AwesomeFolk"}),
+ (fake_keystone, {"group_id": "somegroupid"}),
+ ],
+)
+def test_teamsync_existing_email(
+ user_creation,
+ invite_only_user_creation,
+ auth_system_builder,
+ blacklisted_emails,
+ config,
+ app,
+):
+ with auth_system_builder() as auth:
+ # Create an new team to sync.
+ org = model.organization.get_organization("buynlarge")
+ new_synced_team = model.team.create_team(
+ "synced2", org, "member", "Some synced team."
+ )
+ sync_team_info = model.team.set_team_syncing(
+ new_synced_team, auth.federated_service, config
+ )
- # Add a new *unlinked* user with the same email address as one of the team members.
- it, _ = auth.iterate_group_members(config)
- members = list(it)
- model.user.create_user_noverify('someusername', members[0][0].email)
+ # Add a new *unlinked* user with the same email address as one of the team members.
+ it, _ = auth.iterate_group_members(config)
+ members = list(it)
+ model.user.create_user_noverify("someusername", members[0][0].email)
- # Sync the team and ensure it doesn't fail.
- assert sync_team(auth, sync_team_info)
+ # Sync the team and ensure it doesn't fail.
+ assert sync_team(auth, sync_team_info)
- team_members = list(model.team.list_team_users(sync_team_info.team))
- assert len(team_members) > 0
+ team_members = list(model.team.list_team_users(sync_team_info.team))
+ assert len(team_members) > 0
diff --git a/data/users/test/test_users.py b/data/users/test/test_users.py
index 81f6660bd..6dbe30f24 100644
--- a/data/users/test/test_users.py
+++ b/data/users/test/test_users.py
@@ -11,89 +11,95 @@ from test.test_external_jwt_authn import fake_jwt
from test.fixtures import *
-@pytest.mark.parametrize('auth_system_builder, user1, user2', [
- (mock_ldap, ('someuser', 'somepass'), ('testy', 'password')),
- (fake_keystone, ('cool.user', 'password'), ('some.neat.user', 'foobar')),
-])
+
+@pytest.mark.parametrize(
+ "auth_system_builder, user1, user2",
+ [
+ (mock_ldap, ("someuser", "somepass"), ("testy", "password")),
+ (fake_keystone, ("cool.user", "password"), ("some.neat.user", "foobar")),
+ ],
+)
def test_auth_createuser(auth_system_builder, user1, user2, config, app):
- with auth_system_builder() as auth:
- # Login as a user and ensure a row in the database is created for them.
- user, err = auth.verify_and_link_user(*user1)
- assert err is None
- assert user
+ with auth_system_builder() as auth:
+ # Login as a user and ensure a row in the database is created for them.
+ user, err = auth.verify_and_link_user(*user1)
+ assert err is None
+ assert user
- federated_info = model.user.lookup_federated_login(user, auth.federated_service)
- assert federated_info is not None
+ federated_info = model.user.lookup_federated_login(user, auth.federated_service)
+ assert federated_info is not None
- # Disable user creation.
- with patch('features.USER_CREATION', False):
- # Ensure that the existing user can login.
- user_again, err = auth.verify_and_link_user(*user1)
- assert err is None
- assert user_again.id == user.id
+ # Disable user creation.
+ with patch("features.USER_CREATION", False):
+ # Ensure that the existing user can login.
+ user_again, err = auth.verify_and_link_user(*user1)
+ assert err is None
+ assert user_again.id == user.id
- # Ensure that a new user cannot.
- new_user, err = auth.verify_and_link_user(*user2)
- assert new_user is None
- assert err == DISABLED_MESSAGE
+ # Ensure that a new user cannot.
+ new_user, err = auth.verify_and_link_user(*user2)
+ assert new_user is None
+ assert err == DISABLED_MESSAGE
@pytest.mark.parametrize(
- 'email, blacklisting_enabled, can_create',
- [
- # Blacklisting Enabled, Blacklisted Domain => Blocked
- ('foo@blacklisted.net', True, False),
- ('foo@blacklisted.com', True, False),
-
- # Blacklisting Enabled, similar to blacklisted domain => Allowed
- ('foo@notblacklisted.com', True, True),
- ('foo@blacklisted.org', True, True),
-
- # Blacklisting *Disabled*, Blacklisted Domain => Allowed
- ('foo@blacklisted.com', False, True),
- ('foo@blacklisted.net', False, True),
- ]
+ "email, blacklisting_enabled, can_create",
+ [
+ # Blacklisting Enabled, Blacklisted Domain => Blocked
+ ("foo@blacklisted.net", True, False),
+ ("foo@blacklisted.com", True, False),
+ # Blacklisting Enabled, similar to blacklisted domain => Allowed
+ ("foo@notblacklisted.com", True, True),
+ ("foo@blacklisted.org", True, True),
+ # Blacklisting *Disabled*, Blacklisted Domain => Allowed
+ ("foo@blacklisted.com", False, True),
+ ("foo@blacklisted.net", False, True),
+ ],
)
-@pytest.mark.parametrize('auth_system_builder', [mock_ldap, fake_keystone, fake_jwt])
-def test_createuser_with_blacklist(auth_system_builder, email, blacklisting_enabled, can_create, config, app):
- """Verify email blacklisting with User Creation"""
+@pytest.mark.parametrize("auth_system_builder", [mock_ldap, fake_keystone, fake_jwt])
+def test_createuser_with_blacklist(
+ auth_system_builder, email, blacklisting_enabled, can_create, config, app
+):
+ """Verify email blacklisting with User Creation"""
- MOCK_CONFIG = {'BLACKLISTED_EMAIL_DOMAINS': ['blacklisted.com', 'blacklisted.net']}
- MOCK_PASSWORD = 'somepass'
+ MOCK_CONFIG = {"BLACKLISTED_EMAIL_DOMAINS": ["blacklisted.com", "blacklisted.net"]}
+ MOCK_PASSWORD = "somepass"
- with auth_system_builder() as auth:
- with patch('features.BLACKLISTED_EMAILS', blacklisting_enabled):
- with patch.dict('data.model.config.app_config', MOCK_CONFIG):
- with patch('features.USER_CREATION', True):
- new_user, err = auth.verify_and_link_user(email, MOCK_PASSWORD)
- if can_create:
- assert err is None
- assert new_user
- else:
- assert err
- assert new_user is None
+ with auth_system_builder() as auth:
+ with patch("features.BLACKLISTED_EMAILS", blacklisting_enabled):
+ with patch.dict("data.model.config.app_config", MOCK_CONFIG):
+ with patch("features.USER_CREATION", True):
+ new_user, err = auth.verify_and_link_user(email, MOCK_PASSWORD)
+ if can_create:
+ assert err is None
+ assert new_user
+ else:
+ assert err
+ assert new_user is None
-@pytest.mark.parametrize('auth_system_builder,auth_kwargs', [
- (mock_ldap, {}),
- (fake_keystone, {'version': 3}),
- (fake_keystone, {'version': 2}),
- (fake_jwt, {}),
-])
+@pytest.mark.parametrize(
+ "auth_system_builder,auth_kwargs",
+ [
+ (mock_ldap, {}),
+ (fake_keystone, {"version": 3}),
+ (fake_keystone, {"version": 2}),
+ (fake_jwt, {}),
+ ],
+)
def test_ping(auth_system_builder, auth_kwargs, app):
- with auth_system_builder(**auth_kwargs) as auth:
- status, err = auth.ping()
- assert status
- assert err is None
+ with auth_system_builder(**auth_kwargs) as auth:
+ status, err = auth.ping()
+ assert status
+ assert err is None
-@pytest.mark.parametrize('auth_system_builder,auth_kwargs', [
- (mock_ldap, {}),
- (fake_keystone, {'version': 3}),
- (fake_keystone, {'version': 2}),
-])
+@pytest.mark.parametrize(
+ "auth_system_builder,auth_kwargs",
+ [(mock_ldap, {}), (fake_keystone, {"version": 3}), (fake_keystone, {"version": 2})],
+)
def test_at_least_one_user_exists(auth_system_builder, auth_kwargs, app):
- with auth_system_builder(**auth_kwargs) as auth:
- status, err = auth.at_least_one_user_exists()
- assert status
- assert err is None
+ with auth_system_builder(**auth_kwargs) as auth:
+ status, err = auth.at_least_one_user_exists()
+ assert status
+ assert err is None
diff --git a/digest/digest_tools.py b/digest/digest_tools.py
index 212088236..3e0fc3d9e 100644
--- a/digest/digest_tools.py
+++ b/digest/digest_tools.py
@@ -3,80 +3,85 @@ import os.path
import hashlib
-DIGEST_PATTERN = r'([A-Za-z0-9_+.-]+):([A-Fa-f0-9]+)'
-REPLACE_WITH_PATH = re.compile(r'[+.]')
-REPLACE_DOUBLE_SLASHES = re.compile(r'/+')
+DIGEST_PATTERN = r"([A-Za-z0-9_+.-]+):([A-Fa-f0-9]+)"
+REPLACE_WITH_PATH = re.compile(r"[+.]")
+REPLACE_DOUBLE_SLASHES = re.compile(r"/+")
+
class InvalidDigestException(RuntimeError):
- pass
+ pass
class Digest(object):
- DIGEST_REGEX = re.compile(DIGEST_PATTERN)
+ DIGEST_REGEX = re.compile(DIGEST_PATTERN)
- def __init__(self, hash_alg, hash_bytes):
- self._hash_alg = hash_alg
- self._hash_bytes = hash_bytes
+ def __init__(self, hash_alg, hash_bytes):
+ self._hash_alg = hash_alg
+ self._hash_bytes = hash_bytes
- def __str__(self):
- return '{0}:{1}'.format(self._hash_alg, self._hash_bytes)
+ def __str__(self):
+ return "{0}:{1}".format(self._hash_alg, self._hash_bytes)
- def __eq__(self, rhs):
- return isinstance(rhs, Digest) and str(self) == str(rhs)
+ def __eq__(self, rhs):
+ return isinstance(rhs, Digest) and str(self) == str(rhs)
- @staticmethod
- def parse_digest(digest):
- """ Returns the digest parsed out to its components. """
- match = Digest.DIGEST_REGEX.match(digest)
- if match is None or match.end() != len(digest):
- raise InvalidDigestException('Not a valid digest: %s', digest)
+ @staticmethod
+ def parse_digest(digest):
+ """ Returns the digest parsed out to its components. """
+ match = Digest.DIGEST_REGEX.match(digest)
+ if match is None or match.end() != len(digest):
+ raise InvalidDigestException("Not a valid digest: %s", digest)
- return Digest(match.group(1), match.group(2))
+ return Digest(match.group(1), match.group(2))
- @property
- def hash_alg(self):
- return self._hash_alg
+ @property
+ def hash_alg(self):
+ return self._hash_alg
- @property
- def hash_bytes(self):
- return self._hash_bytes
+ @property
+ def hash_bytes(self):
+ return self._hash_bytes
def content_path(digest):
- """ Returns a relative path to the parsed digest. """
- parsed = Digest.parse_digest(digest)
- components = []
+ """ Returns a relative path to the parsed digest. """
+ parsed = Digest.parse_digest(digest)
+ components = []
- # Generate a prefix which is always two characters, and which will be filled with leading zeros
- # if the input does not contain at least two characters. e.g. ABC -> AB, A -> 0A
- prefix = parsed.hash_bytes[0:2].zfill(2)
- pathish = REPLACE_WITH_PATH.sub('/', parsed.hash_alg)
- normalized = REPLACE_DOUBLE_SLASHES.sub('/', pathish).lstrip('/')
- components.extend([normalized, prefix, parsed.hash_bytes])
- return os.path.join(*components)
+ # Generate a prefix which is always two characters, and which will be filled with leading zeros
+ # if the input does not contain at least two characters. e.g. ABC -> AB, A -> 0A
+ prefix = parsed.hash_bytes[0:2].zfill(2)
+ pathish = REPLACE_WITH_PATH.sub("/", parsed.hash_alg)
+ normalized = REPLACE_DOUBLE_SLASHES.sub("/", pathish).lstrip("/")
+ components.extend([normalized, prefix, parsed.hash_bytes])
+ return os.path.join(*components)
def sha256_digest(content):
- """ Returns a sha256 hash of the content bytes in digest form. """
- def single_chunk_generator():
- yield content
- return sha256_digest_from_generator(single_chunk_generator())
+ """ Returns a sha256 hash of the content bytes in digest form. """
+
+ def single_chunk_generator():
+ yield content
+
+ return sha256_digest_from_generator(single_chunk_generator())
def sha256_digest_from_generator(content_generator):
- """ Reads all of the data from the iterator and creates a sha256 digest from the content
+ """ Reads all of the data from the iterator and creates a sha256 digest from the content
"""
- digest = hashlib.sha256()
- for chunk in content_generator:
- digest.update(chunk)
- return 'sha256:{0}'.format(digest.hexdigest())
+ digest = hashlib.sha256()
+ for chunk in content_generator:
+ digest.update(chunk)
+ return "sha256:{0}".format(digest.hexdigest())
def sha256_digest_from_hashlib(sha256_hash_obj):
- return 'sha256:{0}'.format(sha256_hash_obj.hexdigest())
+ return "sha256:{0}".format(sha256_hash_obj.hexdigest())
def digests_equal(lhs_digest_string, rhs_digest_string):
- """ Parse and compare the two digests, returns True if the digests are equal, False otherwise.
+ """ Parse and compare the two digests, returns True if the digests are equal, False otherwise.
"""
- return Digest.parse_digest(lhs_digest_string) == Digest.parse_digest(rhs_digest_string)
+ return Digest.parse_digest(lhs_digest_string) == Digest.parse_digest(
+ rhs_digest_string
+ )
diff --git a/digest/test/test_digest_tools.py b/digest/test/test_digest_tools.py
index b04f64c6f..1e3792ff3 100644
--- a/digest/test/test_digest_tools.py
+++ b/digest/test/test_digest_tools.py
@@ -2,42 +2,52 @@ import pytest
from digest.digest_tools import Digest, content_path, InvalidDigestException
-@pytest.mark.parametrize('digest, output_args', [
- ('tarsum.v123123+sha1:123deadbeef', ('tarsum.v123123+sha1', '123deadbeef')),
- ('tarsum.v1+sha256:123123', ('tarsum.v1+sha256', '123123')),
- ('tarsum.v0+md5:abc', ('tarsum.v0+md5', 'abc')),
- ('tarsum+sha1:abc', ('tarsum+sha1', 'abc')),
- ('sha1:123deadbeef', ('sha1', '123deadbeef')),
- ('sha256:123123', ('sha256', '123123')),
- ('md5:abc', ('md5', 'abc')),
-])
+
+@pytest.mark.parametrize(
+ "digest, output_args",
+ [
+ ("tarsum.v123123+sha1:123deadbeef", ("tarsum.v123123+sha1", "123deadbeef")),
+ ("tarsum.v1+sha256:123123", ("tarsum.v1+sha256", "123123")),
+ ("tarsum.v0+md5:abc", ("tarsum.v0+md5", "abc")),
+ ("tarsum+sha1:abc", ("tarsum+sha1", "abc")),
+ ("sha1:123deadbeef", ("sha1", "123deadbeef")),
+ ("sha256:123123", ("sha256", "123123")),
+ ("md5:abc", ("md5", "abc")),
+ ],
+)
def test_parse_good(digest, output_args):
- assert Digest.parse_digest(digest) == Digest(*output_args)
- assert str(Digest.parse_digest(digest)) == digest
+ assert Digest.parse_digest(digest) == Digest(*output_args)
+ assert str(Digest.parse_digest(digest)) == digest
-@pytest.mark.parametrize('bad_digest', [
- 'tarsum.v+md5:abc:',
- 'sha1:123deadbeefzxczxv',
- 'sha256123123',
- 'tarsum.v1+',
- 'tarsum.v1123+sha1:',
-])
+@pytest.mark.parametrize(
+ "bad_digest",
+ [
+ "tarsum.v+md5:abc:",
+ "sha1:123deadbeefzxczxv",
+ "sha256123123",
+ "tarsum.v1+",
+ "tarsum.v1123+sha1:",
+ ],
+)
def test_parse_fail(bad_digest):
- with pytest.raises(InvalidDigestException):
- Digest.parse_digest(bad_digest)
+ with pytest.raises(InvalidDigestException):
+ Digest.parse_digest(bad_digest)
-@pytest.mark.parametrize('digest, path', [
- ('tarsum.v123123+sha1:123deadbeef', 'tarsum/v123123/sha1/12/123deadbeef'),
- ('tarsum.v1+sha256:123123', 'tarsum/v1/sha256/12/123123'),
- ('tarsum.v0+md5:abc', 'tarsum/v0/md5/ab/abc'),
- ('sha1:123deadbeef', 'sha1/12/123deadbeef'),
- ('sha256:123123', 'sha256/12/123123'),
- ('md5:abc', 'md5/ab/abc'),
- ('md5:1', 'md5/01/1'),
- ('md5.....+++:1', 'md5/01/1'),
- ('.md5.:1', 'md5/01/1'),
-])
+@pytest.mark.parametrize(
+ "digest, path",
+ [
+ ("tarsum.v123123+sha1:123deadbeef", "tarsum/v123123/sha1/12/123deadbeef"),
+ ("tarsum.v1+sha256:123123", "tarsum/v1/sha256/12/123123"),
+ ("tarsum.v0+md5:abc", "tarsum/v0/md5/ab/abc"),
+ ("sha1:123deadbeef", "sha1/12/123deadbeef"),
+ ("sha256:123123", "sha256/12/123123"),
+ ("md5:abc", "md5/ab/abc"),
+ ("md5:1", "md5/01/1"),
+ ("md5.....+++:1", "md5/01/1"),
+ (".md5.:1", "md5/01/1"),
+ ],
+)
def test_paths(digest, path):
- assert content_path(digest) == path
+ assert content_path(digest) == path
diff --git a/endpoints/api/__init__.py b/endpoints/api/__init__.py
index 8dcabe6a3..a263a37a3 100644
--- a/endpoints/api/__init__.py
+++ b/endpoints/api/__init__.py
@@ -11,20 +11,36 @@ from flask_restful.utils.cors import crossdomain
from jsonschema import validate, ValidationError
from app import app, metric_queue, authentication
-from auth.permissions import (ReadRepositoryPermission, ModifyRepositoryPermission,
- AdministerRepositoryPermission, UserReadPermission,
- UserAdminPermission)
+from auth.permissions import (
+ ReadRepositoryPermission,
+ ModifyRepositoryPermission,
+ AdministerRepositoryPermission,
+ UserReadPermission,
+ UserAdminPermission,
+)
from auth import scopes
-from auth.auth_context import (get_authenticated_context, get_authenticated_user,
- get_validated_oauth_token)
+from auth.auth_context import (
+ get_authenticated_context,
+ get_authenticated_user,
+ get_validated_oauth_token,
+)
from auth.decorators import process_oauth
from data import model as data_model
from data.logs_model import logs_model
from data.database import RepositoryState
from endpoints.csrf import csrf_protect
-from endpoints.exception import (Unauthorized, InvalidRequest, InvalidResponse,
- FreshLoginRequired, NotFound)
-from endpoints.decorators import check_anon_protection, require_xhr_from_browser, check_readonly
+from endpoints.exception import (
+ Unauthorized,
+ InvalidRequest,
+ InvalidResponse,
+ FreshLoginRequired,
+ NotFound,
+)
+from endpoints.decorators import (
+ check_anon_protection,
+ require_xhr_from_browser,
+ check_readonly,
+)
from util.metrics.metricqueue import time_decorator
from util.names import parse_namespace_repository
from util.pagination import encrypt_page_token, decrypt_page_token
@@ -33,253 +49,297 @@ from __init__models_pre_oci import pre_oci_model as model
logger = logging.getLogger(__name__)
-api_bp = Blueprint('api', __name__)
+api_bp = Blueprint("api", __name__)
-CROSS_DOMAIN_HEADERS = ['Authorization', 'Content-Type', 'X-Requested-With']
+CROSS_DOMAIN_HEADERS = ["Authorization", "Content-Type", "X-Requested-With"]
+
class ApiExceptionHandlingApi(Api):
- @crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS)
- def handle_error(self, error):
- return super(ApiExceptionHandlingApi, self).handle_error(error)
+ @crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS)
+ def handle_error(self, error):
+ return super(ApiExceptionHandlingApi, self).handle_error(error)
api = ApiExceptionHandlingApi()
api.init_app(api_bp)
-api.decorators = [csrf_protect(),
- crossdomain(origin='*', headers=CROSS_DOMAIN_HEADERS),
- process_oauth, time_decorator(api_bp.name, metric_queue),
- require_xhr_from_browser]
+api.decorators = [
+ csrf_protect(),
+ crossdomain(origin="*", headers=CROSS_DOMAIN_HEADERS),
+ process_oauth,
+ time_decorator(api_bp.name, metric_queue),
+ require_xhr_from_browser,
+]
def resource(*urls, **kwargs):
- def wrapper(api_resource):
- if not api_resource:
- return None
+ def wrapper(api_resource):
+ if not api_resource:
+ return None
- api_resource.registered = True
- api.add_resource(api_resource, *urls, **kwargs)
- return api_resource
- return wrapper
+ api_resource.registered = True
+ api.add_resource(api_resource, *urls, **kwargs)
+ return api_resource
+
+ return wrapper
def show_if(value):
- def f(inner):
- if hasattr(inner, 'registered') and inner.registered:
- msg = ('API endpoint %s is already registered; please switch the ' +
- '@show_if to be *below* the @resource decorator')
- raise Exception(msg % inner)
+ def f(inner):
+ if hasattr(inner, "registered") and inner.registered:
+ msg = (
+ "API endpoint %s is already registered; please switch the "
+ + "@show_if to be *below* the @resource decorator"
+ )
+ raise Exception(msg % inner)
- if not value:
- return None
+ if not value:
+ return None
- return inner
- return f
+ return inner
+
+ return f
def hide_if(value):
- def f(inner):
- if hasattr(inner, 'registered') and inner.registered:
- msg = ('API endpoint %s is already registered; please switch the ' +
- '@hide_if to be *below* the @resource decorator')
- raise Exception(msg % inner)
+ def f(inner):
+ if hasattr(inner, "registered") and inner.registered:
+ msg = (
+ "API endpoint %s is already registered; please switch the "
+ + "@hide_if to be *below* the @resource decorator"
+ )
+ raise Exception(msg % inner)
- if value:
- return None
+ if value:
+ return None
- return inner
- return f
+ return inner
+
+ return f
def truthy_bool(param):
- return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
+ return param not in {False, "false", "False", "0", "FALSE", "", "null"}
def format_date(date):
- """ Output an RFC822 date format. """
- if date is None:
- return None
- return formatdate(timegm(date.utctimetuple()))
+ """ Output an RFC822 date format. """
+ if date is None:
+ return None
+ return formatdate(timegm(date.utctimetuple()))
def add_method_metadata(name, value):
- def modifier(func):
- if func is None:
- return None
+ def modifier(func):
+ if func is None:
+ return None
- if '__api_metadata' not in dir(func):
- func.__api_metadata = {}
- func.__api_metadata[name] = value
- return func
- return modifier
+ if "__api_metadata" not in dir(func):
+ func.__api_metadata = {}
+ func.__api_metadata[name] = value
+ return func
+
+ return modifier
def method_metadata(func, name):
- if func is None:
+ if func is None:
+ return None
+
+ if "__api_metadata" in dir(func):
+ return func.__api_metadata.get(name, None)
return None
- if '__api_metadata' in dir(func):
- return func.__api_metadata.get(name, None)
- return None
-
-
-nickname = partial(add_method_metadata, 'nickname')
-related_user_resource = partial(add_method_metadata, 'related_user_resource')
-internal_only = add_method_metadata('internal', True)
+nickname = partial(add_method_metadata, "nickname")
+related_user_resource = partial(add_method_metadata, "related_user_resource")
+internal_only = add_method_metadata("internal", True)
def path_param(name, description):
- def add_param(func):
- if not func:
- return func
+ def add_param(func):
+ if not func:
+ return func
- if '__api_path_params' not in dir(func):
- func.__api_path_params = {}
- func.__api_path_params[name] = {
- 'name': name,
- 'description': description
- }
- return func
- return add_param
+ if "__api_path_params" not in dir(func):
+ func.__api_path_params = {}
+ func.__api_path_params[name] = {"name": name, "description": description}
+ return func
+
+ return add_param
-def query_param(name, help_str, type=reqparse.text_type, default=None,
- choices=(), required=False):
- def add_param(func):
- if '__api_query_params' not in dir(func):
- func.__api_query_params = []
- func.__api_query_params.append({
- 'name': name,
- 'type': type,
- 'help': help_str,
- 'default': default,
- 'choices': choices,
- 'required': required,
- 'location': ('args')
- })
- return func
- return add_param
+def query_param(
+ name, help_str, type=reqparse.text_type, default=None, choices=(), required=False
+):
+ def add_param(func):
+ if "__api_query_params" not in dir(func):
+ func.__api_query_params = []
+ func.__api_query_params.append(
+ {
+ "name": name,
+ "type": type,
+ "help": help_str,
+ "default": default,
+ "choices": choices,
+ "required": required,
+ "location": ("args"),
+ }
+ )
+ return func
-def page_support(page_token_kwarg='page_token', parsed_args_kwarg='parsed_args'):
- def inner(func):
- """ Adds pagination support to an API endpoint. The decorated API will have an
+ return add_param
+
+
+def page_support(page_token_kwarg="page_token", parsed_args_kwarg="parsed_args"):
+ def inner(func):
+ """ Adds pagination support to an API endpoint. The decorated API will have an
added query parameter named 'next_page'. Works in tandem with the
modelutil paginate method.
"""
- @wraps(func)
- @query_param('next_page', 'The page token for the next page', type=str)
- def wrapper(self, *args, **kwargs):
- # Note: if page_token is None, we'll receive the first page of results back.
- page_token = decrypt_page_token(kwargs[parsed_args_kwarg]['next_page'])
- kwargs[page_token_kwarg] = page_token
- (result, next_page_token) = func(self, *args, **kwargs)
- if next_page_token is not None:
- result['next_page'] = encrypt_page_token(next_page_token)
+ @wraps(func)
+ @query_param("next_page", "The page token for the next page", type=str)
+ def wrapper(self, *args, **kwargs):
+ # Note: if page_token is None, we'll receive the first page of results back.
+ page_token = decrypt_page_token(kwargs[parsed_args_kwarg]["next_page"])
+ kwargs[page_token_kwarg] = page_token
- return result
- return wrapper
- return inner
+ (result, next_page_token) = func(self, *args, **kwargs)
+ if next_page_token is not None:
+ result["next_page"] = encrypt_page_token(next_page_token)
-def parse_args(kwarg_name='parsed_args'):
- def inner(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- if '__api_query_params' not in dir(func):
- abort(500)
+ return result
- parser = reqparse.RequestParser()
- for arg_spec in func.__api_query_params:
- parser.add_argument(**arg_spec)
- kwargs[kwarg_name] = parser.parse_args()
+ return wrapper
+
+ return inner
+
+
+def parse_args(kwarg_name="parsed_args"):
+ def inner(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if "__api_query_params" not in dir(func):
+ abort(500)
+
+ parser = reqparse.RequestParser()
+ for arg_spec in func.__api_query_params:
+ parser.add_argument(**arg_spec)
+ kwargs[kwarg_name] = parser.parse_args()
+
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return inner
- return func(self, *args, **kwargs)
- return wrapper
- return inner
def parse_repository_name(func):
- @wraps(func)
- def wrapper(repository, *args, **kwargs):
- (namespace, repository) = parse_namespace_repository(repository, app.config['LIBRARY_NAMESPACE'])
- return func(namespace, repository, *args, **kwargs)
- return wrapper
+ @wraps(func)
+ def wrapper(repository, *args, **kwargs):
+ (namespace, repository) = parse_namespace_repository(
+ repository, app.config["LIBRARY_NAMESPACE"]
+ )
+ return func(namespace, repository, *args, **kwargs)
+
+ return wrapper
class ApiResource(Resource):
- registered = False
- method_decorators = [check_anon_protection, check_readonly]
+ registered = False
+ method_decorators = [check_anon_protection, check_readonly]
- def options(self):
- return None, 200
+ def options(self):
+ return None, 200
class RepositoryParamResource(ApiResource):
- method_decorators = [check_anon_protection, parse_repository_name, check_readonly]
+ method_decorators = [check_anon_protection, parse_repository_name, check_readonly]
def disallow_for_app_repositories(func):
- @wraps(func)
- def wrapped(self, namespace_name, repository_name, *args, **kwargs):
- # Lookup the repository with the given namespace and name and ensure it is not an application
- # repository.
- if model.is_app_repository(namespace_name, repository_name):
- abort(501)
+ @wraps(func)
+ def wrapped(self, namespace_name, repository_name, *args, **kwargs):
+ # Lookup the repository with the given namespace and name and ensure it is not an application
+ # repository.
+ if model.is_app_repository(namespace_name, repository_name):
+ abort(501)
- return func(self, namespace_name, repository_name, *args, **kwargs)
+ return func(self, namespace_name, repository_name, *args, **kwargs)
- return wrapped
+ return wrapped
def disallow_for_non_normal_repositories(func):
- @wraps(func)
- def wrapped(self, namespace_name, repository_name, *args, **kwargs):
- repo = data_model.repository.get_repository(namespace_name, repository_name)
- if repo and repo.state != RepositoryState.NORMAL:
- abort(503, message='Repository is in read only or mirror mode: %s' % repo.state)
+ @wraps(func)
+ def wrapped(self, namespace_name, repository_name, *args, **kwargs):
+ repo = data_model.repository.get_repository(namespace_name, repository_name)
+ if repo and repo.state != RepositoryState.NORMAL:
+ abort(
+ 503,
+ message="Repository is in read only or mirror mode: %s" % repo.state,
+ )
- return func(self, namespace_name, repository_name, *args, **kwargs)
- return wrapped
+ return func(self, namespace_name, repository_name, *args, **kwargs)
+
+ return wrapped
def require_repo_permission(permission_class, scope, allow_public=False):
- def wrapper(func):
- @add_method_metadata('oauth2_scope', scope)
- @wraps(func)
- def wrapped(self, namespace, repository, *args, **kwargs):
- logger.debug('Checking permission %s for repo: %s/%s', permission_class, namespace,
- repository)
- permission = permission_class(namespace, repository)
- if (permission.can() or
- (allow_public and
- model.repository_is_public(namespace, repository))):
- return func(self, namespace, repository, *args, **kwargs)
- raise Unauthorized()
- return wrapped
- return wrapper
+ def wrapper(func):
+ @add_method_metadata("oauth2_scope", scope)
+ @wraps(func)
+ def wrapped(self, namespace, repository, *args, **kwargs):
+ logger.debug(
+ "Checking permission %s for repo: %s/%s",
+ permission_class,
+ namespace,
+ repository,
+ )
+ permission = permission_class(namespace, repository)
+ if permission.can() or (
+ allow_public and model.repository_is_public(namespace, repository)
+ ):
+ return func(self, namespace, repository, *args, **kwargs)
+ raise Unauthorized()
+
+ return wrapped
+
+ return wrapper
-require_repo_read = require_repo_permission(ReadRepositoryPermission, scopes.READ_REPO, True)
-require_repo_write = require_repo_permission(ModifyRepositoryPermission, scopes.WRITE_REPO)
-require_repo_admin = require_repo_permission(AdministerRepositoryPermission, scopes.ADMIN_REPO)
+require_repo_read = require_repo_permission(
+ ReadRepositoryPermission, scopes.READ_REPO, True
+)
+require_repo_write = require_repo_permission(
+ ModifyRepositoryPermission, scopes.WRITE_REPO
+)
+require_repo_admin = require_repo_permission(
+ AdministerRepositoryPermission, scopes.ADMIN_REPO
+)
def require_user_permission(permission_class, scope=None):
- def wrapper(func):
- @add_method_metadata('oauth2_scope', scope)
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- user = get_authenticated_user()
- if not user:
- raise Unauthorized()
+ def wrapper(func):
+ @add_method_metadata("oauth2_scope", scope)
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ user = get_authenticated_user()
+ if not user:
+ raise Unauthorized()
- logger.debug('Checking permission %s for user %s', permission_class, user.username)
- permission = permission_class(user.username)
- if permission.can():
- return func(self, *args, **kwargs)
- raise Unauthorized()
- return wrapped
- return wrapper
+ logger.debug(
+ "Checking permission %s for user %s", permission_class, user.username
+ )
+ permission = permission_class(user.username)
+ if permission.can():
+ return func(self, *args, **kwargs)
+ raise Unauthorized()
+
+ return wrapped
+
+ return wrapper
require_user_read = require_user_permission(UserReadPermission, scopes.READ_USER)
@@ -287,136 +347,151 @@ require_user_admin = require_user_permission(UserAdminPermission, scopes.ADMIN_U
def verify_not_prod(func):
- @add_method_metadata('enterprise_only', True)
- @wraps(func)
- def wrapped(*args, **kwargs):
- # Verify that we are not running on a production (i.e. hosted) stack. If so, we fail.
- # This should never happen (because of the feature-flag on SUPER_USERS), but we want to be
- # absolutely sure.
- if app.config['SERVER_HOSTNAME'].find('quay.io') >= 0:
- logger.error('!!! Super user method called IN PRODUCTION !!!')
- raise NotFound()
+ @add_method_metadata("enterprise_only", True)
+ @wraps(func)
+ def wrapped(*args, **kwargs):
+ # Verify that we are not running on a production (i.e. hosted) stack. If so, we fail.
+ # This should never happen (because of the feature-flag on SUPER_USERS), but we want to be
+ # absolutely sure.
+ if app.config["SERVER_HOSTNAME"].find("quay.io") >= 0:
+ logger.error("!!! Super user method called IN PRODUCTION !!!")
+ raise NotFound()
- return func(*args, **kwargs)
+ return func(*args, **kwargs)
- return wrapped
+ return wrapped
def require_fresh_login(func):
- @add_method_metadata('requires_fresh_login', True)
- @wraps(func)
- def wrapped(*args, **kwargs):
- user = get_authenticated_user()
- if not user:
- raise Unauthorized()
+ @add_method_metadata("requires_fresh_login", True)
+ @wraps(func)
+ def wrapped(*args, **kwargs):
+ user = get_authenticated_user()
+ if not user:
+ raise Unauthorized()
- if get_validated_oauth_token():
- return func(*args, **kwargs)
+ if get_validated_oauth_token():
+ return func(*args, **kwargs)
- logger.debug('Checking fresh login for user %s', user.username)
+ logger.debug("Checking fresh login for user %s", user.username)
- last_login = session.get('login_time', datetime.datetime.min)
- valid_span = datetime.datetime.now() - datetime.timedelta(minutes=10)
+ last_login = session.get("login_time", datetime.datetime.min)
+ valid_span = datetime.datetime.now() - datetime.timedelta(minutes=10)
- if (not user.password_hash or last_login >= valid_span or
- not authentication.supports_fresh_login):
- return func(*args, **kwargs)
+ if (
+ not user.password_hash
+ or last_login >= valid_span
+ or not authentication.supports_fresh_login
+ ):
+ return func(*args, **kwargs)
- raise FreshLoginRequired()
- return wrapped
+ raise FreshLoginRequired()
+
+ return wrapped
def require_scope(scope_object):
- def wrapper(func):
- @add_method_metadata('oauth2_scope', scope_object)
- @wraps(func)
- def wrapped(*args, **kwargs):
- return func(*args, **kwargs)
- return wrapped
- return wrapper
+ def wrapper(func):
+ @add_method_metadata("oauth2_scope", scope_object)
+ @wraps(func)
+ def wrapped(*args, **kwargs):
+ return func(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
def max_json_size(max_size):
- def wrapper(func):
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- if request.is_json and len(request.get_data()) > max_size:
- raise InvalidRequest()
-
- return func(self, *args, **kwargs)
- return wrapped
- return wrapper
+ def wrapper(func):
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ if request.is_json and len(request.get_data()) > max_size:
+ raise InvalidRequest()
+
+ return func(self, *args, **kwargs)
+
+ return wrapped
+
+ return wrapper
def validate_json_request(schema_name, optional=False):
- def wrapper(func):
- @add_method_metadata('request_schema', schema_name)
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- schema = self.schemas[schema_name]
- try:
- json_data = request.get_json()
- if json_data is None:
- if not optional:
- raise InvalidRequest('Missing JSON body')
- else:
- validate(json_data, schema)
- return func(self, *args, **kwargs)
- except ValidationError as ex:
- raise InvalidRequest(str(ex))
- return wrapped
- return wrapper
+ def wrapper(func):
+ @add_method_metadata("request_schema", schema_name)
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ schema = self.schemas[schema_name]
+ try:
+ json_data = request.get_json()
+ if json_data is None:
+ if not optional:
+ raise InvalidRequest("Missing JSON body")
+ else:
+ validate(json_data, schema)
+ return func(self, *args, **kwargs)
+ except ValidationError as ex:
+ raise InvalidRequest(str(ex))
+
+ return wrapped
+
+ return wrapper
def request_error(exception=None, **kwargs):
- data = kwargs.copy()
- message = 'Request error.'
- if exception:
- message = str(exception)
+ data = kwargs.copy()
+ message = "Request error."
+ if exception:
+ message = str(exception)
- message = data.pop('message', message)
- raise InvalidRequest(message, data)
+ message = data.pop("message", message)
+ raise InvalidRequest(message, data)
def log_action(kind, user_or_orgname, metadata=None, repo=None, repo_name=None):
- if not metadata:
- metadata = {}
+ if not metadata:
+ metadata = {}
- oauth_token = get_validated_oauth_token()
- if oauth_token:
- metadata['oauth_token_id'] = oauth_token.id
- metadata['oauth_token_application_id'] = oauth_token.application.client_id
- metadata['oauth_token_application'] = oauth_token.application.name
+ oauth_token = get_validated_oauth_token()
+ if oauth_token:
+ metadata["oauth_token_id"] = oauth_token.id
+ metadata["oauth_token_application_id"] = oauth_token.application.client_id
+ metadata["oauth_token_application"] = oauth_token.application.name
- performer = get_authenticated_user()
+ performer = get_authenticated_user()
- if repo_name is not None:
- repo = data_model.repository.get_repository(user_or_orgname, repo_name)
+ if repo_name is not None:
+ repo = data_model.repository.get_repository(user_or_orgname, repo_name)
- logs_model.log_action(kind, user_or_orgname,
- repository=repo,
- performer=performer,
- ip=get_request_ip(),
- metadata=metadata)
+ logs_model.log_action(
+ kind,
+ user_or_orgname,
+ repository=repo,
+ performer=performer,
+ ip=get_request_ip(),
+ metadata=metadata,
+ )
def define_json_response(schema_name):
- def wrapper(func):
- @add_method_metadata('response_schema', schema_name)
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- schema = self.schemas[schema_name]
- resp = func(self, *args, **kwargs)
+ def wrapper(func):
+ @add_method_metadata("response_schema", schema_name)
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ schema = self.schemas[schema_name]
+ resp = func(self, *args, **kwargs)
- if app.config['TESTING']:
- try:
- validate(resp, schema)
- except ValidationError as ex:
- raise InvalidResponse(str(ex))
+ if app.config["TESTING"]:
+ try:
+ validate(resp, schema)
+ except ValidationError as ex:
+ raise InvalidResponse(str(ex))
- return resp
- return wrapped
- return wrapper
+ return resp
+
+ return wrapped
+
+ return wrapper
import endpoints.api.appspecifictokens
diff --git a/endpoints/api/__init__models_interface.py b/endpoints/api/__init__models_interface.py
index 974d9e0e1..d8575c283 100644
--- a/endpoints/api/__init__models_interface.py
+++ b/endpoints/api/__init__models_interface.py
@@ -2,16 +2,16 @@ from abc import ABCMeta, abstractmethod
from six import add_metaclass
-
+
@add_metaclass(ABCMeta)
class InitDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by __init__.
"""
- @abstractmethod
- def is_app_repository(self, namespace_name, repository_name):
- """
+ @abstractmethod
+ def is_app_repository(self, namespace_name, repository_name):
+ """
Args:
namespace_name: namespace or user
@@ -20,11 +20,11 @@ class InitDataInterface(object):
Returns:
Boolean
"""
- pass
-
- @abstractmethod
- def repository_is_public(self, namespace_name, repository_name):
- """
+ pass
+
+ @abstractmethod
+ def repository_is_public(self, namespace_name, repository_name):
+ """
Args:
namespace_name: namespace or user
@@ -33,11 +33,13 @@ class InitDataInterface(object):
Returns:
Boolean
"""
- pass
-
- @abstractmethod
- def log_action(self, kind, namespace_name, repository_name, performer, ip, metadata):
- """
+ pass
+
+ @abstractmethod
+ def log_action(
+ self, kind, namespace_name, repository_name, performer, ip, metadata
+ ):
+ """
Args:
kind: type of log
@@ -50,5 +52,4 @@ class InitDataInterface(object):
Returns:
None
"""
- pass
-
+ pass
diff --git a/endpoints/api/__init__models_pre_oci.py b/endpoints/api/__init__models_pre_oci.py
index f14e7267c..146d6f4eb 100644
--- a/endpoints/api/__init__models_pre_oci.py
+++ b/endpoints/api/__init__models_pre_oci.py
@@ -3,17 +3,31 @@ from __init__models_interface import InitDataInterface
from data import model
from data.logs_model import logs_model
+
class PreOCIModel(InitDataInterface):
- def is_app_repository(self, namespace_name, repository_name):
- return model.repository.get_repository(namespace_name, repository_name,
- kind_filter='application') is not None
+ def is_app_repository(self, namespace_name, repository_name):
+ return (
+ model.repository.get_repository(
+ namespace_name, repository_name, kind_filter="application"
+ )
+ is not None
+ )
- def repository_is_public(self, namespace_name, repository_name):
- return model.repository.repository_is_public(namespace_name, repository_name)
+ def repository_is_public(self, namespace_name, repository_name):
+ return model.repository.repository_is_public(namespace_name, repository_name)
+
+ def log_action(
+ self, kind, namespace_name, repository_name, performer, ip, metadata
+ ):
+ repository = model.repository.get_repository(namespace_name, repository_name)
+ logs_model.log_action(
+ kind,
+ namespace_name,
+ performer=performer,
+ ip=ip,
+ metadata=metadata,
+ repository=repository,
+ )
- def log_action(self, kind, namespace_name, repository_name, performer, ip, metadata):
- repository = model.repository.get_repository(namespace_name, repository_name)
- logs_model.log_action(kind, namespace_name, performer=performer, ip=ip, metadata=metadata,
- repository=repository)
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/appspecifictokens.py b/endpoints/api/appspecifictokens.py
index 1e886c385..a5c5d4fd3 100644
--- a/endpoints/api/appspecifictokens.py
+++ b/endpoints/api/appspecifictokens.py
@@ -11,123 +11,143 @@ import features
from app import app
from auth.auth_context import get_authenticated_user
from data import model
-from endpoints.api import (ApiResource, nickname, resource, validate_json_request,
- log_action, require_user_admin, require_fresh_login,
- path_param, NotFound, format_date, show_if, query_param, parse_args,
- truthy_bool)
+from endpoints.api import (
+ ApiResource,
+ nickname,
+ resource,
+ validate_json_request,
+ log_action,
+ require_user_admin,
+ require_fresh_login,
+ path_param,
+ NotFound,
+ format_date,
+ show_if,
+ query_param,
+ parse_args,
+ truthy_bool,
+)
from util.timedeltastring import convert_to_timedelta
logger = logging.getLogger(__name__)
def token_view(token, include_code=False):
- data = {
- 'uuid': token.uuid,
- 'title': token.title,
- 'last_accessed': format_date(token.last_accessed),
- 'created': format_date(token.created),
- 'expiration': format_date(token.expiration),
- }
+ data = {
+ "uuid": token.uuid,
+ "title": token.title,
+ "last_accessed": format_date(token.last_accessed),
+ "created": format_date(token.created),
+ "expiration": format_date(token.expiration),
+ }
- if include_code:
- data.update({
- 'token_code': model.appspecifictoken.get_full_token_string(token),
- })
+ if include_code:
+ data.update({"token_code": model.appspecifictoken.get_full_token_string(token)})
- return data
+ return data
# The default window to use when looking up tokens that will be expiring.
-_DEFAULT_TOKEN_EXPIRATION_WINDOW = '4w'
+_DEFAULT_TOKEN_EXPIRATION_WINDOW = "4w"
-@resource('/v1/user/apptoken')
+@resource("/v1/user/apptoken")
@show_if(features.APP_SPECIFIC_TOKENS)
class AppTokens(ApiResource):
- """ Lists all app specific tokens for a user """
- schemas = {
- 'NewToken': {
- 'type': 'object',
- 'required': [
- 'title',
- ],
- 'properties': {
- 'title': {
- 'type': 'string',
- 'description': 'The user-defined title for the token',
- },
- }
- },
- }
+ """ Lists all app specific tokens for a user """
- @require_user_admin
- @nickname('listAppTokens')
- @parse_args()
- @query_param('expiring', 'If true, only returns those tokens expiring soon', type=truthy_bool)
- def get(self, parsed_args):
- """ Lists the app specific tokens for the user. """
- expiring = parsed_args['expiring']
- if expiring:
- expiration = app.config.get('APP_SPECIFIC_TOKEN_EXPIRATION')
- token_expiration = convert_to_timedelta(expiration or _DEFAULT_TOKEN_EXPIRATION_WINDOW)
- seconds = math.ceil(token_expiration.total_seconds() * 0.1) or 1
- soon = timedelta(seconds=seconds)
- tokens = model.appspecifictoken.get_expiring_tokens(get_authenticated_user(), soon)
- else:
- tokens = model.appspecifictoken.list_tokens(get_authenticated_user())
-
- return {
- 'tokens': [token_view(token, include_code=False) for token in tokens],
- 'only_expiring': expiring,
+ schemas = {
+ "NewToken": {
+ "type": "object",
+ "required": ["title"],
+ "properties": {
+ "title": {
+ "type": "string",
+ "description": "The user-defined title for the token",
+ }
+ },
+ }
}
- @require_user_admin
- @require_fresh_login
- @nickname('createAppToken')
- @validate_json_request('NewToken')
- def post(self):
- """ Create a new app specific token for user. """
- title = request.get_json()['title']
- token = model.appspecifictoken.create_token(get_authenticated_user(), title)
+ @require_user_admin
+ @nickname("listAppTokens")
+ @parse_args()
+ @query_param(
+ "expiring", "If true, only returns those tokens expiring soon", type=truthy_bool
+ )
+ def get(self, parsed_args):
+ """ Lists the app specific tokens for the user. """
+ expiring = parsed_args["expiring"]
+ if expiring:
+ expiration = app.config.get("APP_SPECIFIC_TOKEN_EXPIRATION")
+ token_expiration = convert_to_timedelta(
+ expiration or _DEFAULT_TOKEN_EXPIRATION_WINDOW
+ )
+ seconds = math.ceil(token_expiration.total_seconds() * 0.1) or 1
+ soon = timedelta(seconds=seconds)
+ tokens = model.appspecifictoken.get_expiring_tokens(
+ get_authenticated_user(), soon
+ )
+ else:
+ tokens = model.appspecifictoken.list_tokens(get_authenticated_user())
- log_action('create_app_specific_token', get_authenticated_user().username,
- {'app_specific_token_title': token.title,
- 'app_specific_token': token.uuid})
+ return {
+ "tokens": [token_view(token, include_code=False) for token in tokens],
+ "only_expiring": expiring,
+ }
- return {
- 'token': token_view(token, include_code=True),
- }
+ @require_user_admin
+ @require_fresh_login
+ @nickname("createAppToken")
+ @validate_json_request("NewToken")
+ def post(self):
+ """ Create a new app specific token for user. """
+ title = request.get_json()["title"]
+ token = model.appspecifictoken.create_token(get_authenticated_user(), title)
+
+ log_action(
+ "create_app_specific_token",
+ get_authenticated_user().username,
+ {"app_specific_token_title": token.title, "app_specific_token": token.uuid},
+ )
+
+ return {"token": token_view(token, include_code=True)}
-@resource('/v1/user/apptoken/')
+@resource("/v1/user/apptoken/")
@show_if(features.APP_SPECIFIC_TOKENS)
-@path_param('token_uuid', 'The uuid of the app specific token')
+@path_param("token_uuid", "The uuid of the app specific token")
class AppToken(ApiResource):
- """ Provides operations on an app specific token """
- @require_user_admin
- @require_fresh_login
- @nickname('getAppToken')
- def get(self, token_uuid):
- """ Returns a specific app token for the user. """
- token = model.appspecifictoken.get_token_by_uuid(token_uuid, owner=get_authenticated_user())
- if token is None:
- raise NotFound()
+ """ Provides operations on an app specific token """
- return {
- 'token': token_view(token, include_code=True),
- }
+ @require_user_admin
+ @require_fresh_login
+ @nickname("getAppToken")
+ def get(self, token_uuid):
+ """ Returns a specific app token for the user. """
+ token = model.appspecifictoken.get_token_by_uuid(
+ token_uuid, owner=get_authenticated_user()
+ )
+ if token is None:
+ raise NotFound()
- @require_user_admin
- @require_fresh_login
- @nickname('revokeAppToken')
- def delete(self, token_uuid):
- """ Revokes a specific app token for the user. """
- token = model.appspecifictoken.revoke_token_by_uuid(token_uuid, owner=get_authenticated_user())
- if token is None:
- raise NotFound()
+ return {"token": token_view(token, include_code=True)}
- log_action('revoke_app_specific_token', get_authenticated_user().username,
- {'app_specific_token_title': token.title,
- 'app_specific_token': token.uuid})
+ @require_user_admin
+ @require_fresh_login
+ @nickname("revokeAppToken")
+ def delete(self, token_uuid):
+ """ Revokes a specific app token for the user. """
+ token = model.appspecifictoken.revoke_token_by_uuid(
+ token_uuid, owner=get_authenticated_user()
+ )
+ if token is None:
+ raise NotFound()
- return '', 204
+ log_action(
+ "revoke_app_specific_token",
+ get_authenticated_user().username,
+ {"app_specific_token_title": token.title, "app_specific_token": token.uuid},
+ )
+
+ return "", 204
diff --git a/endpoints/api/billing.py b/endpoints/api/billing.py
index db7158d12..eda3a301b 100644
--- a/endpoints/api/billing.py
+++ b/endpoints/api/billing.py
@@ -4,9 +4,20 @@ import stripe
from flask import request
from app import billing
-from endpoints.api import (resource, nickname, ApiResource, validate_json_request, log_action,
- related_user_resource, internal_only, require_user_admin, show_if,
- path_param, require_scope, abort)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ validate_json_request,
+ log_action,
+ related_user_resource,
+ internal_only,
+ require_user_admin,
+ show_if,
+ path_param,
+ require_scope,
+ abort,
+)
from endpoints.exception import Unauthorized, NotFound
from endpoints.api.subscribe import subscribe, subscription_view
from auth.permissions import AdministerOrganizationPermission
@@ -19,589 +30,587 @@ import features
import uuid
import json
+
def get_namespace_plan(namespace):
- """ Returns the plan of the given namespace. """
- namespace_user = model.user.get_namespace_user(namespace)
- if namespace_user is None:
- return None
+ """ Returns the plan of the given namespace. """
+ namespace_user = model.user.get_namespace_user(namespace)
+ if namespace_user is None:
+ return None
- if not namespace_user.stripe_id:
- return None
+ if not namespace_user.stripe_id:
+ return None
- # Ask Stripe for the subscribed plan.
- # TODO: Can we cache this or make it faster somehow?
- try:
- cus = billing.Customer.retrieve(namespace_user.stripe_id)
- except stripe.error.APIConnectionError:
- abort(503, message='Cannot contact Stripe')
+ # Ask Stripe for the subscribed plan.
+ # TODO: Can we cache this or make it faster somehow?
+ try:
+ cus = billing.Customer.retrieve(namespace_user.stripe_id)
+ except stripe.error.APIConnectionError:
+ abort(503, message="Cannot contact Stripe")
- if not cus.subscription:
- return None
+ if not cus.subscription:
+ return None
- return get_plan(cus.subscription.plan.id)
+ return get_plan(cus.subscription.plan.id)
def lookup_allowed_private_repos(namespace):
- """ Returns false if the given namespace has used its allotment of private repositories. """
- current_plan = get_namespace_plan(namespace)
- if current_plan is None:
- return False
+ """ Returns false if the given namespace has used its allotment of private repositories. """
+ current_plan = get_namespace_plan(namespace)
+ if current_plan is None:
+ return False
- # Find the number of private repositories used by the namespace and compare it to the
- # plan subscribed.
- private_repos = model.user.get_private_repo_count(namespace)
+ # Find the number of private repositories used by the namespace and compare it to the
+ # plan subscribed.
+ private_repos = model.user.get_private_repo_count(namespace)
- return private_repos < current_plan['privateRepos']
+ return private_repos < current_plan["privateRepos"]
def carderror_response(e):
- return {'carderror': str(e)}, 402
+ return {"carderror": str(e)}, 402
def get_card(user):
- card_info = {
- 'is_valid': False
- }
+ card_info = {"is_valid": False}
- if user.stripe_id:
- try:
- cus = billing.Customer.retrieve(user.stripe_id)
- except stripe.error.APIConnectionError as e:
- abort(503, message='Cannot contact Stripe')
+ if user.stripe_id:
+ try:
+ cus = billing.Customer.retrieve(user.stripe_id)
+ except stripe.error.APIConnectionError as e:
+ abort(503, message="Cannot contact Stripe")
- if cus and cus.default_card:
- # Find the default card.
- default_card = None
- for card in cus.cards.data:
- if card.id == cus.default_card:
- default_card = card
- break
+ if cus and cus.default_card:
+ # Find the default card.
+ default_card = None
+ for card in cus.cards.data:
+ if card.id == cus.default_card:
+ default_card = card
+ break
- if default_card:
- card_info = {
- 'owner': default_card.name,
- 'type': default_card.type,
- 'last4': default_card.last4,
- 'exp_month': default_card.exp_month,
- 'exp_year': default_card.exp_year
- }
+ if default_card:
+ card_info = {
+ "owner": default_card.name,
+ "type": default_card.type,
+ "last4": default_card.last4,
+ "exp_month": default_card.exp_month,
+ "exp_year": default_card.exp_year,
+ }
- return {'card': card_info}
+ return {"card": card_info}
def set_card(user, token):
- if user.stripe_id:
- try:
- cus = billing.Customer.retrieve(user.stripe_id)
- except stripe.error.APIConnectionError as e:
- abort(503, message='Cannot contact Stripe')
+ if user.stripe_id:
+ try:
+ cus = billing.Customer.retrieve(user.stripe_id)
+ except stripe.error.APIConnectionError as e:
+ abort(503, message="Cannot contact Stripe")
- if cus:
- try:
- cus.card = token
- cus.save()
- except stripe.error.CardError as exc:
- return carderror_response(exc)
- except stripe.error.InvalidRequestError as exc:
- return carderror_response(exc)
- except stripe.error.APIConnectionError as e:
- return carderror_response(e)
+ if cus:
+ try:
+ cus.card = token
+ cus.save()
+ except stripe.error.CardError as exc:
+ return carderror_response(exc)
+ except stripe.error.InvalidRequestError as exc:
+ return carderror_response(exc)
+ except stripe.error.APIConnectionError as e:
+ return carderror_response(e)
- return get_card(user)
+ return get_card(user)
def get_invoices(customer_id):
- def invoice_view(i):
- return {
- 'id': i.id,
- 'date': i.date,
- 'period_start': i.period_start,
- 'period_end': i.period_end,
- 'paid': i.paid,
- 'amount_due': i.amount_due,
- 'next_payment_attempt': i.next_payment_attempt,
- 'attempted': i.attempted,
- 'closed': i.closed,
- 'total': i.total,
- 'plan': i.lines.data[0].plan.id if i.lines.data[0].plan else None
- }
+ def invoice_view(i):
+ return {
+ "id": i.id,
+ "date": i.date,
+ "period_start": i.period_start,
+ "period_end": i.period_end,
+ "paid": i.paid,
+ "amount_due": i.amount_due,
+ "next_payment_attempt": i.next_payment_attempt,
+ "attempted": i.attempted,
+ "closed": i.closed,
+ "total": i.total,
+ "plan": i.lines.data[0].plan.id if i.lines.data[0].plan else None,
+ }
- try:
- invoices = billing.Invoice.list(customer=customer_id, count=12)
- except stripe.error.APIConnectionError as e:
- abort(503, message='Cannot contact Stripe')
+ try:
+ invoices = billing.Invoice.list(customer=customer_id, count=12)
+ except stripe.error.APIConnectionError as e:
+ abort(503, message="Cannot contact Stripe")
- return {
- 'invoices': [invoice_view(i) for i in invoices.data]
- }
+ return {"invoices": [invoice_view(i) for i in invoices.data]}
def get_invoice_fields(user):
- try:
- cus = billing.Customer.retrieve(user.stripe_id)
- except stripe.error.APIConnectionError:
- abort(503, message='Cannot contact Stripe')
+ try:
+ cus = billing.Customer.retrieve(user.stripe_id)
+ except stripe.error.APIConnectionError:
+ abort(503, message="Cannot contact Stripe")
- if not 'metadata' in cus:
- cus.metadata = {}
+ if not "metadata" in cus:
+ cus.metadata = {}
- return json.loads(cus.metadata.get('invoice_fields') or '[]'), cus
+ return json.loads(cus.metadata.get("invoice_fields") or "[]"), cus
def create_billing_invoice_field(user, title, value):
- new_field = {
- 'uuid': str(uuid.uuid4()).split('-')[0],
- 'title': title,
- 'value': value
- }
+ new_field = {
+ "uuid": str(uuid.uuid4()).split("-")[0],
+ "title": title,
+ "value": value,
+ }
- invoice_fields, cus = get_invoice_fields(user)
- invoice_fields.append(new_field)
+ invoice_fields, cus = get_invoice_fields(user)
+ invoice_fields.append(new_field)
- if not 'metadata' in cus:
- cus.metadata = {}
+ if not "metadata" in cus:
+ cus.metadata = {}
- cus.metadata['invoice_fields'] = json.dumps(invoice_fields)
- cus.save()
- return new_field
+ cus.metadata["invoice_fields"] = json.dumps(invoice_fields)
+ cus.save()
+ return new_field
def delete_billing_invoice_field(user, field_uuid):
- invoice_fields, cus = get_invoice_fields(user)
- invoice_fields = [field for field in invoice_fields if not field['uuid'] == field_uuid]
+ invoice_fields, cus = get_invoice_fields(user)
+ invoice_fields = [
+ field for field in invoice_fields if not field["uuid"] == field_uuid
+ ]
- if not 'metadata' in cus:
- cus.metadata = {}
+ if not "metadata" in cus:
+ cus.metadata = {}
- cus.metadata['invoice_fields'] = json.dumps(invoice_fields)
- cus.save()
- return True
+ cus.metadata["invoice_fields"] = json.dumps(invoice_fields)
+ cus.save()
+ return True
-@resource('/v1/plans/')
+@resource("/v1/plans/")
@show_if(features.BILLING)
class ListPlans(ApiResource):
- """ Resource for listing the available plans. """
- @nickname('listPlans')
- def get(self):
- """ List the avaialble plans. """
- return {
- 'plans': PLANS,
- }
+ """ Resource for listing the available plans. """
+
+ @nickname("listPlans")
+ def get(self):
+ """ List the avaialble plans. """
+ return {"plans": PLANS}
-@resource('/v1/user/card')
+@resource("/v1/user/card")
@internal_only
@show_if(features.BILLING)
class UserCard(ApiResource):
- """ Resource for managing a user's credit card. """
- schemas = {
- 'UserCard': {
- 'id': 'UserCard',
- 'type': 'object',
- 'description': 'Description of a user card',
- 'required': [
- 'token',
- ],
- 'properties': {
- 'token': {
- 'type': 'string',
- 'description': 'Stripe token that is generated by stripe checkout.js',
- },
- },
- },
- }
+ """ Resource for managing a user's credit card. """
- @require_user_admin
- @nickname('getUserCard')
- def get(self):
- """ Get the user's credit card. """
- user = get_authenticated_user()
- return get_card(user)
+ schemas = {
+ "UserCard": {
+ "id": "UserCard",
+ "type": "object",
+ "description": "Description of a user card",
+ "required": ["token"],
+ "properties": {
+ "token": {
+ "type": "string",
+ "description": "Stripe token that is generated by stripe checkout.js",
+ }
+ },
+ }
+ }
- @require_user_admin
- @nickname('setUserCard')
- @validate_json_request('UserCard')
- def post(self):
- """ Update the user's credit card. """
- user = get_authenticated_user()
- token = request.get_json()['token']
- response = set_card(user, token)
- log_action('account_change_cc', user.username)
- return response
+ @require_user_admin
+ @nickname("getUserCard")
+ def get(self):
+ """ Get the user's credit card. """
+ user = get_authenticated_user()
+ return get_card(user)
+
+ @require_user_admin
+ @nickname("setUserCard")
+ @validate_json_request("UserCard")
+ def post(self):
+ """ Update the user's credit card. """
+ user = get_authenticated_user()
+ token = request.get_json()["token"]
+ response = set_card(user, token)
+ log_action("account_change_cc", user.username)
+ return response
-@resource('/v1/organization//card')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//card")
+@path_param("orgname", "The name of the organization")
@internal_only
@related_user_resource(UserCard)
@show_if(features.BILLING)
class OrganizationCard(ApiResource):
- """ Resource for managing an organization's credit card. """
- schemas = {
- 'OrgCard': {
- 'id': 'OrgCard',
- 'type': 'object',
- 'description': 'Description of a user card',
- 'required': [
- 'token',
- ],
- 'properties': {
- 'token': {
- 'type': 'string',
- 'description': 'Stripe token that is generated by stripe checkout.js',
- },
- },
- },
- }
+ """ Resource for managing an organization's credit card. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrgCard')
- def get(self, orgname):
- """ Get the organization's credit card. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- return get_card(organization)
+ schemas = {
+ "OrgCard": {
+ "id": "OrgCard",
+ "type": "object",
+ "description": "Description of a user card",
+ "required": ["token"],
+ "properties": {
+ "token": {
+ "type": "string",
+ "description": "Stripe token that is generated by stripe checkout.js",
+ }
+ },
+ }
+ }
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrgCard")
+ def get(self, orgname):
+ """ Get the organization's credit card. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ return get_card(organization)
- @nickname('setOrgCard')
- @validate_json_request('OrgCard')
- def post(self, orgname):
- """ Update the orgnaization's credit card. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- token = request.get_json()['token']
- response = set_card(organization, token)
- log_action('account_change_cc', orgname)
- return response
+ raise Unauthorized()
- raise Unauthorized()
+ @nickname("setOrgCard")
+ @validate_json_request("OrgCard")
+ def post(self, orgname):
+ """ Update the orgnaization's credit card. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ token = request.get_json()["token"]
+ response = set_card(organization, token)
+ log_action("account_change_cc", orgname)
+ return response
+
+ raise Unauthorized()
-@resource('/v1/user/plan')
+@resource("/v1/user/plan")
@internal_only
@show_if(features.BILLING)
class UserPlan(ApiResource):
- """ Resource for managing a user's subscription. """
- schemas = {
- 'UserSubscription': {
- 'id': 'UserSubscription',
- 'type': 'object',
- 'description': 'Description of a user card',
- 'required': [
- 'plan',
- ],
- 'properties': {
- 'token': {
- 'type': 'string',
- 'description': 'Stripe token that is generated by stripe checkout.js',
- },
- 'plan': {
- 'type': 'string',
- 'description': 'Plan name to which the user wants to subscribe',
- },
- },
- },
- }
+ """ Resource for managing a user's subscription. """
- @require_user_admin
- @nickname('updateUserSubscription')
- @validate_json_request('UserSubscription')
- def put(self):
- """ Create or update the user's subscription. """
- request_data = request.get_json()
- plan = request_data['plan']
- token = request_data['token'] if 'token' in request_data else None
- user = get_authenticated_user()
- return subscribe(user, plan, token, False) # Business features not required
-
- @require_user_admin
- @nickname('getUserSubscription')
- def get(self):
- """ Fetch any existing subscription for the user. """
- cus = None
- user = get_authenticated_user()
- private_repos = model.user.get_private_repo_count(user.username)
-
- if user.stripe_id:
- try:
- cus = billing.Customer.retrieve(user.stripe_id)
- except stripe.error.APIConnectionError as e:
- abort(503, message='Cannot contact Stripe')
-
- if cus.subscription:
- return subscription_view(cus.subscription, private_repos)
-
- return {
- 'hasSubscription': False,
- 'isExistingCustomer': cus is not None,
- 'plan': 'free',
- 'usedPrivateRepos': private_repos,
+ schemas = {
+ "UserSubscription": {
+ "id": "UserSubscription",
+ "type": "object",
+ "description": "Description of a user card",
+ "required": ["plan"],
+ "properties": {
+ "token": {
+ "type": "string",
+ "description": "Stripe token that is generated by stripe checkout.js",
+ },
+ "plan": {
+ "type": "string",
+ "description": "Plan name to which the user wants to subscribe",
+ },
+ },
+ }
}
+ @require_user_admin
+ @nickname("updateUserSubscription")
+ @validate_json_request("UserSubscription")
+ def put(self):
+ """ Create or update the user's subscription. """
+ request_data = request.get_json()
+ plan = request_data["plan"]
+ token = request_data["token"] if "token" in request_data else None
+ user = get_authenticated_user()
+ return subscribe(user, plan, token, False) # Business features not required
-@resource('/v1/organization//plan')
-@path_param('orgname', 'The name of the organization')
+ @require_user_admin
+ @nickname("getUserSubscription")
+ def get(self):
+ """ Fetch any existing subscription for the user. """
+ cus = None
+ user = get_authenticated_user()
+ private_repos = model.user.get_private_repo_count(user.username)
+
+ if user.stripe_id:
+ try:
+ cus = billing.Customer.retrieve(user.stripe_id)
+ except stripe.error.APIConnectionError as e:
+ abort(503, message="Cannot contact Stripe")
+
+ if cus.subscription:
+ return subscription_view(cus.subscription, private_repos)
+
+ return {
+ "hasSubscription": False,
+ "isExistingCustomer": cus is not None,
+ "plan": "free",
+ "usedPrivateRepos": private_repos,
+ }
+
+
+@resource("/v1/organization//plan")
+@path_param("orgname", "The name of the organization")
@internal_only
@related_user_resource(UserPlan)
@show_if(features.BILLING)
class OrganizationPlan(ApiResource):
- """ Resource for managing a org's subscription. """
- schemas = {
- 'OrgSubscription': {
- 'id': 'OrgSubscription',
- 'type': 'object',
- 'description': 'Description of a user card',
- 'required': [
- 'plan',
- ],
- 'properties': {
- 'token': {
- 'type': 'string',
- 'description': 'Stripe token that is generated by stripe checkout.js',
- },
- 'plan': {
- 'type': 'string',
- 'description': 'Plan name to which the user wants to subscribe',
- },
- },
- },
- }
+ """ Resource for managing a org's subscription. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('updateOrgSubscription')
- @validate_json_request('OrgSubscription')
- def put(self, orgname):
- """ Create or update the org's subscription. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- request_data = request.get_json()
- plan = request_data['plan']
- token = request_data['token'] if 'token' in request_data else None
- organization = model.organization.get_organization(orgname)
- return subscribe(organization, plan, token, True) # Business plan required
+ schemas = {
+ "OrgSubscription": {
+ "id": "OrgSubscription",
+ "type": "object",
+ "description": "Description of a user card",
+ "required": ["plan"],
+ "properties": {
+ "token": {
+ "type": "string",
+ "description": "Stripe token that is generated by stripe checkout.js",
+ },
+ "plan": {
+ "type": "string",
+ "description": "Plan name to which the user wants to subscribe",
+ },
+ },
+ }
+ }
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("updateOrgSubscription")
+ @validate_json_request("OrgSubscription")
+ def put(self, orgname):
+ """ Create or update the org's subscription. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ request_data = request.get_json()
+ plan = request_data["plan"]
+ token = request_data["token"] if "token" in request_data else None
+ organization = model.organization.get_organization(orgname)
+ return subscribe(organization, plan, token, True) # Business plan required
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrgSubscription')
- def get(self, orgname):
- """ Fetch any existing subscription for the org. """
- cus = None
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- private_repos = model.user.get_private_repo_count(orgname)
- organization = model.organization.get_organization(orgname)
- if organization.stripe_id:
- try:
- cus = billing.Customer.retrieve(organization.stripe_id)
- except stripe.error.APIConnectionError as e:
- abort(503, message='Cannot contact Stripe')
+ raise Unauthorized()
- if cus.subscription:
- return subscription_view(cus.subscription, private_repos)
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrgSubscription")
+ def get(self, orgname):
+ """ Fetch any existing subscription for the org. """
+ cus = None
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ private_repos = model.user.get_private_repo_count(orgname)
+ organization = model.organization.get_organization(orgname)
+ if organization.stripe_id:
+ try:
+ cus = billing.Customer.retrieve(organization.stripe_id)
+ except stripe.error.APIConnectionError as e:
+ abort(503, message="Cannot contact Stripe")
- return {
- 'hasSubscription': False,
- 'isExistingCustomer': cus is not None,
- 'plan': 'free',
- 'usedPrivateRepos': private_repos,
- }
+ if cus.subscription:
+ return subscription_view(cus.subscription, private_repos)
- raise Unauthorized()
+ return {
+ "hasSubscription": False,
+ "isExistingCustomer": cus is not None,
+ "plan": "free",
+ "usedPrivateRepos": private_repos,
+ }
+
+ raise Unauthorized()
-@resource('/v1/user/invoices')
+@resource("/v1/user/invoices")
@internal_only
@show_if(features.BILLING)
class UserInvoiceList(ApiResource):
- """ Resource for listing a user's invoices. """
- @require_user_admin
- @nickname('listUserInvoices')
- def get(self):
- """ List the invoices for the current user. """
- user = get_authenticated_user()
- if not user.stripe_id:
- raise NotFound()
+ """ Resource for listing a user's invoices. """
- return get_invoices(user.stripe_id)
+ @require_user_admin
+ @nickname("listUserInvoices")
+ def get(self):
+ """ List the invoices for the current user. """
+ user = get_authenticated_user()
+ if not user.stripe_id:
+ raise NotFound()
+
+ return get_invoices(user.stripe_id)
-@resource('/v1/organization//invoices')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//invoices")
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserInvoiceList)
@show_if(features.BILLING)
class OrganizationInvoiceList(ApiResource):
- """ Resource for listing an orgnaization's invoices. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('listOrgInvoices')
- def get(self, orgname):
- """ List the invoices for the specified orgnaization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- if not organization.stripe_id:
- raise NotFound()
+ """ Resource for listing an orgnaization's invoices. """
- return get_invoices(organization.stripe_id)
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("listOrgInvoices")
+ def get(self, orgname):
+ """ List the invoices for the specified orgnaization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ if not organization.stripe_id:
+ raise NotFound()
- raise Unauthorized()
+ return get_invoices(organization.stripe_id)
+
+ raise Unauthorized()
-@resource('/v1/user/invoice/fields')
+@resource("/v1/user/invoice/fields")
@internal_only
@show_if(features.BILLING)
class UserInvoiceFieldList(ApiResource):
- """ Resource for listing and creating a user's custom invoice fields. """
- schemas = {
- 'InvoiceField': {
- 'id': 'InvoiceField',
- 'type': 'object',
- 'description': 'Description of an invoice field',
- 'required': [
- 'title', 'value'
- ],
- 'properties': {
- 'title': {
- 'type': 'string',
- 'description': 'The title of the field being added',
- },
- 'value': {
- 'type': 'string',
- 'description': 'The value of the field being added',
- },
- },
- },
- }
+ """ Resource for listing and creating a user's custom invoice fields. """
- @require_user_admin
- @nickname('listUserInvoiceFields')
- def get(self):
- """ List the invoice fields for the current user. """
- user = get_authenticated_user()
- if not user.stripe_id:
- raise NotFound()
+ schemas = {
+ "InvoiceField": {
+ "id": "InvoiceField",
+ "type": "object",
+ "description": "Description of an invoice field",
+ "required": ["title", "value"],
+ "properties": {
+ "title": {
+ "type": "string",
+ "description": "The title of the field being added",
+ },
+ "value": {
+ "type": "string",
+ "description": "The value of the field being added",
+ },
+ },
+ }
+ }
- return {'fields': get_invoice_fields(user)[0]}
+ @require_user_admin
+ @nickname("listUserInvoiceFields")
+ def get(self):
+ """ List the invoice fields for the current user. """
+ user = get_authenticated_user()
+ if not user.stripe_id:
+ raise NotFound()
- @require_user_admin
- @nickname('createUserInvoiceField')
- @validate_json_request('InvoiceField')
- def post(self):
- """ Creates a new invoice field. """
- user = get_authenticated_user()
- if not user.stripe_id:
- raise NotFound()
+ return {"fields": get_invoice_fields(user)[0]}
- data = request.get_json()
- created_field = create_billing_invoice_field(user, data['title'], data['value'])
- return created_field
+ @require_user_admin
+ @nickname("createUserInvoiceField")
+ @validate_json_request("InvoiceField")
+ def post(self):
+ """ Creates a new invoice field. """
+ user = get_authenticated_user()
+ if not user.stripe_id:
+ raise NotFound()
+
+ data = request.get_json()
+ created_field = create_billing_invoice_field(user, data["title"], data["value"])
+ return created_field
-@resource('/v1/user/invoice/field/')
+@resource("/v1/user/invoice/field/")
@internal_only
@show_if(features.BILLING)
class UserInvoiceField(ApiResource):
- """ Resource for deleting a user's custom invoice fields. """
- @require_user_admin
- @nickname('deleteUserInvoiceField')
- def delete(self, field_uuid):
- """ Deletes the invoice field for the current user. """
- user = get_authenticated_user()
- if not user.stripe_id:
- raise NotFound()
+ """ Resource for deleting a user's custom invoice fields. """
- result = delete_billing_invoice_field(user, field_uuid)
- if not result:
- abort(404)
+ @require_user_admin
+ @nickname("deleteUserInvoiceField")
+ def delete(self, field_uuid):
+ """ Deletes the invoice field for the current user. """
+ user = get_authenticated_user()
+ if not user.stripe_id:
+ raise NotFound()
- return 'Okay', 201
+ result = delete_billing_invoice_field(user, field_uuid)
+ if not result:
+ abort(404)
+
+ return "Okay", 201
-@resource('/v1/organization//invoice/fields')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//invoice/fields")
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserInvoiceFieldList)
@internal_only
@show_if(features.BILLING)
class OrganizationInvoiceFieldList(ApiResource):
- """ Resource for listing and creating an organization's custom invoice fields. """
- schemas = {
- 'InvoiceField': {
- 'id': 'InvoiceField',
- 'type': 'object',
- 'description': 'Description of an invoice field',
- 'required': [
- 'title', 'value'
- ],
- 'properties': {
- 'title': {
- 'type': 'string',
- 'description': 'The title of the field being added',
- },
- 'value': {
- 'type': 'string',
- 'description': 'The value of the field being added',
- },
- },
- },
- }
+ """ Resource for listing and creating an organization's custom invoice fields. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('listOrgInvoiceFields')
- def get(self, orgname):
- """ List the invoice fields for the organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- if not organization.stripe_id:
- raise NotFound()
+ schemas = {
+ "InvoiceField": {
+ "id": "InvoiceField",
+ "type": "object",
+ "description": "Description of an invoice field",
+ "required": ["title", "value"],
+ "properties": {
+ "title": {
+ "type": "string",
+ "description": "The title of the field being added",
+ },
+ "value": {
+ "type": "string",
+ "description": "The value of the field being added",
+ },
+ },
+ }
+ }
- return {'fields': get_invoice_fields(organization)[0]}
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("listOrgInvoiceFields")
+ def get(self, orgname):
+ """ List the invoice fields for the organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ if not organization.stripe_id:
+ raise NotFound()
- abort(403)
+ return {"fields": get_invoice_fields(organization)[0]}
- @require_scope(scopes.ORG_ADMIN)
- @nickname('createOrgInvoiceField')
- @validate_json_request('InvoiceField')
- def post(self, orgname):
- """ Creates a new invoice field. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- if not organization.stripe_id:
- raise NotFound()
+ abort(403)
- data = request.get_json()
- created_field = create_billing_invoice_field(organization, data['title'], data['value'])
- return created_field
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("createOrgInvoiceField")
+ @validate_json_request("InvoiceField")
+ def post(self, orgname):
+ """ Creates a new invoice field. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ if not organization.stripe_id:
+ raise NotFound()
- abort(403)
+ data = request.get_json()
+ created_field = create_billing_invoice_field(
+ organization, data["title"], data["value"]
+ )
+ return created_field
+
+ abort(403)
-@resource('/v1/organization//invoice/field/')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//invoice/field/")
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserInvoiceField)
@internal_only
@show_if(features.BILLING)
class OrganizationInvoiceField(ApiResource):
- """ Resource for deleting an organization's custom invoice fields. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrgInvoiceField')
- def delete(self, orgname, field_uuid):
- """ Deletes the invoice field for the current user. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- if not organization.stripe_id:
- raise NotFound()
+ """ Resource for deleting an organization's custom invoice fields. """
- result = delete_billing_invoice_field(organization, field_uuid)
- if not result:
- abort(404)
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrgInvoiceField")
+ def delete(self, orgname, field_uuid):
+ """ Deletes the invoice field for the current user. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ if not organization.stripe_id:
+ raise NotFound()
- return 'Okay', 201
+ result = delete_billing_invoice_field(organization, field_uuid)
+ if not result:
+ abort(404)
- abort(403)
+ return "Okay", 201
+
+ abort(403)
diff --git a/endpoints/api/build.py b/endpoints/api/build.py
index d7fb55ae1..d3f88f119 100644
--- a/endpoints/api/build.py
+++ b/endpoints/api/build.py
@@ -11,20 +11,42 @@ from urlparse import urlparse
import features
from app import userfiles as user_files, build_logs, log_archive, dockerfile_build_queue
-from auth.permissions import (ReadRepositoryPermission, ModifyRepositoryPermission,
- AdministerRepositoryPermission, AdministerOrganizationPermission,
- SuperUserPermission)
+from auth.permissions import (
+ ReadRepositoryPermission,
+ ModifyRepositoryPermission,
+ AdministerRepositoryPermission,
+ AdministerOrganizationPermission,
+ SuperUserPermission,
+)
from buildtrigger.basehandler import BuildTriggerHandler
from data import database
from data import model
from data.buildlogs import BuildStatusRetrievalError
-from endpoints.api import (RepositoryParamResource, parse_args, query_param, nickname, resource,
- require_repo_read, require_repo_write, validate_json_request,
- ApiResource, internal_only, format_date, api, path_param,
- require_repo_admin, abort, disallow_for_app_repositories,
- disallow_for_non_normal_repositories)
-from endpoints.building import (start_build, PreparedBuild, MaximumBuildsQueuedException,
- BuildTriggerDisabledException)
+from endpoints.api import (
+ RepositoryParamResource,
+ parse_args,
+ query_param,
+ nickname,
+ resource,
+ require_repo_read,
+ require_repo_write,
+ validate_json_request,
+ ApiResource,
+ internal_only,
+ format_date,
+ api,
+ path_param,
+ require_repo_admin,
+ abort,
+ disallow_for_app_repositories,
+ disallow_for_non_normal_repositories,
+)
+from endpoints.building import (
+ start_build,
+ PreparedBuild,
+ MaximumBuildsQueuedException,
+ BuildTriggerDisabledException,
+)
from endpoints.exception import Unauthorized, NotFound, InvalidRequest
from util.names import parse_robot_username
from util.request import get_request_ip
@@ -33,453 +55,477 @@ logger = logging.getLogger(__name__)
def get_trigger_config(trigger):
- try:
- return json.loads(trigger.config)
- except:
- return {}
+ try:
+ return json.loads(trigger.config)
+ except:
+ return {}
def get_job_config(build_obj):
- try:
- return json.loads(build_obj.job_config)
- except:
- return {}
+ try:
+ return json.loads(build_obj.job_config)
+ except:
+ return {}
def user_view(user):
- return {
- 'name': user.username,
- 'kind': 'user',
- 'is_robot': user.robot,
- }
+ return {"name": user.username, "kind": "user", "is_robot": user.robot}
def trigger_view(trigger, can_read=False, can_admin=False, for_build=False):
- if trigger and trigger.uuid:
- build_trigger = BuildTriggerHandler.get_handler(trigger)
- build_source = build_trigger.config.get('build_source')
+ if trigger and trigger.uuid:
+ build_trigger = BuildTriggerHandler.get_handler(trigger)
+ build_source = build_trigger.config.get("build_source")
- repo_url = build_trigger.get_repository_url() if build_source else None
- can_read = can_read or can_admin
+ repo_url = build_trigger.get_repository_url() if build_source else None
+ can_read = can_read or can_admin
- trigger_data = {
- 'id': trigger.uuid,
- 'service': trigger.service.name,
- 'is_active': build_trigger.is_active(),
+ trigger_data = {
+ "id": trigger.uuid,
+ "service": trigger.service.name,
+ "is_active": build_trigger.is_active(),
+ "build_source": build_source if can_read else None,
+ "repository_url": repo_url if can_read else None,
+ "config": build_trigger.config if can_admin else {},
+ "can_invoke": can_admin,
+ "enabled": trigger.enabled,
+ "disabled_reason": trigger.disabled_reason.name
+ if trigger.disabled_reason
+ else None,
+ }
- 'build_source': build_source if can_read else None,
- 'repository_url': repo_url if can_read else None,
+ if not for_build and can_admin and trigger.pull_robot:
+ trigger_data["pull_robot"] = user_view(trigger.pull_robot)
- 'config': build_trigger.config if can_admin else {},
- 'can_invoke': can_admin,
- 'enabled': trigger.enabled,
- 'disabled_reason': trigger.disabled_reason.name if trigger.disabled_reason else None,
- }
+ return trigger_data
- if not for_build and can_admin and trigger.pull_robot:
- trigger_data['pull_robot'] = user_view(trigger.pull_robot)
-
- return trigger_data
-
- return None
+ return None
def _get_build_status(build_obj):
- """ Returns the updated build phase, status and (if any) error for the build object. """
- phase = build_obj.phase
- status = {}
- error = None
+ """ Returns the updated build phase, status and (if any) error for the build object. """
+ phase = build_obj.phase
+ status = {}
+ error = None
- # If the build is currently running, then load its "real-time" status from Redis.
- if not database.BUILD_PHASE.is_terminal_phase(phase):
- try:
- status = build_logs.get_status(build_obj.uuid)
- except BuildStatusRetrievalError as bsre:
- phase = 'cannot_load'
- if SuperUserPermission().can():
- error = str(bsre)
- else:
- error = 'Redis may be down. Please contact support.'
+ # If the build is currently running, then load its "real-time" status from Redis.
+ if not database.BUILD_PHASE.is_terminal_phase(phase):
+ try:
+ status = build_logs.get_status(build_obj.uuid)
+ except BuildStatusRetrievalError as bsre:
+ phase = "cannot_load"
+ if SuperUserPermission().can():
+ error = str(bsre)
+ else:
+ error = "Redis may be down. Please contact support."
- if phase != 'cannot_load':
- # If the status contains a heartbeat, then check to see if has been written in the last few
- # minutes. If not, then the build timed out.
- if status is not None and 'heartbeat' in status and status['heartbeat']:
- heartbeat = datetime.datetime.utcfromtimestamp(status['heartbeat'])
- if datetime.datetime.utcnow() - heartbeat > datetime.timedelta(minutes=1):
- phase = database.BUILD_PHASE.INTERNAL_ERROR
+ if phase != "cannot_load":
+ # If the status contains a heartbeat, then check to see if has been written in the last few
+ # minutes. If not, then the build timed out.
+ if status is not None and "heartbeat" in status and status["heartbeat"]:
+ heartbeat = datetime.datetime.utcfromtimestamp(status["heartbeat"])
+ if datetime.datetime.utcnow() - heartbeat > datetime.timedelta(
+ minutes=1
+ ):
+ phase = database.BUILD_PHASE.INTERNAL_ERROR
- # If the phase is internal error, return 'expired' instead if the number of retries
- # on the queue item is 0.
- if phase == database.BUILD_PHASE.INTERNAL_ERROR:
- retry = (build_obj.queue_id and
- dockerfile_build_queue.has_retries_remaining(build_obj.queue_id))
- if not retry:
- phase = 'expired'
+ # If the phase is internal error, return 'expired' instead if the number of retries
+ # on the queue item is 0.
+ if phase == database.BUILD_PHASE.INTERNAL_ERROR:
+ retry = build_obj.queue_id and dockerfile_build_queue.has_retries_remaining(
+ build_obj.queue_id
+ )
+ if not retry:
+ phase = "expired"
- return (phase, status, error)
+ return (phase, status, error)
def build_status_view(build_obj):
- phase, status, error = _get_build_status(build_obj)
- repo_namespace = build_obj.repository.namespace_user.username
- repo_name = build_obj.repository.name
+ phase, status, error = _get_build_status(build_obj)
+ repo_namespace = build_obj.repository.namespace_user.username
+ repo_name = build_obj.repository.name
- can_read = ReadRepositoryPermission(repo_namespace, repo_name).can()
- can_write = ModifyRepositoryPermission(repo_namespace, repo_name).can()
- can_admin = AdministerRepositoryPermission(repo_namespace, repo_name).can()
+ can_read = ReadRepositoryPermission(repo_namespace, repo_name).can()
+ can_write = ModifyRepositoryPermission(repo_namespace, repo_name).can()
+ can_admin = AdministerRepositoryPermission(repo_namespace, repo_name).can()
- job_config = get_job_config(build_obj)
+ job_config = get_job_config(build_obj)
- resp = {
- 'id': build_obj.uuid,
- 'phase': phase,
- 'started': format_date(build_obj.started),
- 'display_name': build_obj.display_name,
- 'status': status or {},
- 'subdirectory': job_config.get('build_subdir', ''),
- 'dockerfile_path': job_config.get('build_subdir', ''),
- 'context': job_config.get('context', ''),
- 'tags': job_config.get('docker_tags', []),
- 'manual_user': job_config.get('manual_user', None),
- 'is_writer': can_write,
- 'trigger': trigger_view(build_obj.trigger, can_read, can_admin, for_build=True),
- 'trigger_metadata': job_config.get('trigger_metadata', None) if can_read else None,
- 'resource_key': build_obj.resource_key,
- 'pull_robot': user_view(build_obj.pull_robot) if build_obj.pull_robot else None,
- 'repository': {
- 'namespace': repo_namespace,
- 'name': repo_name
- },
- 'error': error,
- }
+ resp = {
+ "id": build_obj.uuid,
+ "phase": phase,
+ "started": format_date(build_obj.started),
+ "display_name": build_obj.display_name,
+ "status": status or {},
+ "subdirectory": job_config.get("build_subdir", ""),
+ "dockerfile_path": job_config.get("build_subdir", ""),
+ "context": job_config.get("context", ""),
+ "tags": job_config.get("docker_tags", []),
+ "manual_user": job_config.get("manual_user", None),
+ "is_writer": can_write,
+ "trigger": trigger_view(build_obj.trigger, can_read, can_admin, for_build=True),
+ "trigger_metadata": job_config.get("trigger_metadata", None)
+ if can_read
+ else None,
+ "resource_key": build_obj.resource_key,
+ "pull_robot": user_view(build_obj.pull_robot) if build_obj.pull_robot else None,
+ "repository": {"namespace": repo_namespace, "name": repo_name},
+ "error": error,
+ }
- if can_write or features.READER_BUILD_LOGS:
- if build_obj.resource_key is not None:
- resp['archive_url'] = user_files.get_file_url(build_obj.resource_key,
- get_request_ip(), requires_cors=True)
- elif job_config.get('archive_url', None):
- resp['archive_url'] = job_config['archive_url']
+ if can_write or features.READER_BUILD_LOGS:
+ if build_obj.resource_key is not None:
+ resp["archive_url"] = user_files.get_file_url(
+ build_obj.resource_key, get_request_ip(), requires_cors=True
+ )
+ elif job_config.get("archive_url", None):
+ resp["archive_url"] = job_config["archive_url"]
- return resp
+ return resp
-@resource('/v1/repository//build/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//build/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryBuildList(RepositoryParamResource):
- """ Resource related to creating and listing repository builds. """
- schemas = {
- 'RepositoryBuildRequest': {
- 'type': 'object',
- 'description': 'Description of a new repository build.',
- 'properties': {
- 'file_id': {
- 'type': 'string',
- 'description': 'The file id that was generated when the build spec was uploaded',
- },
- 'archive_url': {
- 'type': 'string',
- 'description': 'The URL of the .tar.gz to build. Must start with "http" or "https".',
- },
- 'subdirectory': {
- 'type': 'string',
- 'description': 'Subdirectory in which the Dockerfile can be found. You can only specify this or dockerfile_path',
- },
- 'dockerfile_path': {
- 'type': 'string',
- 'description': 'Path to a dockerfile. You can only specify this or subdirectory.',
- },
- 'context': {
- 'type': 'string',
- 'description': 'Pass in the context for the dockerfile. This is optional.',
- },
- 'pull_robot': {
- 'type': 'string',
- 'description': 'Username of a Quay robot account to use as pull credentials',
- },
- 'docker_tags': {
- 'type': 'array',
- 'description': 'The tags to which the built images will be pushed. ' +
- 'If none specified, "latest" is used.',
- 'items': {
- 'type': 'string'
- },
- 'minItems': 1,
- 'uniqueItems': True
+ """ Resource related to creating and listing repository builds. """
+
+ schemas = {
+ "RepositoryBuildRequest": {
+ "type": "object",
+ "description": "Description of a new repository build.",
+ "properties": {
+ "file_id": {
+ "type": "string",
+ "description": "The file id that was generated when the build spec was uploaded",
+ },
+ "archive_url": {
+ "type": "string",
+ "description": 'The URL of the .tar.gz to build. Must start with "http" or "https".',
+ },
+ "subdirectory": {
+ "type": "string",
+ "description": "Subdirectory in which the Dockerfile can be found. You can only specify this or dockerfile_path",
+ },
+ "dockerfile_path": {
+ "type": "string",
+ "description": "Path to a dockerfile. You can only specify this or subdirectory.",
+ },
+ "context": {
+ "type": "string",
+ "description": "Pass in the context for the dockerfile. This is optional.",
+ },
+ "pull_robot": {
+ "type": "string",
+ "description": "Username of a Quay robot account to use as pull credentials",
+ },
+ "docker_tags": {
+ "type": "array",
+ "description": "The tags to which the built images will be pushed. "
+ + 'If none specified, "latest" is used.',
+ "items": {"type": "string"},
+ "minItems": 1,
+ "uniqueItems": True,
+ },
+ },
}
- },
- },
- }
-
- @require_repo_read
- @parse_args()
- @query_param('limit', 'The maximum number of builds to return', type=int, default=5)
- @query_param('since', 'Returns all builds since the given unix timecode', type=int, default=None)
- @nickname('getRepoBuilds')
- @disallow_for_app_repositories
- def get(self, namespace, repository, parsed_args):
- """ Get the list of repository builds. """
- limit = parsed_args.get('limit', 5)
- since = parsed_args.get('since', None)
-
- if since is not None:
- since = datetime.datetime.utcfromtimestamp(since)
-
- builds = model.build.list_repository_builds(namespace, repository, limit, since=since)
- return {
- 'builds': [build_status_view(build) for build in builds]
}
- @require_repo_write
- @nickname('requestRepoBuild')
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- @validate_json_request('RepositoryBuildRequest')
- def post(self, namespace, repository):
- """ Request that a repository be built and pushed from the specified input. """
- logger.debug('User requested repository initialization.')
- request_json = request.get_json()
+ @require_repo_read
+ @parse_args()
+ @query_param("limit", "The maximum number of builds to return", type=int, default=5)
+ @query_param(
+ "since",
+ "Returns all builds since the given unix timecode",
+ type=int,
+ default=None,
+ )
+ @nickname("getRepoBuilds")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository, parsed_args):
+ """ Get the list of repository builds. """
+ limit = parsed_args.get("limit", 5)
+ since = parsed_args.get("since", None)
- dockerfile_id = request_json.get('file_id', None)
- archive_url = request_json.get('archive_url', None)
+ if since is not None:
+ since = datetime.datetime.utcfromtimestamp(since)
- if not dockerfile_id and not archive_url:
- raise InvalidRequest('file_id or archive_url required')
+ builds = model.build.list_repository_builds(
+ namespace, repository, limit, since=since
+ )
+ return {"builds": [build_status_view(build) for build in builds]}
- if archive_url:
- archive_match = None
- try:
- archive_match = urlparse(archive_url)
- except ValueError:
- pass
+ @require_repo_write
+ @nickname("requestRepoBuild")
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ @validate_json_request("RepositoryBuildRequest")
+ def post(self, namespace, repository):
+ """ Request that a repository be built and pushed from the specified input. """
+ logger.debug("User requested repository initialization.")
+ request_json = request.get_json()
- if not archive_match:
- raise InvalidRequest('Invalid Archive URL: Must be a valid URI')
+ dockerfile_id = request_json.get("file_id", None)
+ archive_url = request_json.get("archive_url", None)
- scheme = archive_match.scheme
- if scheme != 'http' and scheme != 'https':
- raise InvalidRequest('Invalid Archive URL: Must be http or https')
+ if not dockerfile_id and not archive_url:
+ raise InvalidRequest("file_id or archive_url required")
- context, subdir = self.get_dockerfile_context(request_json)
- tags = request_json.get('docker_tags', ['latest'])
- pull_robot_name = request_json.get('pull_robot', None)
+ if archive_url:
+ archive_match = None
+ try:
+ archive_match = urlparse(archive_url)
+ except ValueError:
+ pass
+
+ if not archive_match:
+ raise InvalidRequest("Invalid Archive URL: Must be a valid URI")
+
+ scheme = archive_match.scheme
+ if scheme != "http" and scheme != "https":
+ raise InvalidRequest("Invalid Archive URL: Must be http or https")
+
+ context, subdir = self.get_dockerfile_context(request_json)
+ tags = request_json.get("docker_tags", ["latest"])
+ pull_robot_name = request_json.get("pull_robot", None)
+
+ # Verify the security behind the pull robot.
+ if pull_robot_name:
+ result = parse_robot_username(pull_robot_name)
+ if result:
+ try:
+ model.user.lookup_robot(pull_robot_name)
+ except model.InvalidRobotException:
+ raise NotFound()
+
+ # Make sure the user has administer permissions for the robot's namespace.
+ (robot_namespace, _) = result
+ if not AdministerOrganizationPermission(robot_namespace).can():
+ raise Unauthorized()
+ else:
+ raise Unauthorized()
+
+ # Check if the dockerfile resource has already been used. If so, then it
+ # can only be reused if the user has access to the repository in which the
+ # dockerfile was previously built.
+ if dockerfile_id:
+ associated_repository = model.build.get_repository_for_resource(
+ dockerfile_id
+ )
+ if associated_repository:
+ if not ModifyRepositoryPermission(
+ associated_repository.namespace_user.username,
+ associated_repository.name,
+ ):
+ raise Unauthorized()
+
+ # Start the build.
+ repo = model.repository.get_repository(namespace, repository)
+ if repo is None:
+ raise NotFound()
- # Verify the security behind the pull robot.
- if pull_robot_name:
- result = parse_robot_username(pull_robot_name)
- if result:
try:
- model.user.lookup_robot(pull_robot_name)
- except model.InvalidRobotException:
- raise NotFound()
+ build_name = (
+ user_files.get_file_checksum(dockerfile_id)
+ if dockerfile_id
+ else hashlib.sha224(archive_url).hexdigest()[0:7]
+ )
+ except IOError:
+ raise InvalidRequest(
+ "File %s could not be found or is invalid" % dockerfile_id
+ )
- # Make sure the user has administer permissions for the robot's namespace.
- (robot_namespace, _) = result
- if not AdministerOrganizationPermission(robot_namespace).can():
- raise Unauthorized()
- else:
- raise Unauthorized()
+ prepared = PreparedBuild()
+ prepared.build_name = build_name
+ prepared.dockerfile_id = dockerfile_id
+ prepared.archive_url = archive_url
+ prepared.tags = tags
+ prepared.subdirectory = subdir
+ prepared.context = context
+ prepared.is_manual = True
+ prepared.metadata = {}
+ try:
+ build_request = start_build(repo, prepared, pull_robot_name=pull_robot_name)
+ except MaximumBuildsQueuedException:
+ abort(429, message="Maximum queued build rate exceeded.")
+ except BuildTriggerDisabledException:
+ abort(400, message="Build trigger is disabled")
- # Check if the dockerfile resource has already been used. If so, then it
- # can only be reused if the user has access to the repository in which the
- # dockerfile was previously built.
- if dockerfile_id:
- associated_repository = model.build.get_repository_for_resource(dockerfile_id)
- if associated_repository:
- if not ModifyRepositoryPermission(associated_repository.namespace_user.username,
- associated_repository.name):
- raise Unauthorized()
+ resp = build_status_view(build_request)
+ repo_string = "%s/%s" % (namespace, repository)
+ headers = {
+ "Location": api.url_for(
+ RepositoryBuildStatus,
+ repository=repo_string,
+ build_uuid=build_request.uuid,
+ )
+ }
+ return resp, 201, headers
- # Start the build.
- repo = model.repository.get_repository(namespace, repository)
- if repo is None:
- raise NotFound()
+ @staticmethod
+ def get_dockerfile_context(request_json):
+ context = request_json["context"] if "context" in request_json else os.path.sep
+ if "dockerfile_path" in request_json:
+ subdir = request_json["dockerfile_path"]
+ if "context" not in request_json:
+ context = os.path.dirname(subdir)
+ return context, subdir
- try:
- build_name = (user_files.get_file_checksum(dockerfile_id)
- if dockerfile_id
- else hashlib.sha224(archive_url).hexdigest()[0:7])
- except IOError:
- raise InvalidRequest('File %s could not be found or is invalid' % dockerfile_id)
+ if "subdirectory" in request_json:
+ subdir = request_json["subdirectory"]
+ context = subdir
+ if not subdir.endswith(os.path.sep):
+ subdir += os.path.sep
- prepared = PreparedBuild()
- prepared.build_name = build_name
- prepared.dockerfile_id = dockerfile_id
- prepared.archive_url = archive_url
- prepared.tags = tags
- prepared.subdirectory = subdir
- prepared.context = context
- prepared.is_manual = True
- prepared.metadata = {}
- try:
- build_request = start_build(repo, prepared, pull_robot_name=pull_robot_name)
- except MaximumBuildsQueuedException:
- abort(429, message='Maximum queued build rate exceeded.')
- except BuildTriggerDisabledException:
- abort(400, message='Build trigger is disabled')
+ subdir += "Dockerfile"
+ else:
+ if context.endswith(os.path.sep):
+ subdir = context + "Dockerfile"
+ else:
+ subdir = context + os.path.sep + "Dockerfile"
- resp = build_status_view(build_request)
- repo_string = '%s/%s' % (namespace, repository)
- headers = {
- 'Location': api.url_for(RepositoryBuildStatus, repository=repo_string,
- build_uuid=build_request.uuid),
- }
- return resp, 201, headers
+ return context, subdir
- @staticmethod
- def get_dockerfile_context(request_json):
- context = request_json['context'] if 'context' in request_json else os.path.sep
- if 'dockerfile_path' in request_json:
- subdir = request_json['dockerfile_path']
- if 'context' not in request_json:
- context = os.path.dirname(subdir)
- return context, subdir
- if 'subdirectory' in request_json:
- subdir = request_json['subdirectory']
- context = subdir
- if not subdir.endswith(os.path.sep):
- subdir += os.path.sep
-
- subdir += 'Dockerfile'
- else:
- if context.endswith(os.path.sep):
- subdir = context + 'Dockerfile'
- else:
- subdir = context + os.path.sep + 'Dockerfile'
-
- return context, subdir
-
-@resource('/v1/repository//build/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/repository//build/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("build_uuid", "The UUID of the build")
class RepositoryBuildResource(RepositoryParamResource):
- """ Resource for dealing with repository builds. """
- @require_repo_read
- @nickname('getRepoBuild')
- @disallow_for_app_repositories
- def get(self, namespace, repository, build_uuid):
- """ Returns information about a build. """
- try:
- build = model.build.get_repository_build(build_uuid)
- except model.build.InvalidRepositoryBuildException:
- raise NotFound()
+ """ Resource for dealing with repository builds. """
- if build.repository.name != repository or build.repository.namespace_user.username != namespace:
- raise NotFound()
+ @require_repo_read
+ @nickname("getRepoBuild")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository, build_uuid):
+ """ Returns information about a build. """
+ try:
+ build = model.build.get_repository_build(build_uuid)
+ except model.build.InvalidRepositoryBuildException:
+ raise NotFound()
- return build_status_view(build)
+ if (
+ build.repository.name != repository
+ or build.repository.namespace_user.username != namespace
+ ):
+ raise NotFound()
- @require_repo_admin
- @nickname('cancelRepoBuild')
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- def delete(self, namespace, repository, build_uuid):
- """ Cancels a repository build. """
- try:
- build = model.build.get_repository_build(build_uuid)
- except model.build.InvalidRepositoryBuildException:
- raise NotFound()
+ return build_status_view(build)
- if build.repository.name != repository or build.repository.namespace_user.username != namespace:
- raise NotFound()
+ @require_repo_admin
+ @nickname("cancelRepoBuild")
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ def delete(self, namespace, repository, build_uuid):
+ """ Cancels a repository build. """
+ try:
+ build = model.build.get_repository_build(build_uuid)
+ except model.build.InvalidRepositoryBuildException:
+ raise NotFound()
- if model.build.cancel_repository_build(build, dockerfile_build_queue):
- return 'Okay', 201
- else:
- raise InvalidRequest('Build is currently running or has finished')
+ if (
+ build.repository.name != repository
+ or build.repository.namespace_user.username != namespace
+ ):
+ raise NotFound()
+
+ if model.build.cancel_repository_build(build, dockerfile_build_queue):
+ return "Okay", 201
+ else:
+ raise InvalidRequest("Build is currently running or has finished")
-@resource('/v1/repository//build//status')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/repository//build//status")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("build_uuid", "The UUID of the build")
class RepositoryBuildStatus(RepositoryParamResource):
- """ Resource for dealing with repository build status. """
- @require_repo_read
- @nickname('getRepoBuildStatus')
- @disallow_for_app_repositories
- def get(self, namespace, repository, build_uuid):
- """ Return the status for the builds specified by the build uuids. """
- build = model.build.get_repository_build(build_uuid)
- if (not build or build.repository.name != repository or
- build.repository.namespace_user.username != namespace):
- raise NotFound()
+ """ Resource for dealing with repository build status. """
- return build_status_view(build)
+ @require_repo_read
+ @nickname("getRepoBuildStatus")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository, build_uuid):
+ """ Return the status for the builds specified by the build uuids. """
+ build = model.build.get_repository_build(build_uuid)
+ if (
+ not build
+ or build.repository.name != repository
+ or build.repository.namespace_user.username != namespace
+ ):
+ raise NotFound()
+
+ return build_status_view(build)
def get_logs_or_log_url(build):
- # If the logs have been archived, just return a URL of the completed archive
- if build.logs_archived:
- return {
- 'logs_url': log_archive.get_file_url(build.uuid, get_request_ip(), requires_cors=True)
- }
- start = int(request.args.get('start', 0))
+ # If the logs have been archived, just return a URL of the completed archive
+ if build.logs_archived:
+ return {
+ "logs_url": log_archive.get_file_url(
+ build.uuid, get_request_ip(), requires_cors=True
+ )
+ }
+ start = int(request.args.get("start", 0))
- try:
- count, logs = build_logs.get_log_entries(build.uuid, start)
- except BuildStatusRetrievalError:
- count, logs = (0, [])
+ try:
+ count, logs = build_logs.get_log_entries(build.uuid, start)
+ except BuildStatusRetrievalError:
+ count, logs = (0, [])
- response_obj = {}
- response_obj.update({
- 'start': start,
- 'total': count,
- 'logs': [log for log in logs],
- })
+ response_obj = {}
+ response_obj.update({"start": start, "total": count, "logs": [log for log in logs]})
- return response_obj
+ return response_obj
-@resource('/v1/repository//build//logs')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/repository//build//logs")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("build_uuid", "The UUID of the build")
class RepositoryBuildLogs(RepositoryParamResource):
- """ Resource for loading repository build logs. """
- @require_repo_read
- @nickname('getRepoBuildLogs')
- @disallow_for_app_repositories
- def get(self, namespace, repository, build_uuid):
- """ Return the build logs for the build specified by the build uuid. """
- can_write = ModifyRepositoryPermission(namespace, repository).can()
- if not features.READER_BUILD_LOGS and not can_write:
- raise Unauthorized()
+ """ Resource for loading repository build logs. """
- build = model.build.get_repository_build(build_uuid)
- if (not build or build.repository.name != repository or
- build.repository.namespace_user.username != namespace):
- raise NotFound()
+ @require_repo_read
+ @nickname("getRepoBuildLogs")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository, build_uuid):
+ """ Return the build logs for the build specified by the build uuid. """
+ can_write = ModifyRepositoryPermission(namespace, repository).can()
+ if not features.READER_BUILD_LOGS and not can_write:
+ raise Unauthorized()
- return get_logs_or_log_url(build)
+ build = model.build.get_repository_build(build_uuid)
+ if (
+ not build
+ or build.repository.name != repository
+ or build.repository.namespace_user.username != namespace
+ ):
+ raise NotFound()
+
+ return get_logs_or_log_url(build)
-@resource('/v1/filedrop/')
+@resource("/v1/filedrop/")
@internal_only
class FileDropResource(ApiResource):
- """ Custom verb for setting up a client side file transfer. """
- schemas = {
- 'FileDropRequest': {
- 'type': 'object',
- 'description': 'Description of the file that the user wishes to upload.',
- 'required': [
- 'mimeType',
- ],
- 'properties': {
- 'mimeType': {
- 'type': 'string',
- 'description': 'Type of the file which is about to be uploaded',
- },
- },
- },
- }
+ """ Custom verb for setting up a client side file transfer. """
- @nickname('getFiledropUrl')
- @validate_json_request('FileDropRequest')
- def post(self):
- """ Request a URL to which a file may be uploaded. """
- mime_type = request.get_json()['mimeType']
- (url, file_id) = user_files.prepare_for_drop(mime_type, requires_cors=True)
- return {
- 'url': url,
- 'file_id': str(file_id),
+ schemas = {
+ "FileDropRequest": {
+ "type": "object",
+ "description": "Description of the file that the user wishes to upload.",
+ "required": ["mimeType"],
+ "properties": {
+ "mimeType": {
+ "type": "string",
+ "description": "Type of the file which is about to be uploaded",
+ }
+ },
+ }
}
+
+ @nickname("getFiledropUrl")
+ @validate_json_request("FileDropRequest")
+ def post(self):
+ """ Request a URL to which a file may be uploaded. """
+ mime_type = request.get_json()["mimeType"]
+ (url, file_id) = user_files.prepare_for_drop(mime_type, requires_cors=True)
+ return {"url": url, "file_id": str(file_id)}
diff --git a/endpoints/api/discovery.py b/endpoints/api/discovery.py
index 66e7c74a3..45f8f476e 100644
--- a/endpoints/api/discovery.py
+++ b/endpoints/api/discovery.py
@@ -11,324 +11,337 @@ from flask_restful import reqparse
from app import app
from auth import scopes
-from endpoints.api import (ApiResource, resource, method_metadata, nickname, truthy_bool,
- parse_args, query_param)
+from endpoints.api import (
+ ApiResource,
+ resource,
+ method_metadata,
+ nickname,
+ truthy_bool,
+ parse_args,
+ query_param,
+)
from endpoints.decorators import anon_allowed
logger = logging.getLogger(__name__)
-PARAM_REGEX = re.compile(r'<([^:>]+:)*([\w]+)>')
+PARAM_REGEX = re.compile(r"<([^:>]+:)*([\w]+)>")
TYPE_CONVERTER = {
- truthy_bool: 'boolean',
- str: 'string',
- basestring: 'string',
- reqparse.text_type: 'string',
- int: 'integer',
+ truthy_bool: "boolean",
+ str: "string",
+ basestring: "string",
+ reqparse.text_type: "string",
+ int: "integer",
}
-PREFERRED_URL_SCHEME = app.config['PREFERRED_URL_SCHEME']
-SERVER_HOSTNAME = app.config['SERVER_HOSTNAME']
+PREFERRED_URL_SCHEME = app.config["PREFERRED_URL_SCHEME"]
+SERVER_HOSTNAME = app.config["SERVER_HOSTNAME"]
def fully_qualified_name(method_view_class):
- return '%s.%s' % (method_view_class.__module__, method_view_class.__name__)
+ return "%s.%s" % (method_view_class.__module__, method_view_class.__name__)
def swagger_route_data(include_internal=False, compact=False):
- def swagger_parameter(name, description, kind='path', param_type='string', required=True,
- enum=None, schema=None):
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
- parameter_info = {
- 'name': name,
- 'in': kind,
- 'required': required
- }
+ def swagger_parameter(
+ name,
+ description,
+ kind="path",
+ param_type="string",
+ required=True,
+ enum=None,
+ schema=None,
+ ):
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#parameterObject
+ parameter_info = {"name": name, "in": kind, "required": required}
- if not compact:
- parameter_info['description'] = description or ''
+ if not compact:
+ parameter_info["description"] = description or ""
- if schema:
- parameter_info['schema'] = {
- '$ref': '#/definitions/%s' % schema
- }
- else:
- parameter_info['type'] = param_type
-
- if enum is not None and len(list(enum)) > 0:
- parameter_info['enum'] = list(enum)
-
- return parameter_info
-
- paths = {}
- models = {}
- tags = []
- tags_added = set()
- operationIds = set()
-
- for rule in app.url_map.iter_rules():
- endpoint_method = app.view_functions[rule.endpoint]
-
- # Verify that we have a view class for this API method.
- if not 'view_class' in dir(endpoint_method):
- continue
-
- view_class = endpoint_method.view_class
-
- # Hide the class if it is internal.
- internal = method_metadata(view_class, 'internal')
- if not include_internal and internal:
- continue
-
- # Build the tag.
- parts = fully_qualified_name(view_class).split('.')
- tag_name = parts[-2]
- if not tag_name in tags_added:
- tags_added.add(tag_name)
- tags.append({
- 'name': tag_name,
- 'description': (sys.modules[view_class.__module__].__doc__ or '').strip()
- })
-
- # Build the Swagger data for the path.
- swagger_path = PARAM_REGEX.sub(r'{\2}', rule.rule)
- full_name = fully_qualified_name(view_class)
- path_swagger = {
- 'x-name': full_name,
- 'x-path': swagger_path,
- 'x-tag': tag_name
- }
-
- if include_internal:
- related_user_res = method_metadata(view_class, 'related_user_resource')
- if related_user_res is not None:
- path_swagger['x-user-related'] = fully_qualified_name(related_user_res)
-
- paths[swagger_path] = path_swagger
-
- # Add any global path parameters.
- param_data_map = view_class.__api_path_params if '__api_path_params' in dir(view_class) else {}
- if param_data_map:
- path_parameters_swagger = []
- for path_parameter in param_data_map:
- description = param_data_map[path_parameter].get('description')
- path_parameters_swagger.append(swagger_parameter(path_parameter, description))
-
- path_swagger['parameters'] = path_parameters_swagger
-
- # Add the individual HTTP operations.
- method_names = list(rule.methods.difference(['HEAD', 'OPTIONS']))
- for method_name in method_names:
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
- method = getattr(view_class, method_name.lower(), None)
- if method is None:
- logger.debug('Unable to find method for %s in class %s', method_name, view_class)
- continue
-
- operationId = method_metadata(method, 'nickname')
- operation_swagger = {
- 'operationId': operationId,
- 'parameters': [],
- }
-
- if operationId is None:
- continue
-
- if operationId in operationIds:
- raise Exception('Duplicate operation Id: %s' % operationId)
-
- operationIds.add(operationId)
-
- if not compact:
- operation_swagger.update({
- 'description': method.__doc__.strip() if method.__doc__ else '',
- 'tags': [tag_name]
- })
-
- # Mark the method as internal.
- internal = method_metadata(method, 'internal')
- if internal is not None:
- operation_swagger['x-internal'] = True
-
- if include_internal:
- requires_fresh_login = method_metadata(method, 'requires_fresh_login')
- if requires_fresh_login is not None:
- operation_swagger['x-requires-fresh-login'] = True
-
- # Add the path parameters.
- if rule.arguments:
- for path_parameter in rule.arguments:
- description = param_data_map.get(path_parameter, {}).get('description')
- operation_swagger['parameters'].append(swagger_parameter(path_parameter, description))
-
- # Add the query parameters.
- if '__api_query_params' in dir(method):
- for query_parameter_info in method.__api_query_params:
- name = query_parameter_info['name']
- description = query_parameter_info['help']
- param_type = TYPE_CONVERTER[query_parameter_info['type']]
- required = query_parameter_info['required']
-
- operation_swagger['parameters'].append(
- swagger_parameter(name, description, kind='query',
- param_type=param_type,
- required=required,
- enum=query_parameter_info['choices']))
-
- # Add the OAuth security block.
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
- scope = method_metadata(method, 'oauth2_scope')
- if scope and not compact:
- operation_swagger['security'] = [{'oauth2_implicit': [scope.scope]}]
-
- # Add the responses block.
- # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
- response_schema_name = method_metadata(method, 'response_schema')
- if not compact:
- if response_schema_name:
- models[response_schema_name] = view_class.schemas[response_schema_name]
-
- models['ApiError'] = {
- 'type': 'object',
- 'properties': {
- 'status': {
- 'type': 'integer',
- 'description': 'Status code of the response.'
- },
- 'type': {
- 'type': 'string',
- 'description': 'Reference to the type of the error.'
- },
- 'detail': {
- 'type': 'string',
- 'description': 'Details about the specific instance of the error.'
- },
- 'title': {
- 'type': 'string',
- 'description': 'Unique error code to identify the type of error.'
- },
- 'error_message': {
- 'type': 'string',
- 'description': 'Deprecated; alias for detail'
- },
- 'error_type': {
- 'type': 'string',
- 'description': 'Deprecated; alias for detail'
- }
- },
- 'required': [
- 'status',
- 'type',
- 'title',
- ]
- }
-
- responses = {
- '400': {
- 'description': 'Bad Request',
- },
-
- '401': {
- 'description': 'Session required',
- },
-
- '403': {
- 'description': 'Unauthorized access',
- },
-
- '404': {
- 'description': 'Not found',
- },
- }
-
- for _, body in responses.items():
- body['schema'] = {'$ref': '#/definitions/ApiError'}
-
- if method_name == 'DELETE':
- responses['204'] = {
- 'description': 'Deleted'
- }
- elif method_name == 'POST':
- responses['201'] = {
- 'description': 'Successful creation'
- }
+ if schema:
+ parameter_info["schema"] = {"$ref": "#/definitions/%s" % schema}
else:
- responses['200'] = {
- 'description': 'Successful invocation'
- }
+ parameter_info["type"] = param_type
- if response_schema_name:
- responses['200']['schema'] = {
- '$ref': '#/definitions/%s' % response_schema_name
+ if enum is not None and len(list(enum)) > 0:
+ parameter_info["enum"] = list(enum)
+
+ return parameter_info
+
+ paths = {}
+ models = {}
+ tags = []
+ tags_added = set()
+ operationIds = set()
+
+ for rule in app.url_map.iter_rules():
+ endpoint_method = app.view_functions[rule.endpoint]
+
+ # Verify that we have a view class for this API method.
+ if not "view_class" in dir(endpoint_method):
+ continue
+
+ view_class = endpoint_method.view_class
+
+ # Hide the class if it is internal.
+ internal = method_metadata(view_class, "internal")
+ if not include_internal and internal:
+ continue
+
+ # Build the tag.
+ parts = fully_qualified_name(view_class).split(".")
+ tag_name = parts[-2]
+ if not tag_name in tags_added:
+ tags_added.add(tag_name)
+ tags.append(
+ {
+ "name": tag_name,
+ "description": (
+ sys.modules[view_class.__module__].__doc__ or ""
+ ).strip(),
+ }
+ )
+
+ # Build the Swagger data for the path.
+ swagger_path = PARAM_REGEX.sub(r"{\2}", rule.rule)
+ full_name = fully_qualified_name(view_class)
+ path_swagger = {"x-name": full_name, "x-path": swagger_path, "x-tag": tag_name}
+
+ if include_internal:
+ related_user_res = method_metadata(view_class, "related_user_resource")
+ if related_user_res is not None:
+ path_swagger["x-user-related"] = fully_qualified_name(related_user_res)
+
+ paths[swagger_path] = path_swagger
+
+ # Add any global path parameters.
+ param_data_map = (
+ view_class.__api_path_params
+ if "__api_path_params" in dir(view_class)
+ else {}
+ )
+ if param_data_map:
+ path_parameters_swagger = []
+ for path_parameter in param_data_map:
+ description = param_data_map[path_parameter].get("description")
+ path_parameters_swagger.append(
+ swagger_parameter(path_parameter, description)
+ )
+
+ path_swagger["parameters"] = path_parameters_swagger
+
+ # Add the individual HTTP operations.
+ method_names = list(rule.methods.difference(["HEAD", "OPTIONS"]))
+ for method_name in method_names:
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#operation-object
+ method = getattr(view_class, method_name.lower(), None)
+ if method is None:
+ logger.debug(
+ "Unable to find method for %s in class %s", method_name, view_class
+ )
+ continue
+
+ operationId = method_metadata(method, "nickname")
+ operation_swagger = {"operationId": operationId, "parameters": []}
+
+ if operationId is None:
+ continue
+
+ if operationId in operationIds:
+ raise Exception("Duplicate operation Id: %s" % operationId)
+
+ operationIds.add(operationId)
+
+ if not compact:
+ operation_swagger.update(
+ {
+ "description": method.__doc__.strip() if method.__doc__ else "",
+ "tags": [tag_name],
+ }
+ )
+
+ # Mark the method as internal.
+ internal = method_metadata(method, "internal")
+ if internal is not None:
+ operation_swagger["x-internal"] = True
+
+ if include_internal:
+ requires_fresh_login = method_metadata(method, "requires_fresh_login")
+ if requires_fresh_login is not None:
+ operation_swagger["x-requires-fresh-login"] = True
+
+ # Add the path parameters.
+ if rule.arguments:
+ for path_parameter in rule.arguments:
+ description = param_data_map.get(path_parameter, {}).get(
+ "description"
+ )
+ operation_swagger["parameters"].append(
+ swagger_parameter(path_parameter, description)
+ )
+
+ # Add the query parameters.
+ if "__api_query_params" in dir(method):
+ for query_parameter_info in method.__api_query_params:
+ name = query_parameter_info["name"]
+ description = query_parameter_info["help"]
+ param_type = TYPE_CONVERTER[query_parameter_info["type"]]
+ required = query_parameter_info["required"]
+
+ operation_swagger["parameters"].append(
+ swagger_parameter(
+ name,
+ description,
+ kind="query",
+ param_type=param_type,
+ required=required,
+ enum=query_parameter_info["choices"],
+ )
+ )
+
+ # Add the OAuth security block.
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#securityRequirementObject
+ scope = method_metadata(method, "oauth2_scope")
+ if scope and not compact:
+ operation_swagger["security"] = [{"oauth2_implicit": [scope.scope]}]
+
+ # Add the responses block.
+ # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#responsesObject
+ response_schema_name = method_metadata(method, "response_schema")
+ if not compact:
+ if response_schema_name:
+ models[response_schema_name] = view_class.schemas[
+ response_schema_name
+ ]
+
+ models["ApiError"] = {
+ "type": "object",
+ "properties": {
+ "status": {
+ "type": "integer",
+ "description": "Status code of the response.",
+ },
+ "type": {
+ "type": "string",
+ "description": "Reference to the type of the error.",
+ },
+ "detail": {
+ "type": "string",
+ "description": "Details about the specific instance of the error.",
+ },
+ "title": {
+ "type": "string",
+ "description": "Unique error code to identify the type of error.",
+ },
+ "error_message": {
+ "type": "string",
+ "description": "Deprecated; alias for detail",
+ },
+ "error_type": {
+ "type": "string",
+ "description": "Deprecated; alias for detail",
+ },
+ },
+ "required": ["status", "type", "title"],
+ }
+
+ responses = {
+ "400": {"description": "Bad Request"},
+ "401": {"description": "Session required"},
+ "403": {"description": "Unauthorized access"},
+ "404": {"description": "Not found"},
+ }
+
+ for _, body in responses.items():
+ body["schema"] = {"$ref": "#/definitions/ApiError"}
+
+ if method_name == "DELETE":
+ responses["204"] = {"description": "Deleted"}
+ elif method_name == "POST":
+ responses["201"] = {"description": "Successful creation"}
+ else:
+ responses["200"] = {"description": "Successful invocation"}
+
+ if response_schema_name:
+ responses["200"]["schema"] = {
+ "$ref": "#/definitions/%s" % response_schema_name
+ }
+
+ operation_swagger["responses"] = responses
+
+ # Add the request block.
+ request_schema_name = method_metadata(method, "request_schema")
+ if request_schema_name and not compact:
+ models[request_schema_name] = view_class.schemas[request_schema_name]
+
+ operation_swagger["parameters"].append(
+ swagger_parameter(
+ "body",
+ "Request body contents.",
+ kind="body",
+ schema=request_schema_name,
+ )
+ )
+
+ # Add the operation to the parent path.
+ if not internal or (internal and include_internal):
+ path_swagger[method_name.lower()] = operation_swagger
+
+ tags.sort(key=lambda t: t["name"])
+ paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]["x-tag"]))
+
+ if compact:
+ return {"paths": paths}
+
+ swagger_data = {
+ "swagger": "2.0",
+ "host": SERVER_HOSTNAME,
+ "basePath": "/",
+ "schemes": [PREFERRED_URL_SCHEME],
+ "info": {
+ "version": "v1",
+ "title": "Quay Frontend",
+ "description": (
+ "This API allows you to perform many of the operations required to work "
+ "with Quay repositories, users, and organizations. You can find out more "
+ 'at Quay.'
+ ),
+ "termsOfService": "https://quay.io/tos",
+ "contact": {"email": "support@quay.io"},
+ },
+ "securityDefinitions": {
+ "oauth2_implicit": {
+ "type": "oauth2",
+ "flow": "implicit",
+ "authorizationUrl": "%s://%s/oauth/authorize"
+ % (PREFERRED_URL_SCHEME, SERVER_HOSTNAME),
+ "scopes": {
+ scope.scope: scope.description
+ for scope in scopes.app_scopes(app.config).values()
+ },
}
+ },
+ "paths": paths,
+ "definitions": models,
+ "tags": tags,
+ }
- operation_swagger['responses'] = responses
+ return swagger_data
- # Add the request block.
- request_schema_name = method_metadata(method, 'request_schema')
- if request_schema_name and not compact:
- models[request_schema_name] = view_class.schemas[request_schema_name]
-
- operation_swagger['parameters'].append(
- swagger_parameter('body', 'Request body contents.', kind='body',
- schema=request_schema_name))
-
- # Add the operation to the parent path.
- if not internal or (internal and include_internal):
- path_swagger[method_name.lower()] = operation_swagger
-
- tags.sort(key=lambda t: t['name'])
- paths = OrderedDict(sorted(paths.items(), key=lambda p: p[1]['x-tag']))
-
- if compact:
- return {'paths': paths}
-
- swagger_data = {
- 'swagger': '2.0',
- 'host': SERVER_HOSTNAME,
- 'basePath': '/',
- 'schemes': [
- PREFERRED_URL_SCHEME
- ],
- 'info': {
- 'version': 'v1',
- 'title': 'Quay Frontend',
- 'description': ('This API allows you to perform many of the operations required to work '
- 'with Quay repositories, users, and organizations. You can find out more '
- 'at Quay.'),
- 'termsOfService': 'https://quay.io/tos',
- 'contact': {
- 'email': 'support@quay.io'
- }
- },
- 'securityDefinitions': {
- 'oauth2_implicit': {
- "type": "oauth2",
- "flow": "implicit",
- "authorizationUrl": "%s://%s/oauth/authorize" % (PREFERRED_URL_SCHEME, SERVER_HOSTNAME),
- 'scopes': {scope.scope:scope.description
- for scope in scopes.app_scopes(app.config).values()},
- },
- },
- 'paths': paths,
- 'definitions': models,
- 'tags': tags
- }
-
- return swagger_data
-
-
-@resource('/v1/discovery')
+@resource("/v1/discovery")
class DiscoveryResource(ApiResource):
- """Ability to inspect the API for usage information and documentation."""
- @parse_args()
- @query_param('internal', 'Whether to include internal APIs.', type=truthy_bool, default=False)
- @nickname('discovery')
- @anon_allowed
- def get(self, parsed_args):
- """ List all of the API endpoints available in the swagger API format."""
- return swagger_route_data(parsed_args['internal'])
+ """Ability to inspect the API for usage information and documentation."""
+
+ @parse_args()
+ @query_param(
+ "internal", "Whether to include internal APIs.", type=truthy_bool, default=False
+ )
+ @nickname("discovery")
+ @anon_allowed
+ def get(self, parsed_args):
+ """ List all of the API endpoints available in the swagger API format."""
+ return swagger_route_data(parsed_args["internal"])
diff --git a/endpoints/api/error.py b/endpoints/api/error.py
index bfa80efe2..210f728ed 100644
--- a/endpoints/api/error.py
+++ b/endpoints/api/error.py
@@ -1,61 +1,63 @@
""" Error details API """
from flask import url_for
-from endpoints.api import (resource, nickname, ApiResource, path_param,
- define_json_response)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ path_param,
+ define_json_response,
+)
from endpoints.exception import NotFound, ApiErrorType, ERROR_DESCRIPTION
+
def error_view(error_type):
- return {
- 'type': url_for('api.error', error_type=error_type, _external=True),
- 'title': error_type,
- 'description': ERROR_DESCRIPTION[error_type]
- }
+ return {
+ "type": url_for("api.error", error_type=error_type, _external=True),
+ "title": error_type,
+ "description": ERROR_DESCRIPTION[error_type],
+ }
-@resource('/v1/error/')
-@path_param('error_type', 'The error code identifying the type of error.')
+@resource("/v1/error/")
+@path_param("error_type", "The error code identifying the type of error.")
class Error(ApiResource):
- """ Resource for Error Descriptions"""
- schemas = {
- 'ApiErrorDescription': {
- 'type': 'object',
- 'description': 'Description of an error',
- 'required': [
- 'type',
- 'description',
- 'title',
- ],
- 'properties': {
- 'type': {
- 'type': 'string',
- 'description': 'A reference to the error type resource'
- },
- 'title': {
- 'type': 'string',
- 'description': (
- 'The title of the error. Can be used to uniquely identify the kind'
- ' of error.'
- ),
- 'enum': list(ApiErrorType.__members__)
- },
- 'description': {
- 'type': 'string',
- 'description': (
- 'A more detailed description of the error that may include help for'
- ' fixing the issue.'
- )
+ """ Resource for Error Descriptions"""
+
+ schemas = {
+ "ApiErrorDescription": {
+ "type": "object",
+ "description": "Description of an error",
+ "required": ["type", "description", "title"],
+ "properties": {
+ "type": {
+ "type": "string",
+ "description": "A reference to the error type resource",
+ },
+ "title": {
+ "type": "string",
+ "description": (
+ "The title of the error. Can be used to uniquely identify the kind"
+ " of error."
+ ),
+ "enum": list(ApiErrorType.__members__),
+ },
+ "description": {
+ "type": "string",
+ "description": (
+ "A more detailed description of the error that may include help for"
+ " fixing the issue."
+ ),
+ },
+ },
}
- },
- },
- }
+ }
- @define_json_response('ApiErrorDescription')
- @nickname('getErrorDescription')
- def get(self, error_type):
- """ Get a detailed description of the error """
- if error_type in ERROR_DESCRIPTION.keys():
- return error_view(error_type)
-
- raise NotFound()
+ @define_json_response("ApiErrorDescription")
+ @nickname("getErrorDescription")
+ def get(self, error_type):
+ """ Get a detailed description of the error """
+ if error_type in ERROR_DESCRIPTION.keys():
+ return error_view(error_type)
+ raise NotFound()
diff --git a/endpoints/api/globalmessages.py b/endpoints/api/globalmessages.py
index 43ea58083..035886857 100644
--- a/endpoints/api/globalmessages.py
+++ b/endpoints/api/globalmessages.py
@@ -6,123 +6,127 @@ from flask import request
import features
from auth import scopes
from auth.permissions import SuperUserPermission
-from endpoints.api import (ApiResource, resource, nickname,
- require_fresh_login, verify_not_prod, validate_json_request,
- require_scope, show_if,)
+from endpoints.api import (
+ ApiResource,
+ resource,
+ nickname,
+ require_fresh_login,
+ verify_not_prod,
+ validate_json_request,
+ require_scope,
+ show_if,
+)
from globalmessages_models_pre_oci import pre_oci_model as model
-@resource('/v1/messages')
+@resource("/v1/messages")
class GlobalUserMessages(ApiResource):
- """ Resource for getting a list of super user messages """
- schemas = {
- 'GetMessage': {
- 'id': 'GetMessage',
- 'type': 'object',
- 'description': 'Messages that a super user has saved in the past',
- 'properties': {
- 'message': {
- 'type': 'array',
- 'description': 'A list of messages',
- 'itemType': {
- 'type': 'object',
- 'properties': {
- 'uuid': {
- 'type': 'string',
- 'description': 'The message id',
- },
- 'content': {
- 'type': 'string',
- 'description': 'The actual message',
- },
- 'media_type': {
- 'type': 'string',
- 'description': 'The media type of the message',
- 'enum': ['text/plain', 'text/markdown'],
- },
- 'severity': {
- 'type': 'string',
- 'description': 'The severity of the message',
- 'enum': ['info', 'warning', 'error'],
- },
- },
- },
- },
- },
- },
- 'CreateMessage': {
- 'id': 'CreateMessage',
- 'type': 'object',
- 'description': 'Create a new message',
- 'properties': {
- 'message': {
- 'type': 'object',
- 'description': 'A single message',
- 'required': [
- 'content',
- 'media_type',
- 'severity',
- ],
- 'properties': {
- 'content': {
- 'type': 'string',
- 'description': 'The actual message',
- },
- 'media_type': {
- 'type': 'string',
- 'description': 'The media type of the message',
- 'enum': ['text/plain', 'text/markdown'],
- },
- 'severity': {
- 'type': 'string',
- 'description': 'The severity of the message',
- 'enum': ['info', 'warning', 'error'],
- },
- },
- },
- },
- }
- }
+ """ Resource for getting a list of super user messages """
- @nickname('getGlobalMessages')
- def get(self):
- """ Return a super users messages """
- return {
- 'messages': [m.to_dict() for m in model.get_all_messages()],
+ schemas = {
+ "GetMessage": {
+ "id": "GetMessage",
+ "type": "object",
+ "description": "Messages that a super user has saved in the past",
+ "properties": {
+ "message": {
+ "type": "array",
+ "description": "A list of messages",
+ "itemType": {
+ "type": "object",
+ "properties": {
+ "uuid": {"type": "string", "description": "The message id"},
+ "content": {
+ "type": "string",
+ "description": "The actual message",
+ },
+ "media_type": {
+ "type": "string",
+ "description": "The media type of the message",
+ "enum": ["text/plain", "text/markdown"],
+ },
+ "severity": {
+ "type": "string",
+ "description": "The severity of the message",
+ "enum": ["info", "warning", "error"],
+ },
+ },
+ },
+ }
+ },
+ },
+ "CreateMessage": {
+ "id": "CreateMessage",
+ "type": "object",
+ "description": "Create a new message",
+ "properties": {
+ "message": {
+ "type": "object",
+ "description": "A single message",
+ "required": ["content", "media_type", "severity"],
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "The actual message",
+ },
+ "media_type": {
+ "type": "string",
+ "description": "The media type of the message",
+ "enum": ["text/plain", "text/markdown"],
+ },
+ "severity": {
+ "type": "string",
+ "description": "The severity of the message",
+ "enum": ["info", "warning", "error"],
+ },
+ },
+ }
+ },
+ },
}
- @require_fresh_login
- @verify_not_prod
- @nickname('createGlobalMessage')
- @validate_json_request('CreateMessage')
- @require_scope(scopes.SUPERUSER)
- def post(self):
- """ Create a message """
- if not features.SUPER_USERS:
- abort(404)
+ @nickname("getGlobalMessages")
+ def get(self):
+ """ Return a super users messages """
+ return {"messages": [m.to_dict() for m in model.get_all_messages()]}
- if SuperUserPermission().can():
- message_req = request.get_json()['message']
- message = model.create_message(message_req['severity'], message_req['media_type'], message_req['content'])
- if message is None:
- abort(400)
- return make_response('', 201)
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("createGlobalMessage")
+ @validate_json_request("CreateMessage")
+ @require_scope(scopes.SUPERUSER)
+ def post(self):
+ """ Create a message """
+ if not features.SUPER_USERS:
+ abort(404)
- abort(403)
+ if SuperUserPermission().can():
+ message_req = request.get_json()["message"]
+ message = model.create_message(
+ message_req["severity"],
+ message_req["media_type"],
+ message_req["content"],
+ )
+ if message is None:
+ abort(400)
+ return make_response("", 201)
+
+ abort(403)
-@resource('/v1/message/')
+@resource("/v1/message/")
@show_if(features.SUPER_USERS)
class GlobalUserMessage(ApiResource):
- """ Resource for managing individual messages """
- @require_fresh_login
- @verify_not_prod
- @nickname('deleteGlobalMessage')
- @require_scope(scopes.SUPERUSER)
- def delete(self, uuid):
- """ Delete a message """
- if SuperUserPermission().can():
- model.delete_message(uuid)
- return make_response('', 204)
+ """ Resource for managing individual messages """
- abort(403)
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("deleteGlobalMessage")
+ @require_scope(scopes.SUPERUSER)
+ def delete(self, uuid):
+ """ Delete a message """
+ if SuperUserPermission().can():
+ model.delete_message(uuid)
+ return make_response("", 204)
+
+ abort(403)
diff --git a/endpoints/api/globalmessages_models_interface.py b/endpoints/api/globalmessages_models_interface.py
index 679462c1d..53e2da412 100644
--- a/endpoints/api/globalmessages_models_interface.py
+++ b/endpoints/api/globalmessages_models_interface.py
@@ -3,52 +3,45 @@ from collections import namedtuple
from six import add_metaclass
-class GlobalMessage(
- namedtuple('GlobalMessage', [
- 'uuid',
- 'content',
- 'severity',
- 'media_type_name',
- ])):
-
- def to_dict(self):
- return {
- 'uuid': self.uuid,
- 'content': self.content,
- 'severity': self.severity,
- 'media_type': self.media_type_name,
- }
+class GlobalMessage(
+ namedtuple("GlobalMessage", ["uuid", "content", "severity", "media_type_name"])
+):
+ def to_dict(self):
+ return {
+ "uuid": self.uuid,
+ "content": self.content,
+ "severity": self.severity,
+ "media_type": self.media_type_name,
+ }
@add_metaclass(ABCMeta)
class GlobalMessageDataInterface(object):
- """
+ """
Data interface for globalmessages API
"""
-
- @abstractmethod
- def get_all_messages(self):
- """
+
+ @abstractmethod
+ def get_all_messages(self):
+ """
Returns:
list(GlobalMessage)
"""
-
- @abstractmethod
- def create_message(self, severity, media_type_name, content):
- """
+
+ @abstractmethod
+ def create_message(self, severity, media_type_name, content):
+ """
Returns:
GlobalMessage or None
"""
-
- @abstractmethod
- def delete_message(self, uuid):
- """
+
+ @abstractmethod
+ def delete_message(self, uuid):
+ """
Returns:
void
"""
-
-
\ No newline at end of file
diff --git a/endpoints/api/globalmessages_models_pre_oci.py b/endpoints/api/globalmessages_models_pre_oci.py
index d9a623f1b..2b0f35444 100644
--- a/endpoints/api/globalmessages_models_pre_oci.py
+++ b/endpoints/api/globalmessages_models_pre_oci.py
@@ -3,31 +3,31 @@ from data import model
class GlobalMessagePreOCI(GlobalMessageDataInterface):
-
- def get_all_messages(self):
- messages = model.message.get_messages()
- return [self._message(m) for m in messages]
-
- def create_message(self, severity, media_type_name, content):
- message = {
- 'severity': severity,
- 'media_type': media_type_name,
- 'content': content
- }
- messages = model.message.create([message])
- return self._message(messages[0])
-
- def delete_message(self, uuid):
- model.message.delete_message([uuid])
-
- def _message(self, message_obj):
- if message_obj is None:
- return None
- return GlobalMessage(
- uuid=message_obj.uuid,
- content=message_obj.content,
- severity=message_obj.severity,
- media_type_name=message_obj.media_type.name,
- )
-
-pre_oci_model = GlobalMessagePreOCI()
\ No newline at end of file
+ def get_all_messages(self):
+ messages = model.message.get_messages()
+ return [self._message(m) for m in messages]
+
+ def create_message(self, severity, media_type_name, content):
+ message = {
+ "severity": severity,
+ "media_type": media_type_name,
+ "content": content,
+ }
+ messages = model.message.create([message])
+ return self._message(messages[0])
+
+ def delete_message(self, uuid):
+ model.message.delete_message([uuid])
+
+ def _message(self, message_obj):
+ if message_obj is None:
+ return None
+ return GlobalMessage(
+ uuid=message_obj.uuid,
+ content=message_obj.content,
+ severity=message_obj.severity,
+ media_type_name=message_obj.media_type.name,
+ )
+
+
+pre_oci_model = GlobalMessagePreOCI()
diff --git a/endpoints/api/image.py b/endpoints/api/image.py
index 3a9dcd82c..bf24d91e2 100644
--- a/endpoints/api/image.py
+++ b/endpoints/api/image.py
@@ -2,76 +2,85 @@
import json
from data.registry_model import registry_model
-from endpoints.api import (resource, nickname, require_repo_read, RepositoryParamResource,
- path_param, disallow_for_app_repositories, format_date)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_read,
+ RepositoryParamResource,
+ path_param,
+ disallow_for_app_repositories,
+ format_date,
+)
from endpoints.exception import NotFound
def image_dict(image, with_history=False, with_tags=False):
- parsed_command = None
- if image.command:
- try:
- parsed_command = json.loads(image.command)
- except (ValueError, TypeError):
- parsed_command = {'error': 'Could not parse command'}
+ parsed_command = None
+ if image.command:
+ try:
+ parsed_command = json.loads(image.command)
+ except (ValueError, TypeError):
+ parsed_command = {"error": "Could not parse command"}
- image_data = {
- 'id': image.docker_image_id,
- 'created': format_date(image.created),
- 'comment': image.comment,
- 'command': parsed_command,
- 'size': image.image_size,
- 'uploading': image.uploading,
- 'sort_index': len(image.parents),
- }
+ image_data = {
+ "id": image.docker_image_id,
+ "created": format_date(image.created),
+ "comment": image.comment,
+ "command": parsed_command,
+ "size": image.image_size,
+ "uploading": image.uploading,
+ "sort_index": len(image.parents),
+ }
- if with_tags:
- image_data['tags'] = [tag.name for tag in image.tags]
+ if with_tags:
+ image_data["tags"] = [tag.name for tag in image.tags]
- if with_history:
- image_data['history'] = [image_dict(parent) for parent in image.parents]
+ if with_history:
+ image_data["history"] = [image_dict(parent) for parent in image.parents]
- # Calculate the ancestors string, with the DBID's replaced with the docker IDs.
- parent_docker_ids = [parent_image.docker_image_id for parent_image in image.parents]
- image_data['ancestors'] = '/{0}/'.format('/'.join(parent_docker_ids))
- return image_data
+ # Calculate the ancestors string, with the DBID's replaced with the docker IDs.
+ parent_docker_ids = [parent_image.docker_image_id for parent_image in image.parents]
+ image_data["ancestors"] = "/{0}/".format("/".join(parent_docker_ids))
+ return image_data
-@resource('/v1/repository//image/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//image/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryImageList(RepositoryParamResource):
- """ Resource for listing repository images. """
+ """ Resource for listing repository images. """
- @require_repo_read
- @nickname('listRepositoryImages')
- @disallow_for_app_repositories
- def get(self, namespace, repository):
- """ List the images for the specified repository. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @require_repo_read
+ @nickname("listRepositoryImages")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository):
+ """ List the images for the specified repository. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- images = registry_model.get_legacy_images(repo_ref)
- return {'images': [image_dict(image, with_tags=True) for image in images]}
+ images = registry_model.get_legacy_images(repo_ref)
+ return {"images": [image_dict(image, with_tags=True) for image in images]}
-@resource('/v1/repository//image/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('image_id', 'The Docker image ID')
+@resource("/v1/repository//image/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("image_id", "The Docker image ID")
class RepositoryImage(RepositoryParamResource):
- """ Resource for handling repository images. """
+ """ Resource for handling repository images. """
- @require_repo_read
- @nickname('getImage')
- @disallow_for_app_repositories
- def get(self, namespace, repository, image_id):
- """ Get the information available for the specified image. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @require_repo_read
+ @nickname("getImage")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository, image_id):
+ """ Get the information available for the specified image. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- image = registry_model.get_legacy_image(repo_ref, image_id, include_parents=True)
- if image is None:
- raise NotFound()
+ image = registry_model.get_legacy_image(
+ repo_ref, image_id, include_parents=True
+ )
+ if image is None:
+ raise NotFound()
- return image_dict(image, with_history=True)
+ return image_dict(image, with_history=True)
diff --git a/endpoints/api/logs.py b/endpoints/api/logs.py
index 1760a2e9b..e4e27b94e 100644
--- a/endpoints/api/logs.py
+++ b/endpoints/api/logs.py
@@ -11,334 +11,440 @@ from auth.auth_context import get_authenticated_user
from auth import scopes
from data.logs_model import logs_model
from data.registry_model import registry_model
-from endpoints.api import (resource, nickname, ApiResource, query_param, parse_args,
- RepositoryParamResource, require_repo_admin, related_user_resource,
- format_date, require_user_admin, path_param, require_scope, page_support,
- validate_json_request, InvalidRequest, show_if)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ query_param,
+ parse_args,
+ RepositoryParamResource,
+ require_repo_admin,
+ related_user_resource,
+ format_date,
+ require_user_admin,
+ path_param,
+ require_scope,
+ page_support,
+ validate_json_request,
+ InvalidRequest,
+ show_if,
+)
from endpoints.exception import Unauthorized, NotFound
LOGS_PER_PAGE = 20
-SERVICE_LEVEL_LOG_KINDS = set(['service_key_create', 'service_key_approve', 'service_key_delete',
- 'service_key_modify', 'service_key_extend', 'service_key_rotate'])
+SERVICE_LEVEL_LOG_KINDS = set(
+ [
+ "service_key_create",
+ "service_key_approve",
+ "service_key_delete",
+ "service_key_modify",
+ "service_key_extend",
+ "service_key_rotate",
+ ]
+)
def _parse_datetime(dt_string):
- if not dt_string:
- return None
+ if not dt_string:
+ return None
- try:
- return datetime.strptime(dt_string + ' UTC', '%m/%d/%Y %Z')
- except ValueError:
- return None
+ try:
+ return datetime.strptime(dt_string + " UTC", "%m/%d/%Y %Z")
+ except ValueError:
+ return None
def _validate_logs_arguments(start_time, end_time):
- start_time = _parse_datetime(start_time) or (datetime.today() - timedelta(days=1))
- end_time = _parse_datetime(end_time) or datetime.today()
- end_time = end_time + timedelta(days=1)
- return start_time, end_time
+ start_time = _parse_datetime(start_time) or (datetime.today() - timedelta(days=1))
+ end_time = _parse_datetime(end_time) or datetime.today()
+ end_time = end_time + timedelta(days=1)
+ return start_time, end_time
-def _get_logs(start_time, end_time, performer_name=None, repository_name=None, namespace_name=None,
- page_token=None, filter_kinds=None):
- (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
- log_entry_page = logs_model.lookup_logs(start_time, end_time, performer_name, repository_name,
- namespace_name, filter_kinds, page_token,
- app.config['ACTION_LOG_MAX_PAGE'])
- include_namespace = namespace_name is None and repository_name is None
- return {
- 'start_time': format_date(start_time),
- 'end_time': format_date(end_time),
- 'logs': [log.to_dict(avatar, include_namespace) for log in log_entry_page.logs],
- }, log_entry_page.next_page_token
+def _get_logs(
+ start_time,
+ end_time,
+ performer_name=None,
+ repository_name=None,
+ namespace_name=None,
+ page_token=None,
+ filter_kinds=None,
+):
+ (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
+ log_entry_page = logs_model.lookup_logs(
+ start_time,
+ end_time,
+ performer_name,
+ repository_name,
+ namespace_name,
+ filter_kinds,
+ page_token,
+ app.config["ACTION_LOG_MAX_PAGE"],
+ )
+ include_namespace = namespace_name is None and repository_name is None
+ return (
+ {
+ "start_time": format_date(start_time),
+ "end_time": format_date(end_time),
+ "logs": [
+ log.to_dict(avatar, include_namespace) for log in log_entry_page.logs
+ ],
+ },
+ log_entry_page.next_page_token,
+ )
-def _get_aggregate_logs(start_time, end_time, performer_name=None, repository=None, namespace=None,
- filter_kinds=None):
- (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
- aggregated_logs = logs_model.get_aggregated_log_counts(start_time, end_time,
- performer_name=performer_name,
- repository_name=repository,
- namespace_name=namespace,
- filter_kinds=filter_kinds)
+def _get_aggregate_logs(
+ start_time,
+ end_time,
+ performer_name=None,
+ repository=None,
+ namespace=None,
+ filter_kinds=None,
+):
+ (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
+ aggregated_logs = logs_model.get_aggregated_log_counts(
+ start_time,
+ end_time,
+ performer_name=performer_name,
+ repository_name=repository,
+ namespace_name=namespace,
+ filter_kinds=filter_kinds,
+ )
- return {
- 'aggregated': [log.to_dict() for log in aggregated_logs]
- }
+ return {"aggregated": [log.to_dict() for log in aggregated_logs]}
-@resource('/v1/repository//logs')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//logs")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryLogs(RepositoryParamResource):
- """ Resource for fetching logs for the specific repository. """
+ """ Resource for fetching logs for the specific repository. """
- @require_repo_admin
- @nickname('listRepoLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @page_support()
- def get(self, namespace, repository, page_token, parsed_args):
- """ List the logs for the specified repository. """
- if registry_model.lookup_repository(namespace, repository) is None:
- raise NotFound()
+ @require_repo_admin
+ @nickname("listRepoLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @page_support()
+ def get(self, namespace, repository, page_token, parsed_args):
+ """ List the logs for the specified repository. """
+ if registry_model.lookup_repository(namespace, repository) is None:
+ raise NotFound()
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
- return _get_logs(start_time, end_time,
- repository_name=repository,
- page_token=page_token,
- namespace_name=namespace)
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
+ return _get_logs(
+ start_time,
+ end_time,
+ repository_name=repository,
+ page_token=page_token,
+ namespace_name=namespace,
+ )
-@resource('/v1/user/logs')
+@resource("/v1/user/logs")
class UserLogs(ApiResource):
- """ Resource for fetching logs for the current user. """
+ """ Resource for fetching logs for the current user. """
- @require_user_admin
- @nickname('listUserLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('performer', 'Username for which to filter logs.', type=str)
- @page_support()
- def get(self, parsed_args, page_token):
- """ List the logs for the current user. """
- performer_name = parsed_args['performer']
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ @require_user_admin
+ @nickname("listUserLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param("performer", "Username for which to filter logs.", type=str)
+ @page_support()
+ def get(self, parsed_args, page_token):
+ """ List the logs for the current user. """
+ performer_name = parsed_args["performer"]
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- user = get_authenticated_user()
- return _get_logs(start_time, end_time,
- performer_name=performer_name,
- namespace_name=user.username,
- page_token=page_token,
- filter_kinds=SERVICE_LEVEL_LOG_KINDS)
+ user = get_authenticated_user()
+ return _get_logs(
+ start_time,
+ end_time,
+ performer_name=performer_name,
+ namespace_name=user.username,
+ page_token=page_token,
+ filter_kinds=SERVICE_LEVEL_LOG_KINDS,
+ )
-@resource('/v1/organization//logs')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//logs")
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserLogs)
class OrgLogs(ApiResource):
- """ Resource for fetching logs for the entire organization. """
+ """ Resource for fetching logs for the entire organization. """
- @nickname('listOrgLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('performer', 'Username for which to filter logs.', type=str)
- @page_support()
- @require_scope(scopes.ORG_ADMIN)
- def get(self, orgname, page_token, parsed_args):
- """ List the logs for the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- performer_name = parsed_args['performer']
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ @nickname("listOrgLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param("performer", "Username for which to filter logs.", type=str)
+ @page_support()
+ @require_scope(scopes.ORG_ADMIN)
+ def get(self, orgname, page_token, parsed_args):
+ """ List the logs for the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ performer_name = parsed_args["performer"]
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- return _get_logs(start_time, end_time,
- namespace_name=orgname,
- performer_name=performer_name,
- page_token=page_token)
+ return _get_logs(
+ start_time,
+ end_time,
+ namespace_name=orgname,
+ performer_name=performer_name,
+ page_token=page_token,
+ )
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/repository//aggregatelogs')
+@resource("/v1/repository//aggregatelogs")
@show_if(features.AGGREGATED_LOG_COUNT_RETRIEVAL)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryAggregateLogs(RepositoryParamResource):
- """ Resource for fetching aggregated logs for the specific repository. """
+ """ Resource for fetching aggregated logs for the specific repository. """
- @require_repo_admin
- @nickname('getAggregateRepoLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- def get(self, namespace, repository, parsed_args):
- """ Returns the aggregated logs for the specified repository. """
- if registry_model.lookup_repository(namespace, repository) is None:
- raise NotFound()
+ @require_repo_admin
+ @nickname("getAggregateRepoLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ def get(self, namespace, repository, parsed_args):
+ """ Returns the aggregated logs for the specified repository. """
+ if registry_model.lookup_repository(namespace, repository) is None:
+ raise NotFound()
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
- return _get_aggregate_logs(start_time, end_time,
- repository=repository,
- namespace=namespace)
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
+ return _get_aggregate_logs(
+ start_time, end_time, repository=repository, namespace=namespace
+ )
-@resource('/v1/user/aggregatelogs')
+@resource("/v1/user/aggregatelogs")
@show_if(features.AGGREGATED_LOG_COUNT_RETRIEVAL)
class UserAggregateLogs(ApiResource):
- """ Resource for fetching aggregated logs for the current user. """
+ """ Resource for fetching aggregated logs for the current user. """
- @require_user_admin
- @nickname('getAggregateUserLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('performer', 'Username for which to filter logs.', type=str)
- def get(self, parsed_args):
- """ Returns the aggregated logs for the current user. """
- performer_name = parsed_args['performer']
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ @require_user_admin
+ @nickname("getAggregateUserLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param("performer", "Username for which to filter logs.", type=str)
+ def get(self, parsed_args):
+ """ Returns the aggregated logs for the current user. """
+ performer_name = parsed_args["performer"]
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- user = get_authenticated_user()
- return _get_aggregate_logs(start_time, end_time,
- performer_name=performer_name,
- namespace=user.username,
- filter_kinds=SERVICE_LEVEL_LOG_KINDS)
+ user = get_authenticated_user()
+ return _get_aggregate_logs(
+ start_time,
+ end_time,
+ performer_name=performer_name,
+ namespace=user.username,
+ filter_kinds=SERVICE_LEVEL_LOG_KINDS,
+ )
-@resource('/v1/organization//aggregatelogs')
+@resource("/v1/organization//aggregatelogs")
@show_if(features.AGGREGATED_LOG_COUNT_RETRIEVAL)
-@path_param('orgname', 'The name of the organization')
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserLogs)
class OrgAggregateLogs(ApiResource):
- """ Resource for fetching aggregate logs for the entire organization. """
+ """ Resource for fetching aggregate logs for the entire organization. """
- @nickname('getAggregateOrgLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('performer', 'Username for which to filter logs.', type=str)
- @require_scope(scopes.ORG_ADMIN)
- def get(self, orgname, parsed_args):
- """ Gets the aggregated logs for the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- performer_name = parsed_args['performer']
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ @nickname("getAggregateOrgLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param("performer", "Username for which to filter logs.", type=str)
+ @require_scope(scopes.ORG_ADMIN)
+ def get(self, orgname, parsed_args):
+ """ Gets the aggregated logs for the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ performer_name = parsed_args["performer"]
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- return _get_aggregate_logs(start_time, end_time,
- namespace=orgname,
- performer_name=performer_name)
+ return _get_aggregate_logs(
+ start_time, end_time, namespace=orgname, performer_name=performer_name
+ )
- raise Unauthorized()
+ raise Unauthorized()
EXPORT_LOGS_SCHEMA = {
- 'type': 'object',
- 'description': 'Configuration for an export logs operation',
- 'properties': {
- 'callback_url': {
- 'type': 'string',
- 'description': 'The callback URL to invoke with a link to the exported logs',
+ "type": "object",
+ "description": "Configuration for an export logs operation",
+ "properties": {
+ "callback_url": {
+ "type": "string",
+ "description": "The callback URL to invoke with a link to the exported logs",
+ },
+ "callback_email": {
+ "type": "string",
+ "description": "The e-mail address at which to e-mail a link to the exported logs",
+ },
},
- 'callback_email': {
- 'type': 'string',
- 'description': 'The e-mail address at which to e-mail a link to the exported logs',
- },
- },
}
-def _queue_logs_export(start_time, end_time, options, namespace_name, repository_name=None):
- callback_url = options.get('callback_url')
- if callback_url:
- if not callback_url.startswith('https://') and not callback_url.startswith('http://'):
- raise InvalidRequest('Invalid callback URL')
+def _queue_logs_export(
+ start_time, end_time, options, namespace_name, repository_name=None
+):
+ callback_url = options.get("callback_url")
+ if callback_url:
+ if not callback_url.startswith("https://") and not callback_url.startswith(
+ "http://"
+ ):
+ raise InvalidRequest("Invalid callback URL")
- callback_email = options.get('callback_email')
- if callback_email:
- if callback_email.find('@') < 0:
- raise InvalidRequest('Invalid callback e-mail')
+ callback_email = options.get("callback_email")
+ if callback_email:
+ if callback_email.find("@") < 0:
+ raise InvalidRequest("Invalid callback e-mail")
- (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
- export_id = logs_model.queue_logs_export(start_time, end_time, export_action_logs_queue,
- namespace_name, repository_name, callback_url,
- callback_email)
- if export_id is None:
- raise InvalidRequest('Invalid export request')
+ (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
+ export_id = logs_model.queue_logs_export(
+ start_time,
+ end_time,
+ export_action_logs_queue,
+ namespace_name,
+ repository_name,
+ callback_url,
+ callback_email,
+ )
+ if export_id is None:
+ raise InvalidRequest("Invalid export request")
- return export_id
+ return export_id
-@resource('/v1/repository//exportlogs')
+@resource("/v1/repository//exportlogs")
@show_if(features.LOG_EXPORT)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class ExportRepositoryLogs(RepositoryParamResource):
- """ Resource for exporting the logs for the specific repository. """
- schemas = {
- 'ExportLogs': EXPORT_LOGS_SCHEMA
- }
+ """ Resource for exporting the logs for the specific repository. """
- @require_repo_admin
- @nickname('exportRepoLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @validate_json_request('ExportLogs')
- def post(self, namespace, repository, parsed_args):
- """ Queues an export of the logs for the specified repository. """
- if registry_model.lookup_repository(namespace, repository) is None:
- raise NotFound()
+ schemas = {"ExportLogs": EXPORT_LOGS_SCHEMA}
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
- export_id = _queue_logs_export(start_time, end_time, request.get_json(), namespace,
- repository_name=repository)
- return {
- 'export_id': export_id,
- }
+ @require_repo_admin
+ @nickname("exportRepoLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @validate_json_request("ExportLogs")
+ def post(self, namespace, repository, parsed_args):
+ """ Queues an export of the logs for the specified repository. """
+ if registry_model.lookup_repository(namespace, repository) is None:
+ raise NotFound()
+
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
+ export_id = _queue_logs_export(
+ start_time,
+ end_time,
+ request.get_json(),
+ namespace,
+ repository_name=repository,
+ )
+ return {"export_id": export_id}
-@resource('/v1/user/exportlogs')
+@resource("/v1/user/exportlogs")
@show_if(features.LOG_EXPORT)
class ExportUserLogs(ApiResource):
- """ Resource for exporting the logs for the current user repository. """
- schemas = {
- 'ExportLogs': EXPORT_LOGS_SCHEMA
- }
+ """ Resource for exporting the logs for the current user repository. """
- @require_user_admin
- @nickname('exportUserLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @validate_json_request('ExportLogs')
- def post(self, parsed_args):
- """ Returns the aggregated logs for the current user. """
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ schemas = {"ExportLogs": EXPORT_LOGS_SCHEMA}
- user = get_authenticated_user()
- export_id = _queue_logs_export(start_time, end_time, request.get_json(), user.username)
- return {
- 'export_id': export_id,
- }
+ @require_user_admin
+ @nickname("exportUserLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @validate_json_request("ExportLogs")
+ def post(self, parsed_args):
+ """ Returns the aggregated logs for the current user. """
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
+
+ user = get_authenticated_user()
+ export_id = _queue_logs_export(
+ start_time, end_time, request.get_json(), user.username
+ )
+ return {"export_id": export_id}
-@resource('/v1/organization//exportlogs')
+@resource("/v1/organization//exportlogs")
@show_if(features.LOG_EXPORT)
-@path_param('orgname', 'The name of the organization')
+@path_param("orgname", "The name of the organization")
@related_user_resource(ExportUserLogs)
class ExportOrgLogs(ApiResource):
- """ Resource for exporting the logs for an entire organization. """
- schemas = {
- 'ExportLogs': EXPORT_LOGS_SCHEMA
- }
+ """ Resource for exporting the logs for an entire organization. """
- @nickname('exportOrgLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @query_param('endtime', 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str)
- @require_scope(scopes.ORG_ADMIN)
- @validate_json_request('ExportLogs')
- def post(self, orgname, parsed_args):
- """ Exports the logs for the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ schemas = {"ExportLogs": EXPORT_LOGS_SCHEMA}
- export_id = _queue_logs_export(start_time, end_time, request.get_json(), orgname)
- return {
- 'export_id': export_id,
- }
+ @nickname("exportOrgLogs")
+ @parse_args()
+ @query_param(
+ "starttime", 'Earliest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @query_param(
+ "endtime", 'Latest time for logs. Format: "%m/%d/%Y" in UTC.', type=str
+ )
+ @require_scope(scopes.ORG_ADMIN)
+ @validate_json_request("ExportLogs")
+ def post(self, orgname, parsed_args):
+ """ Exports the logs for the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- raise Unauthorized()
+ export_id = _queue_logs_export(
+ start_time, end_time, request.get_json(), orgname
+ )
+ return {"export_id": export_id}
+
+ raise Unauthorized()
diff --git a/endpoints/api/manifest.py b/endpoints/api/manifest.py
index 1370fa743..c3e0aa79a 100644
--- a/endpoints/api/manifest.py
+++ b/endpoints/api/manifest.py
@@ -8,266 +8,291 @@ from app import label_validator, storage
from data.model import InvalidLabelKeyException, InvalidMediaTypeException
from data.registry_model import registry_model
from digest import digest_tools
-from endpoints.api import (resource, nickname, require_repo_read, require_repo_write,
- RepositoryParamResource, log_action, validate_json_request,
- path_param, parse_args, query_param, abort, api,
- disallow_for_app_repositories, format_date,
- disallow_for_non_normal_repositories)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_read,
+ require_repo_write,
+ RepositoryParamResource,
+ log_action,
+ validate_json_request,
+ path_param,
+ parse_args,
+ query_param,
+ abort,
+ api,
+ disallow_for_app_repositories,
+ format_date,
+ disallow_for_non_normal_repositories,
+)
from endpoints.api.image import image_dict
from endpoints.exception import NotFound
from util.validation import VALID_LABEL_KEY_REGEX
-BASE_MANIFEST_ROUTE = '/v1/repository//manifest/'
+BASE_MANIFEST_ROUTE = (
+ '/v1/repository//manifest/'
+)
MANIFEST_DIGEST_ROUTE = BASE_MANIFEST_ROUTE.format(digest_tools.DIGEST_PATTERN)
-ALLOWED_LABEL_MEDIA_TYPES = ['text/plain', 'application/json']
+ALLOWED_LABEL_MEDIA_TYPES = ["text/plain", "application/json"]
logger = logging.getLogger(__name__)
+
def _label_dict(label):
- return {
- 'id': label.uuid,
- 'key': label.key,
- 'value': label.value,
- 'source_type': label.source_type_name,
- 'media_type': label.media_type_name,
- }
+ return {
+ "id": label.uuid,
+ "key": label.key,
+ "value": label.value,
+ "source_type": label.source_type_name,
+ "media_type": label.media_type_name,
+ }
def _layer_dict(manifest_layer, index):
- # NOTE: The `command` in the layer is either a JSON string of an array (schema 1) or
- # a single string (schema 2). The block below normalizes it to have the same format.
- command = None
- if manifest_layer.command:
- try:
- command = json.loads(manifest_layer.command)
- except (TypeError, ValueError):
- command = [manifest_layer.command]
+ # NOTE: The `command` in the layer is either a JSON string of an array (schema 1) or
+ # a single string (schema 2). The block below normalizes it to have the same format.
+ command = None
+ if manifest_layer.command:
+ try:
+ command = json.loads(manifest_layer.command)
+ except (TypeError, ValueError):
+ command = [manifest_layer.command]
- return {
- 'index': index,
- 'compressed_size': manifest_layer.compressed_size,
- 'is_remote': manifest_layer.is_remote,
- 'urls': manifest_layer.urls,
- 'command': command,
- 'comment': manifest_layer.comment,
- 'author': manifest_layer.author,
- 'blob_digest': str(manifest_layer.blob_digest),
- 'created_datetime': format_date(manifest_layer.created_datetime),
- }
+ return {
+ "index": index,
+ "compressed_size": manifest_layer.compressed_size,
+ "is_remote": manifest_layer.is_remote,
+ "urls": manifest_layer.urls,
+ "command": command,
+ "comment": manifest_layer.comment,
+ "author": manifest_layer.author,
+ "blob_digest": str(manifest_layer.blob_digest),
+ "created_datetime": format_date(manifest_layer.created_datetime),
+ }
def _manifest_dict(manifest):
- image = None
- if manifest.legacy_image_if_present is not None:
- image = image_dict(manifest.legacy_image, with_history=True)
+ image = None
+ if manifest.legacy_image_if_present is not None:
+ image = image_dict(manifest.legacy_image, with_history=True)
- layers = None
- if not manifest.is_manifest_list:
- layers = registry_model.list_manifest_layers(manifest, storage)
- if layers is None:
- logger.debug('Missing layers for manifest `%s`', manifest.digest)
- abort(404)
+ layers = None
+ if not manifest.is_manifest_list:
+ layers = registry_model.list_manifest_layers(manifest, storage)
+ if layers is None:
+ logger.debug("Missing layers for manifest `%s`", manifest.digest)
+ abort(404)
- return {
- 'digest': manifest.digest,
- 'is_manifest_list': manifest.is_manifest_list,
- 'manifest_data': manifest.internal_manifest_bytes.as_unicode(),
- 'image': image,
- 'layers': ([_layer_dict(lyr.layer_info, idx) for idx, lyr in enumerate(layers)]
- if layers else None),
- }
+ return {
+ "digest": manifest.digest,
+ "is_manifest_list": manifest.is_manifest_list,
+ "manifest_data": manifest.internal_manifest_bytes.as_unicode(),
+ "image": image,
+ "layers": (
+ [_layer_dict(lyr.layer_info, idx) for idx, lyr in enumerate(layers)]
+ if layers
+ else None
+ ),
+ }
@resource(MANIFEST_DIGEST_ROUTE)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('manifestref', 'The digest of the manifest')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("manifestref", "The digest of the manifest")
class RepositoryManifest(RepositoryParamResource):
- """ Resource for retrieving a specific repository manifest. """
- @require_repo_read
- @nickname('getRepoManifest')
- @disallow_for_app_repositories
- def get(self, namespace_name, repository_name, manifestref):
- repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
- if repo_ref is None:
- raise NotFound()
+ """ Resource for retrieving a specific repository manifest. """
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref,
- include_legacy_image=True)
- if manifest is None:
- raise NotFound()
+ @require_repo_read
+ @nickname("getRepoManifest")
+ @disallow_for_app_repositories
+ def get(self, namespace_name, repository_name, manifestref):
+ repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
+ if repo_ref is None:
+ raise NotFound()
- return _manifest_dict(manifest)
+ manifest = registry_model.lookup_manifest_by_digest(
+ repo_ref, manifestref, include_legacy_image=True
+ )
+ if manifest is None:
+ raise NotFound()
+
+ return _manifest_dict(manifest)
-@resource(MANIFEST_DIGEST_ROUTE + '/labels')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('manifestref', 'The digest of the manifest')
+@resource(MANIFEST_DIGEST_ROUTE + "/labels")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("manifestref", "The digest of the manifest")
class RepositoryManifestLabels(RepositoryParamResource):
- """ Resource for listing the labels on a specific repository manifest. """
- schemas = {
- 'AddLabel': {
- 'type': 'object',
- 'description': 'Adds a label to a manifest',
- 'required': [
- 'key',
- 'value',
- 'media_type',
- ],
- 'properties': {
- 'key': {
- 'type': 'string',
- 'description': 'The key for the label',
- },
- 'value': {
- 'type': 'string',
- 'description': 'The value for the label',
- },
- 'media_type': {
- 'type': ['string', 'null'],
- 'description': 'The media type for this label',
- 'enum': ALLOWED_LABEL_MEDIA_TYPES + [None],
- },
- },
- },
- }
+ """ Resource for listing the labels on a specific repository manifest. """
- @require_repo_read
- @nickname('listManifestLabels')
- @disallow_for_app_repositories
- @parse_args()
- @query_param('filter', 'If specified, only labels matching the given prefix will be returned',
- type=str, default=None)
- def get(self, namespace_name, repository_name, manifestref, parsed_args):
- repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
- if repo_ref is None:
- raise NotFound()
-
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
- if manifest is None:
- raise NotFound()
-
- labels = registry_model.list_manifest_labels(manifest, parsed_args['filter'])
- if labels is None:
- raise NotFound()
-
- return {
- 'labels': [_label_dict(label) for label in labels]
+ schemas = {
+ "AddLabel": {
+ "type": "object",
+ "description": "Adds a label to a manifest",
+ "required": ["key", "value", "media_type"],
+ "properties": {
+ "key": {"type": "string", "description": "The key for the label"},
+ "value": {"type": "string", "description": "The value for the label"},
+ "media_type": {
+ "type": ["string", "null"],
+ "description": "The media type for this label",
+ "enum": ALLOWED_LABEL_MEDIA_TYPES + [None],
+ },
+ },
+ }
}
- @require_repo_write
- @nickname('addManifestLabel')
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- @validate_json_request('AddLabel')
- def post(self, namespace_name, repository_name, manifestref):
- """ Adds a new label into the tag manifest. """
- label_data = request.get_json()
+ @require_repo_read
+ @nickname("listManifestLabels")
+ @disallow_for_app_repositories
+ @parse_args()
+ @query_param(
+ "filter",
+ "If specified, only labels matching the given prefix will be returned",
+ type=str,
+ default=None,
+ )
+ def get(self, namespace_name, repository_name, manifestref, parsed_args):
+ repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
+ if repo_ref is None:
+ raise NotFound()
- # Check for any reserved prefixes.
- if label_validator.has_reserved_prefix(label_data['key']):
- abort(400, message='Label has a reserved prefix')
+ manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
+ if manifest is None:
+ raise NotFound()
- repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
- if repo_ref is None:
- raise NotFound()
+ labels = registry_model.list_manifest_labels(manifest, parsed_args["filter"])
+ if labels is None:
+ raise NotFound()
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
- if manifest is None:
- raise NotFound()
+ return {"labels": [_label_dict(label) for label in labels]}
- label = None
- try:
- label = registry_model.create_manifest_label(manifest,
- label_data['key'],
- label_data['value'],
- 'api',
- label_data['media_type'])
- except InvalidLabelKeyException:
- message = ('Label is of an invalid format or missing please ' +
- 'use %s format for labels' % VALID_LABEL_KEY_REGEX)
- abort(400, message=message)
- except InvalidMediaTypeException:
- message = 'Media type is invalid please use a valid media type: text/plain, application/json'
- abort(400, message=message)
+ @require_repo_write
+ @nickname("addManifestLabel")
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ @validate_json_request("AddLabel")
+ def post(self, namespace_name, repository_name, manifestref):
+ """ Adds a new label into the tag manifest. """
+ label_data = request.get_json()
- if label is None:
- raise NotFound()
+ # Check for any reserved prefixes.
+ if label_validator.has_reserved_prefix(label_data["key"]):
+ abort(400, message="Label has a reserved prefix")
- metadata = {
- 'id': label.uuid,
- 'key': label.key,
- 'value': label.value,
- 'manifest_digest': manifestref,
- 'media_type': label.media_type_name,
- 'namespace': namespace_name,
- 'repo': repository_name,
- }
+ repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
+ if repo_ref is None:
+ raise NotFound()
- log_action('manifest_label_add', namespace_name, metadata, repo_name=repository_name)
+ manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
+ if manifest is None:
+ raise NotFound()
- resp = {'label': _label_dict(label)}
- repo_string = '%s/%s' % (namespace_name, repository_name)
- headers = {
- 'Location': api.url_for(ManageRepositoryManifestLabel, repository=repo_string,
- manifestref=manifestref, labelid=label.uuid),
- }
- return resp, 201, headers
+ label = None
+ try:
+ label = registry_model.create_manifest_label(
+ manifest,
+ label_data["key"],
+ label_data["value"],
+ "api",
+ label_data["media_type"],
+ )
+ except InvalidLabelKeyException:
+ message = (
+ "Label is of an invalid format or missing please "
+ + "use %s format for labels" % VALID_LABEL_KEY_REGEX
+ )
+ abort(400, message=message)
+ except InvalidMediaTypeException:
+ message = "Media type is invalid please use a valid media type: text/plain, application/json"
+ abort(400, message=message)
+
+ if label is None:
+ raise NotFound()
+
+ metadata = {
+ "id": label.uuid,
+ "key": label.key,
+ "value": label.value,
+ "manifest_digest": manifestref,
+ "media_type": label.media_type_name,
+ "namespace": namespace_name,
+ "repo": repository_name,
+ }
+
+ log_action(
+ "manifest_label_add", namespace_name, metadata, repo_name=repository_name
+ )
+
+ resp = {"label": _label_dict(label)}
+ repo_string = "%s/%s" % (namespace_name, repository_name)
+ headers = {
+ "Location": api.url_for(
+ ManageRepositoryManifestLabel,
+ repository=repo_string,
+ manifestref=manifestref,
+ labelid=label.uuid,
+ )
+ }
+ return resp, 201, headers
-@resource(MANIFEST_DIGEST_ROUTE + '/labels/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('manifestref', 'The digest of the manifest')
-@path_param('labelid', 'The ID of the label')
+@resource(MANIFEST_DIGEST_ROUTE + "/labels/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("manifestref", "The digest of the manifest")
+@path_param("labelid", "The ID of the label")
class ManageRepositoryManifestLabel(RepositoryParamResource):
- """ Resource for managing the labels on a specific repository manifest. """
- @require_repo_read
- @nickname('getManifestLabel')
- @disallow_for_app_repositories
- def get(self, namespace_name, repository_name, manifestref, labelid):
- """ Retrieves the label with the specific ID under the manifest. """
- repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
- if repo_ref is None:
- raise NotFound()
+ """ Resource for managing the labels on a specific repository manifest. """
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
- if manifest is None:
- raise NotFound()
+ @require_repo_read
+ @nickname("getManifestLabel")
+ @disallow_for_app_repositories
+ def get(self, namespace_name, repository_name, manifestref, labelid):
+ """ Retrieves the label with the specific ID under the manifest. """
+ repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
+ if repo_ref is None:
+ raise NotFound()
- label = registry_model.get_manifest_label(manifest, labelid)
- if label is None:
- raise NotFound()
+ manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
+ if manifest is None:
+ raise NotFound()
- return _label_dict(label)
+ label = registry_model.get_manifest_label(manifest, labelid)
+ if label is None:
+ raise NotFound()
+ return _label_dict(label)
- @require_repo_write
- @nickname('deleteManifestLabel')
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- def delete(self, namespace_name, repository_name, manifestref, labelid):
- """ Deletes an existing label from a manifest. """
- repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
- if repo_ref is None:
- raise NotFound()
+ @require_repo_write
+ @nickname("deleteManifestLabel")
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ def delete(self, namespace_name, repository_name, manifestref, labelid):
+ """ Deletes an existing label from a manifest. """
+ repo_ref = registry_model.lookup_repository(namespace_name, repository_name)
+ if repo_ref is None:
+ raise NotFound()
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
- if manifest is None:
- raise NotFound()
+ manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref)
+ if manifest is None:
+ raise NotFound()
- deleted = registry_model.delete_manifest_label(manifest, labelid)
- if deleted is None:
- raise NotFound()
+ deleted = registry_model.delete_manifest_label(manifest, labelid)
+ if deleted is None:
+ raise NotFound()
- metadata = {
- 'id': labelid,
- 'key': deleted.key,
- 'value': deleted.value,
- 'manifest_digest': manifestref,
- 'namespace': namespace_name,
- 'repo': repository_name,
- }
+ metadata = {
+ "id": labelid,
+ "key": deleted.key,
+ "value": deleted.value,
+ "manifest_digest": manifestref,
+ "namespace": namespace_name,
+ "repo": repository_name,
+ }
- log_action('manifest_label_delete', namespace_name, metadata, repo_name=repository_name)
- return '', 204
+ log_action(
+ "manifest_label_delete", namespace_name, metadata, repo_name=repository_name
+ )
+ return "", 204
diff --git a/endpoints/api/mirror.py b/endpoints/api/mirror.py
index cac7f9caa..cd782dced 100644
--- a/endpoints/api/mirror.py
+++ b/endpoints/api/mirror.py
@@ -11,387 +11,511 @@ import features
from auth.auth_context import get_authenticated_user
from data import model
-from endpoints.api import (RepositoryParamResource, nickname, path_param, require_repo_admin,
- resource, validate_json_request, define_json_response, show_if,
- format_date)
+from endpoints.api import (
+ RepositoryParamResource,
+ nickname,
+ path_param,
+ require_repo_admin,
+ resource,
+ validate_json_request,
+ define_json_response,
+ show_if,
+ format_date,
+)
from endpoints.exception import NotFound
from util.audit import track_and_log, wrap_repository
from util.names import parse_robot_username
common_properties = {
- 'is_enabled': {
- 'type': 'boolean',
- 'description': 'Used to enable or disable synchronizations.',
- },
- 'external_reference': {
- 'type': 'string',
- 'description': 'Location of the external repository.'
- },
- 'external_registry_username': {
- 'type': ['string', 'null'],
- 'description': 'Username used to authenticate with external registry.',
- },
- 'external_registry_password': {
- 'type': ['string', 'null'],
- 'description': 'Password used to authenticate with external registry.',
- },
- 'sync_start_date': {
- 'type': 'string',
- 'description': 'Determines the next time this repository is ready for synchronization.',
- },
- 'sync_interval': {
- 'type': 'integer',
- 'minimum': 0,
- 'description': 'Number of seconds after next_start_date to begin synchronizing.'
- },
- 'robot_username': {
- 'type': 'string',
- 'description': 'Username of robot which will be used for image pushes.'
- },
- 'root_rule': {
- 'type': 'object',
- 'description': 'Tag mirror rule',
- 'required': [
- 'rule_type',
- 'rule_value'
- ],
- 'properties': {
- 'rule_type': {
- 'type': 'string',
- 'description': 'Rule type must be "TAG_GLOB_CSV"'
- },
- 'rule_value': {
- 'type': 'array',
- 'description': 'Array of tag patterns',
- 'items': {
- 'type': 'string'
- }
- }
+ "is_enabled": {
+ "type": "boolean",
+ "description": "Used to enable or disable synchronizations.",
+ },
+ "external_reference": {
+ "type": "string",
+ "description": "Location of the external repository.",
+ },
+ "external_registry_username": {
+ "type": ["string", "null"],
+ "description": "Username used to authenticate with external registry.",
+ },
+ "external_registry_password": {
+ "type": ["string", "null"],
+ "description": "Password used to authenticate with external registry.",
+ },
+ "sync_start_date": {
+ "type": "string",
+ "description": "Determines the next time this repository is ready for synchronization.",
+ },
+ "sync_interval": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "Number of seconds after next_start_date to begin synchronizing.",
+ },
+ "robot_username": {
+ "type": "string",
+ "description": "Username of robot which will be used for image pushes.",
+ },
+ "root_rule": {
+ "type": "object",
+ "description": "Tag mirror rule",
+ "required": ["rule_type", "rule_value"],
+ "properties": {
+ "rule_type": {
+ "type": "string",
+ "description": 'Rule type must be "TAG_GLOB_CSV"',
+ },
+ "rule_value": {
+ "type": "array",
+ "description": "Array of tag patterns",
+ "items": {"type": "string"},
+ },
+ },
+ "description": "A list of glob-patterns used to determine which tags should be synchronized.",
+ },
+ "external_registry_config": {
+ "type": "object",
+ "properties": {
+ "verify_tls": {
+ "type": "boolean",
+ "description": (
+ "Determines whether HTTPs is required and the certificate is verified when "
+ "communicating with the external repository."
+ ),
+ },
+ "proxy": {
+ "type": "object",
+ "description": "Proxy configuration for use during synchronization.",
+ "properties": {
+ "https_proxy": {
+ "type": ["string", "null"],
+ "description": "Value for HTTPS_PROXY environment variable during sync.",
+ },
+ "http_proxy": {
+ "type": ["string", "null"],
+ "description": "Value for HTTP_PROXY environment variable during sync.",
+ },
+ "no_proxy": {
+ "type": ["string", "null"],
+ "description": "Value for NO_PROXY environment variable during sync.",
+ },
+ },
+ },
+ },
},
- 'description': 'A list of glob-patterns used to determine which tags should be synchronized.'
- },
- 'external_registry_config': {
- 'type': 'object',
- 'properties': {
- 'verify_tls': {
- 'type': 'boolean',
- 'description': (
- 'Determines whether HTTPs is required and the certificate is verified when '
- 'communicating with the external repository.'
- ),
- },
- 'proxy': {
- 'type': 'object',
- 'description': 'Proxy configuration for use during synchronization.',
- 'properties': {
- 'https_proxy': {
- 'type': ['string', 'null'],
- 'description': 'Value for HTTPS_PROXY environment variable during sync.'
- },
- 'http_proxy': {
- 'type': ['string', 'null'],
- 'description': 'Value for HTTP_PROXY environment variable during sync.'
- },
- 'no_proxy': {
- 'type': ['string', 'null'],
- 'description': 'Value for NO_PROXY environment variable during sync.'
- }
- }
- }
- }
- }
}
-@resource('/v1/repository//mirror/sync-now')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//mirror/sync-now")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
@show_if(features.REPO_MIRROR)
class RepoMirrorSyncNowResource(RepositoryParamResource):
- """ A resource for managing RepoMirrorConfig.sync_status """
+ """ A resource for managing RepoMirrorConfig.sync_status """
- @require_repo_admin
- @nickname('syncNow')
- def post(self, namespace_name, repository_name):
- """ Update the sync_status for a given Repository's mirroring configuration. """
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
+ @require_repo_admin
+ @nickname("syncNow")
+ def post(self, namespace_name, repository_name):
+ """ Update the sync_status for a given Repository's mirroring configuration. """
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- mirror = model.repo_mirror.get_mirror(repository=repo)
- if not mirror:
- raise NotFound()
+ mirror = model.repo_mirror.get_mirror(repository=repo)
+ if not mirror:
+ raise NotFound()
- if mirror and model.repo_mirror.update_sync_status_to_sync_now(mirror):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed="sync_status", to="SYNC_NOW")
- return '', 204
+ if mirror and model.repo_mirror.update_sync_status_to_sync_now(mirror):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="sync_status",
+ to="SYNC_NOW",
+ )
+ return "", 204
- raise NotFound()
+ raise NotFound()
-@resource('/v1/repository//mirror/sync-cancel')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//mirror/sync-cancel")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
@show_if(features.REPO_MIRROR)
class RepoMirrorSyncCancelResource(RepositoryParamResource):
- """ A resource for managing RepoMirrorConfig.sync_status """
+ """ A resource for managing RepoMirrorConfig.sync_status """
- @require_repo_admin
- @nickname('syncCancel')
- def post(self, namespace_name, repository_name):
- """ Update the sync_status for a given Repository's mirroring configuration. """
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
+ @require_repo_admin
+ @nickname("syncCancel")
+ def post(self, namespace_name, repository_name):
+ """ Update the sync_status for a given Repository's mirroring configuration. """
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- mirror = model.repo_mirror.get_mirror(repository=repo)
- if not mirror:
- raise NotFound()
+ mirror = model.repo_mirror.get_mirror(repository=repo)
+ if not mirror:
+ raise NotFound()
- if mirror and model.repo_mirror.update_sync_status_to_cancel(mirror):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed="sync_status", to="SYNC_CANCEL")
- return '', 204
+ if mirror and model.repo_mirror.update_sync_status_to_cancel(mirror):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="sync_status",
+ to="SYNC_CANCEL",
+ )
+ return "", 204
- raise NotFound()
+ raise NotFound()
-@resource('/v1/repository//mirror')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//mirror")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
@show_if(features.REPO_MIRROR)
class RepoMirrorResource(RepositoryParamResource):
- """
+ """
Resource for managing repository mirroring.
"""
- schemas = {
- 'CreateMirrorConfig': {
- 'description': 'Create the repository mirroring configuration.',
- 'type': 'object',
- 'required': [
- 'external_reference',
- 'sync_interval',
- 'sync_start_date',
- 'root_rule'
- ],
- 'properties': common_properties
- },
- 'UpdateMirrorConfig': {
- 'description': 'Update the repository mirroring configuration.',
- 'type': 'object',
- 'properties': common_properties
- },
- 'ViewMirrorConfig': {
- 'description': 'View the repository mirroring configuration.',
- 'type': 'object',
- 'required': [
- 'is_enabled',
- 'mirror_type',
- 'external_reference',
- 'external_registry_username',
- 'external_registry_config',
- 'sync_interval',
- 'sync_start_date',
- 'sync_expiration_date',
- 'sync_retries_remaining',
- 'sync_status',
- 'root_rule',
- 'robot_username',
- ],
- 'properties': common_properties
- }
- }
- @require_repo_admin
- @define_json_response('ViewMirrorConfig')
- @nickname('getRepoMirrorConfig')
- def get(self, namespace_name, repository_name):
- """ Return the Mirror configuration for a given Repository. """
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
-
- mirror = model.repo_mirror.get_mirror(repo)
- if not mirror:
- raise NotFound()
-
- # Transformations
- rules = mirror.root_rule.rule_value
- username = self._decrypt_username(mirror.external_registry_username)
- sync_start_date = self._dt_to_string(mirror.sync_start_date)
- sync_expiration_date = self._dt_to_string(mirror.sync_expiration_date)
- robot = mirror.internal_robot.username if mirror.internal_robot is not None else None
-
- return {
- 'is_enabled': mirror.is_enabled,
- 'mirror_type': mirror.mirror_type.name,
- 'external_reference': mirror.external_reference,
- 'external_registry_username': username,
- 'external_registry_config': mirror.external_registry_config or {},
- 'sync_interval': mirror.sync_interval,
- 'sync_start_date': sync_start_date,
- 'sync_expiration_date': sync_expiration_date,
- 'sync_retries_remaining': mirror.sync_retries_remaining,
- 'sync_status': mirror.sync_status.name,
- 'root_rule': {
- 'rule_type': 'TAG_GLOB_CSV',
- 'rule_value': rules
- },
- 'robot_username': robot,
+ schemas = {
+ "CreateMirrorConfig": {
+ "description": "Create the repository mirroring configuration.",
+ "type": "object",
+ "required": [
+ "external_reference",
+ "sync_interval",
+ "sync_start_date",
+ "root_rule",
+ ],
+ "properties": common_properties,
+ },
+ "UpdateMirrorConfig": {
+ "description": "Update the repository mirroring configuration.",
+ "type": "object",
+ "properties": common_properties,
+ },
+ "ViewMirrorConfig": {
+ "description": "View the repository mirroring configuration.",
+ "type": "object",
+ "required": [
+ "is_enabled",
+ "mirror_type",
+ "external_reference",
+ "external_registry_username",
+ "external_registry_config",
+ "sync_interval",
+ "sync_start_date",
+ "sync_expiration_date",
+ "sync_retries_remaining",
+ "sync_status",
+ "root_rule",
+ "robot_username",
+ ],
+ "properties": common_properties,
+ },
}
- @require_repo_admin
- @nickname('createRepoMirrorConfig')
- @validate_json_request('CreateMirrorConfig')
- def post(self, namespace_name, repository_name):
- """ Create a RepoMirrorConfig for a given Repository. """
- # TODO: Tidy up this function
- # TODO: Specify only the data we want to pass on when creating the RepoMirrorConfig. Avoid
- # the possibility of data injection.
+ @require_repo_admin
+ @define_json_response("ViewMirrorConfig")
+ @nickname("getRepoMirrorConfig")
+ def get(self, namespace_name, repository_name):
+ """ Return the Mirror configuration for a given Repository. """
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
+ mirror = model.repo_mirror.get_mirror(repo)
+ if not mirror:
+ raise NotFound()
- if model.repo_mirror.get_mirror(repo):
- return {'detail': 'Mirror configuration already exits for repository %s/%s' % (
- namespace_name, repository_name)}, 409
+ # Transformations
+ rules = mirror.root_rule.rule_value
+ username = self._decrypt_username(mirror.external_registry_username)
+ sync_start_date = self._dt_to_string(mirror.sync_start_date)
+ sync_expiration_date = self._dt_to_string(mirror.sync_expiration_date)
+ robot = (
+ mirror.internal_robot.username
+ if mirror.internal_robot is not None
+ else None
+ )
- data = request.get_json()
+ return {
+ "is_enabled": mirror.is_enabled,
+ "mirror_type": mirror.mirror_type.name,
+ "external_reference": mirror.external_reference,
+ "external_registry_username": username,
+ "external_registry_config": mirror.external_registry_config or {},
+ "sync_interval": mirror.sync_interval,
+ "sync_start_date": sync_start_date,
+ "sync_expiration_date": sync_expiration_date,
+ "sync_retries_remaining": mirror.sync_retries_remaining,
+ "sync_status": mirror.sync_status.name,
+ "root_rule": {"rule_type": "TAG_GLOB_CSV", "rule_value": rules},
+ "robot_username": robot,
+ }
- data['sync_start_date'] = self._string_to_dt(data['sync_start_date'])
+ @require_repo_admin
+ @nickname("createRepoMirrorConfig")
+ @validate_json_request("CreateMirrorConfig")
+ def post(self, namespace_name, repository_name):
+ """ Create a RepoMirrorConfig for a given Repository. """
+ # TODO: Tidy up this function
+ # TODO: Specify only the data we want to pass on when creating the RepoMirrorConfig. Avoid
+ # the possibility of data injection.
- rule = model.repo_mirror.create_rule(repo, data['root_rule']['rule_value'])
- del data['root_rule']
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- # Verify the robot is part of the Repository's namespace
- robot = self._setup_robot_for_mirroring(namespace_name, repository_name, data['robot_username'])
- del data['robot_username']
+ if model.repo_mirror.get_mirror(repo):
+ return (
+ {
+ "detail": "Mirror configuration already exits for repository %s/%s"
+ % (namespace_name, repository_name)
+ },
+ 409,
+ )
- mirror = model.repo_mirror.enable_mirroring_for_repository(repo, root_rule=rule,
- internal_robot=robot, **data)
- if mirror:
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_reference', to=data['external_reference'])
- return '', 201
- else:
- # TODO: Determine appropriate Response
- return {'detail': 'RepoMirrorConfig already exists for this repository.'}, 409
+ data = request.get_json()
- @require_repo_admin
- @validate_json_request('UpdateMirrorConfig')
- @nickname('changeRepoMirrorConfig')
- def put(self, namespace_name, repository_name):
- """ Allow users to modifying the repository's mirroring configuration. """
- values = request.get_json()
+ data["sync_start_date"] = self._string_to_dt(data["sync_start_date"])
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
+ rule = model.repo_mirror.create_rule(repo, data["root_rule"]["rule_value"])
+ del data["root_rule"]
- mirror = model.repo_mirror.get_mirror(repo)
- if not mirror:
- raise NotFound()
+ # Verify the robot is part of the Repository's namespace
+ robot = self._setup_robot_for_mirroring(
+ namespace_name, repository_name, data["robot_username"]
+ )
+ del data["robot_username"]
- if 'is_enabled' in values:
- if values['is_enabled'] == True:
- if model.repo_mirror.enable_mirror(repo):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='is_enabled', to=True)
- if values['is_enabled'] == False:
- if model.repo_mirror.disable_mirror(repo):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='is_enabled', to=False)
-
- if 'external_reference' in values:
- if values['external_reference'] == '':
- return {'detail': 'Empty string is an invalid repository location.'}, 400
- if model.repo_mirror.change_remote(repo, values['external_reference']):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_reference', to=values['external_reference'])
-
- if 'robot_username' in values:
- robot_username = values['robot_username']
- robot = self._setup_robot_for_mirroring(namespace_name, repository_name, robot_username)
- if model.repo_mirror.set_mirroring_robot(repo, robot):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='robot_username', to=robot_username)
-
- if 'sync_start_date' in values:
- try:
- sync_start_date = self._string_to_dt(values['sync_start_date'])
- except ValueError as e:
- return {'detail': 'Incorrect DateTime format for sync_start_date.'}, 400
- if model.repo_mirror.change_sync_start_date(repo, sync_start_date):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='sync_start_date', to=sync_start_date)
-
- if 'sync_interval' in values:
- if model.repo_mirror.change_sync_interval(repo, values['sync_interval']):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='sync_interval', to=values['sync_interval'])
-
- if 'external_registry_username' in values and 'external_registry_password' in values:
- username = values['external_registry_username']
- password = values['external_registry_password']
- if username is None and password is not None:
- return {'detail': 'Unable to delete username while setting a password.'}, 400
- if model.repo_mirror.change_credentials(repo, username, password):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_registry_username', to=username)
- if password is None:
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_registry_password', to=None)
+ mirror = model.repo_mirror.enable_mirroring_for_repository(
+ repo, root_rule=rule, internal_robot=robot, **data
+ )
+ if mirror:
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_reference",
+ to=data["external_reference"],
+ )
+ return "", 201
else:
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_registry_password', to="********")
+ # TODO: Determine appropriate Response
+ return (
+ {"detail": "RepoMirrorConfig already exists for this repository."},
+ 409,
+ )
- elif 'external_registry_username' in values:
- username = values['external_registry_username']
- if model.repo_mirror.change_username(repo, username):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='external_registry_username', to=username)
+ @require_repo_admin
+ @validate_json_request("UpdateMirrorConfig")
+ @nickname("changeRepoMirrorConfig")
+ def put(self, namespace_name, repository_name):
+ """ Allow users to modifying the repository's mirroring configuration. """
+ values = request.get_json()
- # Do not allow specifying a password without setting a username
- if 'external_registry_password' in values and 'external_registry_username' not in values:
- return {'detail': 'Unable to set a new password without also specifying a username.'}, 400
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- if 'external_registry_config' in values:
- external_registry_config = values.get('external_registry_config', {})
+ mirror = model.repo_mirror.get_mirror(repo)
+ if not mirror:
+ raise NotFound()
- if 'verify_tls' in external_registry_config:
- updates = {'verify_tls': external_registry_config['verify_tls']}
- if model.repo_mirror.change_external_registry_config(repo, updates):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='verify_tls', to=external_registry_config['verify_tls'])
+ if "is_enabled" in values:
+ if values["is_enabled"] == True:
+ if model.repo_mirror.enable_mirror(repo):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="is_enabled",
+ to=True,
+ )
+ if values["is_enabled"] == False:
+ if model.repo_mirror.disable_mirror(repo):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="is_enabled",
+ to=False,
+ )
- if 'proxy' in external_registry_config:
- proxy_values = external_registry_config.get('proxy', {})
+ if "external_reference" in values:
+ if values["external_reference"] == "":
+ return (
+ {"detail": "Empty string is an invalid repository location."},
+ 400,
+ )
+ if model.repo_mirror.change_remote(repo, values["external_reference"]):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_reference",
+ to=values["external_reference"],
+ )
- if 'http_proxy' in proxy_values:
- updates = {'proxy': {'http_proxy': proxy_values['http_proxy']}}
- if model.repo_mirror.change_external_registry_config(repo, updates):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='http_proxy', to=proxy_values['http_proxy'])
+ if "robot_username" in values:
+ robot_username = values["robot_username"]
+ robot = self._setup_robot_for_mirroring(
+ namespace_name, repository_name, robot_username
+ )
+ if model.repo_mirror.set_mirroring_robot(repo, robot):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="robot_username",
+ to=robot_username,
+ )
- if 'https_proxy' in proxy_values:
- updates = {'proxy': {'https_proxy': proxy_values['https_proxy']}}
- if model.repo_mirror.change_external_registry_config(repo, updates):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='https_proxy', to=proxy_values['https_proxy'])
+ if "sync_start_date" in values:
+ try:
+ sync_start_date = self._string_to_dt(values["sync_start_date"])
+ except ValueError as e:
+ return {"detail": "Incorrect DateTime format for sync_start_date."}, 400
+ if model.repo_mirror.change_sync_start_date(repo, sync_start_date):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="sync_start_date",
+ to=sync_start_date,
+ )
- if 'no_proxy' in proxy_values:
- updates = {'proxy': {'no_proxy': proxy_values['no_proxy']}}
- if model.repo_mirror.change_external_registry_config(repo, updates):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed='no_proxy', to=proxy_values['no_proxy'])
+ if "sync_interval" in values:
+ if model.repo_mirror.change_sync_interval(repo, values["sync_interval"]):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="sync_interval",
+ to=values["sync_interval"],
+ )
- return '', 201
+ if (
+ "external_registry_username" in values
+ and "external_registry_password" in values
+ ):
+ username = values["external_registry_username"]
+ password = values["external_registry_password"]
+ if username is None and password is not None:
+ return (
+ {"detail": "Unable to delete username while setting a password."},
+ 400,
+ )
+ if model.repo_mirror.change_credentials(repo, username, password):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_registry_username",
+ to=username,
+ )
+ if password is None:
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_registry_password",
+ to=None,
+ )
+ else:
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_registry_password",
+ to="********",
+ )
- def _setup_robot_for_mirroring(self, namespace_name, repo_name, robot_username):
- """ Validate robot exists and give write permissions. """
- robot = model.user.lookup_robot(robot_username)
- assert robot.robot
+ elif "external_registry_username" in values:
+ username = values["external_registry_username"]
+ if model.repo_mirror.change_username(repo, username):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="external_registry_username",
+ to=username,
+ )
- namespace, _ = parse_robot_username(robot_username)
- if namespace != namespace_name:
- raise model.DataModelException('Invalid robot')
+ # Do not allow specifying a password without setting a username
+ if (
+ "external_registry_password" in values
+ and "external_registry_username" not in values
+ ):
+ return (
+ {
+ "detail": "Unable to set a new password without also specifying a username."
+ },
+ 400,
+ )
- # Ensure the robot specified has access to the repository. If not, grant it.
- permissions = model.permission.get_user_repository_permissions(robot, namespace_name, repo_name)
- if not permissions or permissions[0].role.name == 'read':
- model.permission.set_user_repo_permission(robot.username, namespace_name, repo_name, 'write')
+ if "external_registry_config" in values:
+ external_registry_config = values.get("external_registry_config", {})
- return robot
+ if "verify_tls" in external_registry_config:
+ updates = {"verify_tls": external_registry_config["verify_tls"]}
+ if model.repo_mirror.change_external_registry_config(repo, updates):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="verify_tls",
+ to=external_registry_config["verify_tls"],
+ )
- def _string_to_dt(self, string):
- """ Convert String to correct DateTime format. """
- if string is None:
- return None
+ if "proxy" in external_registry_config:
+ proxy_values = external_registry_config.get("proxy", {})
- """
+ if "http_proxy" in proxy_values:
+ updates = {"proxy": {"http_proxy": proxy_values["http_proxy"]}}
+ if model.repo_mirror.change_external_registry_config(repo, updates):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="http_proxy",
+ to=proxy_values["http_proxy"],
+ )
+
+ if "https_proxy" in proxy_values:
+ updates = {"proxy": {"https_proxy": proxy_values["https_proxy"]}}
+ if model.repo_mirror.change_external_registry_config(repo, updates):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="https_proxy",
+ to=proxy_values["https_proxy"],
+ )
+
+ if "no_proxy" in proxy_values:
+ updates = {"proxy": {"no_proxy": proxy_values["no_proxy"]}}
+ if model.repo_mirror.change_external_registry_config(repo, updates):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="no_proxy",
+ to=proxy_values["no_proxy"],
+ )
+
+ return "", 201
+
+ def _setup_robot_for_mirroring(self, namespace_name, repo_name, robot_username):
+ """ Validate robot exists and give write permissions. """
+ robot = model.user.lookup_robot(robot_username)
+ assert robot.robot
+
+ namespace, _ = parse_robot_username(robot_username)
+ if namespace != namespace_name:
+ raise model.DataModelException("Invalid robot")
+
+ # Ensure the robot specified has access to the repository. If not, grant it.
+ permissions = model.permission.get_user_repository_permissions(
+ robot, namespace_name, repo_name
+ )
+ if not permissions or permissions[0].role.name == "read":
+ model.permission.set_user_repo_permission(
+ robot.username, namespace_name, repo_name, "write"
+ )
+
+ return robot
+
+ def _string_to_dt(self, string):
+ """ Convert String to correct DateTime format. """
+ if string is None:
+ return None
+
+ """
# TODO: Use RFC2822. This doesn't work consistently.
# TODO: Move this to same module as `format_date` once fixed.
tup = parsedate_tz(string)
@@ -401,67 +525,71 @@ class RepoMirrorResource(RepositoryParamResource):
dt = datetime.fromtimestamp(ts, pytz.UTC)
return dt
"""
- assert isinstance(string, (str, unicode))
- dt = datetime.strptime(string, "%Y-%m-%dT%H:%M:%SZ")
- return dt
+ assert isinstance(string, (str, unicode))
+ dt = datetime.strptime(string, "%Y-%m-%dT%H:%M:%SZ")
+ return dt
- def _dt_to_string(self, dt):
- """ Convert DateTime to correctly formatted String."""
- if dt is None:
- return None
+ def _dt_to_string(self, dt):
+ """ Convert DateTime to correctly formatted String."""
+ if dt is None:
+ return None
- """
+ """
# TODO: Use RFC2822. Need to make it work bi-directionally.
return format_date(dt)
"""
- assert isinstance(dt, datetime)
- string = dt.isoformat() + 'Z'
- return string
+ assert isinstance(dt, datetime)
+ string = dt.isoformat() + "Z"
+ return string
- def _decrypt_username(self, username):
- if username is None:
- return None
- return username.decrypt()
+ def _decrypt_username(self, username):
+ if username is None:
+ return None
+ return username.decrypt()
-@resource('/v1/repository//mirror/rules')
+@resource("/v1/repository//mirror/rules")
@show_if(features.REPO_MIRROR)
class ManageRepoMirrorRule(RepositoryParamResource):
- """
+ """
Operations to manage a single Repository Mirroring Rule.
TODO: At the moment, we are only dealing with a single rule associated with the mirror.
This should change to update the rule and address it using its UUID.
"""
- schemas = {
- 'MirrorRule': {
- 'type': 'object',
- 'description': 'A rule used to define how a repository is mirrored.',
- 'required': ['root_rule'],
- 'properties': {
- 'root_rule': common_properties['root_rule']
- }
- }
- }
- @require_repo_admin
- @nickname('changeRepoMirrorRule')
- @validate_json_request('MirrorRule')
- def put(self, namespace_name, repository_name):
- """
+ schemas = {
+ "MirrorRule": {
+ "type": "object",
+ "description": "A rule used to define how a repository is mirrored.",
+ "required": ["root_rule"],
+ "properties": {"root_rule": common_properties["root_rule"]},
+ }
+ }
+
+ @require_repo_admin
+ @nickname("changeRepoMirrorRule")
+ @validate_json_request("MirrorRule")
+ def put(self, namespace_name, repository_name):
+ """
Update an existing RepoMirrorRule
"""
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- raise NotFound()
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ raise NotFound()
- rule = model.repo_mirror.get_root_rule(repo)
- if not rule:
- return {'detail': 'The rule appears to be missing.'}, 400
+ rule = model.repo_mirror.get_root_rule(repo)
+ if not rule:
+ return {"detail": "The rule appears to be missing."}, 400
- data = request.get_json()
- if model.repo_mirror.change_rule_value(rule, data['root_rule']['rule_value']):
- track_and_log('repo_mirror_config_changed', wrap_repository(repo), changed="mirror_rule", to=data['root_rule']['rule_value'])
- return 200
- else:
- return {'detail': 'Unable to update rule.'}, 400
+ data = request.get_json()
+ if model.repo_mirror.change_rule_value(rule, data["root_rule"]["rule_value"]):
+ track_and_log(
+ "repo_mirror_config_changed",
+ wrap_repository(repo),
+ changed="mirror_rule",
+ to=data["root_rule"]["rule_value"],
+ )
+ return 200
+ else:
+ return {"detail": "Unable to update rule."}, 400
diff --git a/endpoints/api/organization.py b/endpoints/api/organization.py
index e53bba6b9..4df3b87db 100644
--- a/endpoints/api/organization.py
+++ b/endpoints/api/organization.py
@@ -8,15 +8,38 @@ from flask import request
import features
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from app import (billing as stripe, avatar, all_queues, authentication, namespace_gc_queue,
- ip_resolver, app)
-from endpoints.api import (resource, nickname, ApiResource, validate_json_request, request_error,
- related_user_resource, internal_only, require_user_admin, log_action,
- show_if, path_param, require_scope, require_fresh_login)
+from app import (
+ billing as stripe,
+ avatar,
+ all_queues,
+ authentication,
+ namespace_gc_queue,
+ ip_resolver,
+ app,
+)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ validate_json_request,
+ request_error,
+ related_user_resource,
+ internal_only,
+ require_user_admin,
+ log_action,
+ show_if,
+ path_param,
+ require_scope,
+ require_fresh_login,
+)
from endpoints.exception import Unauthorized, NotFound
from endpoints.api.user import User, PrivateRepositories
-from auth.permissions import (AdministerOrganizationPermission, OrganizationMemberPermission,
- CreateRepositoryPermission, ViewTeamPermission)
+from auth.permissions import (
+ AdministerOrganizationPermission,
+ OrganizationMemberPermission,
+ CreateRepositoryPermission,
+ ViewTeamPermission,
+)
from auth.auth_context import get_authenticated_user
from auth import scopes
from data import model
@@ -29,712 +52,742 @@ logger = logging.getLogger(__name__)
def team_view(orgname, team):
- return {
- 'name': team.name,
- 'description': team.description,
- 'role': team.role_name,
- 'avatar': avatar.get_data_for_team(team),
- 'can_view': ViewTeamPermission(orgname, team.name).can(),
-
- 'repo_count': team.repo_count,
- 'member_count': team.member_count,
-
- 'is_synced': team.is_synced,
- }
+ return {
+ "name": team.name,
+ "description": team.description,
+ "role": team.role_name,
+ "avatar": avatar.get_data_for_team(team),
+ "can_view": ViewTeamPermission(orgname, team.name).can(),
+ "repo_count": team.repo_count,
+ "member_count": team.member_count,
+ "is_synced": team.is_synced,
+ }
def org_view(o, teams):
- is_admin = AdministerOrganizationPermission(o.username).can()
- is_member = OrganizationMemberPermission(o.username).can()
+ is_admin = AdministerOrganizationPermission(o.username).can()
+ is_member = OrganizationMemberPermission(o.username).can()
- view = {
- 'name': o.username,
- 'email': o.email if is_admin else '',
- 'avatar': avatar.get_data_for_user(o),
- 'is_admin': is_admin,
- 'is_member': is_member
- }
+ view = {
+ "name": o.username,
+ "email": o.email if is_admin else "",
+ "avatar": avatar.get_data_for_user(o),
+ "is_admin": is_admin,
+ "is_member": is_member,
+ }
- if teams is not None:
- teams = sorted(teams, key=lambda team: team.id)
- view['teams'] = {t.name : team_view(o.username, t) for t in teams}
- view['ordered_teams'] = [team.name for team in teams]
+ if teams is not None:
+ teams = sorted(teams, key=lambda team: team.id)
+ view["teams"] = {t.name: team_view(o.username, t) for t in teams}
+ view["ordered_teams"] = [team.name for team in teams]
- if is_admin:
- view['invoice_email'] = o.invoice_email
- view['invoice_email_address'] = o.invoice_email_address
- view['tag_expiration_s'] = o.removed_tag_expiration_s
- view['is_free_account'] = o.stripe_id is None
+ if is_admin:
+ view["invoice_email"] = o.invoice_email
+ view["invoice_email_address"] = o.invoice_email_address
+ view["tag_expiration_s"] = o.removed_tag_expiration_s
+ view["is_free_account"] = o.stripe_id is None
- return view
+ return view
-@resource('/v1/organization/')
+@resource("/v1/organization/")
class OrganizationList(ApiResource):
- """ Resource for creating organizations. """
- schemas = {
- 'NewOrg': {
- 'type': 'object',
- 'description': 'Description of a new organization.',
- 'required': [
- 'name',
- ],
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'Organization username',
- },
- 'email': {
- 'type': 'string',
- 'description': 'Organization contact email',
- },
- 'recaptcha_response': {
- 'type': 'string',
- 'description': 'The (may be disabled) recaptcha response code for verification',
- },
- },
- },
- }
+ """ Resource for creating organizations. """
- @require_user_admin
- @nickname('createOrganization')
- @validate_json_request('NewOrg')
- def post(self):
- """ Create a new organization. """
- user = get_authenticated_user()
- org_data = request.get_json()
- existing = None
+ schemas = {
+ "NewOrg": {
+ "type": "object",
+ "description": "Description of a new organization.",
+ "required": ["name"],
+ "properties": {
+ "name": {"type": "string", "description": "Organization username"},
+ "email": {
+ "type": "string",
+ "description": "Organization contact email",
+ },
+ "recaptcha_response": {
+ "type": "string",
+ "description": "The (may be disabled) recaptcha response code for verification",
+ },
+ },
+ }
+ }
- try:
- existing = model.organization.get_organization(org_data['name'])
- except model.InvalidOrganizationException:
- pass
+ @require_user_admin
+ @nickname("createOrganization")
+ @validate_json_request("NewOrg")
+ def post(self):
+ """ Create a new organization. """
+ user = get_authenticated_user()
+ org_data = request.get_json()
+ existing = None
- if not existing:
- existing = model.user.get_user(org_data['name'])
+ try:
+ existing = model.organization.get_organization(org_data["name"])
+ except model.InvalidOrganizationException:
+ pass
- if existing:
- msg = 'A user or organization with this name already exists'
- raise request_error(message=msg)
+ if not existing:
+ existing = model.user.get_user(org_data["name"])
- if features.MAILING and not org_data.get('email'):
- raise request_error(message='Email address is required')
+ if existing:
+ msg = "A user or organization with this name already exists"
+ raise request_error(message=msg)
- # If recaptcha is enabled, then verify the user is a human.
- if features.RECAPTCHA:
- recaptcha_response = org_data.get('recaptcha_response', '')
- result = recaptcha2.verify(app.config['RECAPTCHA_SECRET_KEY'],
- recaptcha_response,
- get_request_ip())
+ if features.MAILING and not org_data.get("email"):
+ raise request_error(message="Email address is required")
- if not result['success']:
- return {
- 'message': 'Are you a bot? If not, please revalidate the captcha.'
- }, 400
+ # If recaptcha is enabled, then verify the user is a human.
+ if features.RECAPTCHA:
+ recaptcha_response = org_data.get("recaptcha_response", "")
+ result = recaptcha2.verify(
+ app.config["RECAPTCHA_SECRET_KEY"], recaptcha_response, get_request_ip()
+ )
- is_possible_abuser = ip_resolver.is_ip_possible_threat(get_request_ip())
- try:
- model.organization.create_organization(org_data['name'], org_data.get('email'), user,
- email_required=features.MAILING,
- is_possible_abuser=is_possible_abuser)
- return 'Created', 201
- except model.DataModelException as ex:
- raise request_error(exception=ex)
+ if not result["success"]:
+ return (
+ {
+ "message": "Are you a bot? If not, please revalidate the captcha."
+ },
+ 400,
+ )
+
+ is_possible_abuser = ip_resolver.is_ip_possible_threat(get_request_ip())
+ try:
+ model.organization.create_organization(
+ org_data["name"],
+ org_data.get("email"),
+ user,
+ email_required=features.MAILING,
+ is_possible_abuser=is_possible_abuser,
+ )
+ return "Created", 201
+ except model.DataModelException as ex:
+ raise request_error(exception=ex)
-@resource('/v1/organization/')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization/")
+@path_param("orgname", "The name of the organization")
@related_user_resource(User)
class Organization(ApiResource):
- """ Resource for managing organizations. """
- schemas = {
- 'UpdateOrg': {
- 'type': 'object',
- 'description': 'Description of updates for an existing organization',
- 'properties': {
- 'email': {
- 'type': 'string',
- 'description': 'Organization contact email',
- },
- 'invoice_email': {
- 'type': 'boolean',
- 'description': 'Whether the organization desires to receive emails for invoices',
- },
- 'invoice_email_address': {
- 'type': ['string', 'null'],
- 'description': 'The email address at which to receive invoices',
- },
- 'tag_expiration_s': {
- 'type': 'integer',
- 'minimum': 0,
- 'description': 'The number of seconds for tag expiration',
- },
- },
- },
- }
+ """ Resource for managing organizations. """
- @nickname('getOrganization')
- def get(self, orgname):
- """ Get the details for the specified organization """
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ schemas = {
+ "UpdateOrg": {
+ "type": "object",
+ "description": "Description of updates for an existing organization",
+ "properties": {
+ "email": {
+ "type": "string",
+ "description": "Organization contact email",
+ },
+ "invoice_email": {
+ "type": "boolean",
+ "description": "Whether the organization desires to receive emails for invoices",
+ },
+ "invoice_email_address": {
+ "type": ["string", "null"],
+ "description": "The email address at which to receive invoices",
+ },
+ "tag_expiration_s": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "The number of seconds for tag expiration",
+ },
+ },
+ }
+ }
- teams = None
- if OrganizationMemberPermission(orgname).can():
- has_syncing = features.TEAM_SYNCING and bool(authentication.federated_service)
- teams = model.team.get_teams_within_org(org, has_syncing)
+ @nickname("getOrganization")
+ def get(self, orgname):
+ """ Get the details for the specified organization """
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- return org_view(org, teams)
+ teams = None
+ if OrganizationMemberPermission(orgname).can():
+ has_syncing = features.TEAM_SYNCING and bool(
+ authentication.federated_service
+ )
+ teams = model.team.get_teams_within_org(org, has_syncing)
+
+ return org_view(org, teams)
+
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("changeOrganizationDetails")
+ @validate_json_request("UpdateOrg")
+ def put(self, orgname):
+ """ Change the details for the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
+
+ org_data = request.get_json()
+ if "invoice_email" in org_data:
+ logger.debug(
+ "Changing invoice_email for organization: %s", org.username
+ )
+ model.user.change_send_invoice_email(org, org_data["invoice_email"])
+
+ if (
+ "invoice_email_address" in org_data
+ and org_data["invoice_email_address"] != org.invoice_email_address
+ ):
+ new_email = org_data["invoice_email_address"]
+ logger.debug(
+ "Changing invoice email address for organization: %s", org.username
+ )
+ model.user.change_invoice_email_address(org, new_email)
+
+ if "email" in org_data and org_data["email"] != org.email:
+ new_email = org_data["email"]
+ if model.user.find_user_by_email(new_email):
+ raise request_error(message="E-mail address already used")
+
+ logger.debug(
+ "Changing email address for organization: %s", org.username
+ )
+ model.user.update_email(org, new_email)
+
+ if features.CHANGE_TAG_EXPIRATION and "tag_expiration_s" in org_data:
+ logger.debug(
+ "Changing organization tag expiration to: %ss",
+ org_data["tag_expiration_s"],
+ )
+ model.user.change_user_tag_expiration(org, org_data["tag_expiration_s"])
+
+ teams = model.team.get_teams_within_org(org)
+ return org_view(org, teams)
+ raise Unauthorized()
+
+ @require_scope(scopes.ORG_ADMIN)
+ @require_fresh_login
+ @nickname("deleteAdminedOrganization")
+ def delete(self, orgname):
+ """ Deletes the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
+
+ model.user.mark_namespace_for_deletion(org, all_queues, namespace_gc_queue)
+ return "", 204
+
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('changeOrganizationDetails')
- @validate_json_request('UpdateOrg')
- def put(self, orgname):
- """ Change the details for the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
-
- org_data = request.get_json()
- if 'invoice_email' in org_data:
- logger.debug('Changing invoice_email for organization: %s', org.username)
- model.user.change_send_invoice_email(org, org_data['invoice_email'])
-
- if ('invoice_email_address' in org_data and
- org_data['invoice_email_address'] != org.invoice_email_address):
- new_email = org_data['invoice_email_address']
- logger.debug('Changing invoice email address for organization: %s', org.username)
- model.user.change_invoice_email_address(org, new_email)
-
- if 'email' in org_data and org_data['email'] != org.email:
- new_email = org_data['email']
- if model.user.find_user_by_email(new_email):
- raise request_error(message='E-mail address already used')
-
- logger.debug('Changing email address for organization: %s', org.username)
- model.user.update_email(org, new_email)
-
- if features.CHANGE_TAG_EXPIRATION and 'tag_expiration_s' in org_data:
- logger.debug('Changing organization tag expiration to: %ss', org_data['tag_expiration_s'])
- model.user.change_user_tag_expiration(org, org_data['tag_expiration_s'])
-
- teams = model.team.get_teams_within_org(org)
- return org_view(org, teams)
- raise Unauthorized()
-
-
- @require_scope(scopes.ORG_ADMIN)
- @require_fresh_login
- @nickname('deleteAdminedOrganization')
- def delete(self, orgname):
- """ Deletes the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
-
- model.user.mark_namespace_for_deletion(org, all_queues, namespace_gc_queue)
- return '', 204
-
- raise Unauthorized()
-
-
-@resource('/v1/organization//private')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//private")
+@path_param("orgname", "The name of the organization")
@internal_only
@related_user_resource(PrivateRepositories)
@show_if(features.BILLING)
class OrgPrivateRepositories(ApiResource):
- """ Custom verb to compute whether additional private repositories are available. """
+ """ Custom verb to compute whether additional private repositories are available. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationPrivateAllowed')
- def get(self, orgname):
- """ Return whether or not this org is allowed to create new private repositories. """
- permission = CreateRepositoryPermission(orgname)
- if permission.can():
- organization = model.organization.get_organization(orgname)
- private_repos = model.user.get_private_repo_count(organization.username)
- data = {
- 'privateAllowed': False
- }
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationPrivateAllowed")
+ def get(self, orgname):
+ """ Return whether or not this org is allowed to create new private repositories. """
+ permission = CreateRepositoryPermission(orgname)
+ if permission.can():
+ organization = model.organization.get_organization(orgname)
+ private_repos = model.user.get_private_repo_count(organization.username)
+ data = {"privateAllowed": False}
- if organization.stripe_id:
- cus = stripe.Customer.retrieve(organization.stripe_id)
- if cus.subscription:
- repos_allowed = 0
- plan = get_plan(cus.subscription.plan.id)
- if plan:
- repos_allowed = plan['privateRepos']
+ if organization.stripe_id:
+ cus = stripe.Customer.retrieve(organization.stripe_id)
+ if cus.subscription:
+ repos_allowed = 0
+ plan = get_plan(cus.subscription.plan.id)
+ if plan:
+ repos_allowed = plan["privateRepos"]
- data['privateAllowed'] = (private_repos < repos_allowed)
+ data["privateAllowed"] = private_repos < repos_allowed
+
+ if AdministerOrganizationPermission(orgname).can():
+ data["privateCount"] = private_repos
+
+ return data
+
+ raise Unauthorized()
- if AdministerOrganizationPermission(orgname).can():
- data['privateCount'] = private_repos
-
- return data
-
- raise Unauthorized()
-
-
-@resource('/v1/organization//collaborators')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//collaborators")
+@path_param("orgname", "The name of the organization")
class OrganizationCollaboratorList(ApiResource):
- """ Resource for listing outside collaborators of an organization.
+ """ Resource for listing outside collaborators of an organization.
Collaborators are users that do not belong to any team in the
organiztion, but who have direct permissions on one or more
repositories belonging to the organization.
"""
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationCollaborators')
- def get(self, orgname):
- """ List outside collaborators of the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if not permission.can():
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationCollaborators")
+ def get(self, orgname):
+ """ List outside collaborators of the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if not permission.can():
+ raise Unauthorized()
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- all_perms = model.permission.list_organization_member_permissions(org)
- membership = model.team.list_organization_members_by_teams(org)
+ all_perms = model.permission.list_organization_member_permissions(org)
+ membership = model.team.list_organization_members_by_teams(org)
- org_members = set(m.user.username for m in membership)
+ org_members = set(m.user.username for m in membership)
- collaborators = {}
- for perm in all_perms:
- username = perm.user.username
+ collaborators = {}
+ for perm in all_perms:
+ username = perm.user.username
- # Only interested in non-member permissions.
- if username in org_members:
- continue
+ # Only interested in non-member permissions.
+ if username in org_members:
+ continue
- if username not in collaborators:
- collaborators[username] = {
- 'kind': 'user',
- 'name': username,
- 'avatar': avatar.get_data_for_user(perm.user),
- 'repositories': [],
- }
+ if username not in collaborators:
+ collaborators[username] = {
+ "kind": "user",
+ "name": username,
+ "avatar": avatar.get_data_for_user(perm.user),
+ "repositories": [],
+ }
- collaborators[username]['repositories'].append(perm.repository.name)
+ collaborators[username]["repositories"].append(perm.repository.name)
- return {'collaborators': collaborators.values()}
+ return {"collaborators": collaborators.values()}
-@resource('/v1/organization//members')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//members")
+@path_param("orgname", "The name of the organization")
class OrganizationMemberList(ApiResource):
- """ Resource for listing the members of an organization. """
+ """ Resource for listing the members of an organization. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationMembers')
- def get(self, orgname):
- """ List the human members of the specified organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationMembers")
+ def get(self, orgname):
+ """ List the human members of the specified organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- # Loop to create the members dictionary. Note that the members collection
- # will return an entry for *every team* a member is on, so we will have
- # duplicate keys (which is why we pre-build the dictionary).
- members_dict = {}
- members = model.team.list_organization_members_by_teams(org)
- for member in members:
- if member.user.robot:
- continue
+ # Loop to create the members dictionary. Note that the members collection
+ # will return an entry for *every team* a member is on, so we will have
+ # duplicate keys (which is why we pre-build the dictionary).
+ members_dict = {}
+ members = model.team.list_organization_members_by_teams(org)
+ for member in members:
+ if member.user.robot:
+ continue
- if not member.user.username in members_dict:
- member_data = {
- 'name': member.user.username,
- 'kind': 'user',
- 'avatar': avatar.get_data_for_user(member.user),
- 'teams': [],
- 'repositories': []
- }
+ if not member.user.username in members_dict:
+ member_data = {
+ "name": member.user.username,
+ "kind": "user",
+ "avatar": avatar.get_data_for_user(member.user),
+ "teams": [],
+ "repositories": [],
+ }
- members_dict[member.user.username] = member_data
+ members_dict[member.user.username] = member_data
- members_dict[member.user.username]['teams'].append({
- 'name': member.team.name,
- 'avatar': avatar.get_data_for_team(member.team),
- })
+ members_dict[member.user.username]["teams"].append(
+ {
+ "name": member.team.name,
+ "avatar": avatar.get_data_for_team(member.team),
+ }
+ )
- # Loop to add direct repository permissions.
- for permission in model.permission.list_organization_member_permissions(org):
- username = permission.user.username
- if not username in members_dict:
- continue
+ # Loop to add direct repository permissions.
+ for permission in model.permission.list_organization_member_permissions(
+ org
+ ):
+ username = permission.user.username
+ if not username in members_dict:
+ continue
- members_dict[username]['repositories'].append(permission.repository.name)
+ members_dict[username]["repositories"].append(
+ permission.repository.name
+ )
- return {'members': members_dict.values()}
+ return {"members": members_dict.values()}
- raise Unauthorized()
+ raise Unauthorized()
-
-@resource('/v1/organization//members/')
-@path_param('orgname', 'The name of the organization')
-@path_param('membername', 'The username of the organization member')
+@resource("/v1/organization//members/")
+@path_param("orgname", "The name of the organization")
+@path_param("membername", "The username of the organization member")
class OrganizationMember(ApiResource):
- """ Resource for managing individual organization members. """
+ """ Resource for managing individual organization members. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationMember')
- def get(self, orgname, membername):
- """ Retrieves the details of a member of the organization.
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationMember")
+ def get(self, orgname, membername):
+ """ Retrieves the details of a member of the organization.
"""
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- # Lookup the user.
- member = model.user.get_user(membername)
- if not member:
- raise NotFound()
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ # Lookup the user.
+ member = model.user.get_user(membername)
+ if not member:
+ raise NotFound()
- organization = model.user.get_user_or_org(orgname)
- if not organization:
- raise NotFound()
+ organization = model.user.get_user_or_org(orgname)
+ if not organization:
+ raise NotFound()
- # Lookup the user's information in the organization.
- teams = list(model.team.get_user_teams_within_org(membername, organization))
- if not teams:
- # 404 if the user is not a robot under the organization, as that means the referenced
- # user or robot is not a member of this organization.
- if not member.robot:
- raise NotFound()
+ # Lookup the user's information in the organization.
+ teams = list(model.team.get_user_teams_within_org(membername, organization))
+ if not teams:
+ # 404 if the user is not a robot under the organization, as that means the referenced
+ # user or robot is not a member of this organization.
+ if not member.robot:
+ raise NotFound()
- namespace, _ = parse_robot_username(member.username)
- if namespace != orgname:
- raise NotFound()
+ namespace, _ = parse_robot_username(member.username)
+ if namespace != orgname:
+ raise NotFound()
- repo_permissions = model.permission.list_organization_member_permissions(organization, member)
+ repo_permissions = model.permission.list_organization_member_permissions(
+ organization, member
+ )
- def local_team_view(team):
- return {
- 'name': team.name,
- 'avatar': avatar.get_data_for_team(team),
- }
+ def local_team_view(team):
+ return {"name": team.name, "avatar": avatar.get_data_for_team(team)}
- return {
- 'name': member.username,
- 'kind': 'robot' if member.robot else 'user',
- 'avatar': avatar.get_data_for_user(member),
- 'teams': [local_team_view(team) for team in teams],
- 'repositories': [permission.repository.name for permission in repo_permissions]
- }
+ return {
+ "name": member.username,
+ "kind": "robot" if member.robot else "user",
+ "avatar": avatar.get_data_for_user(member),
+ "teams": [local_team_view(team) for team in teams],
+ "repositories": [
+ permission.repository.name for permission in repo_permissions
+ ],
+ }
- raise Unauthorized()
+ raise Unauthorized()
-
- @require_scope(scopes.ORG_ADMIN)
- @nickname('removeOrganizationMember')
- def delete(self, orgname, membername):
- """ Removes a member from an organization, revoking all its repository
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("removeOrganizationMember")
+ def delete(self, orgname, membername):
+ """ Removes a member from an organization, revoking all its repository
priviledges and removing it from all teams in the organization.
"""
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- # Lookup the user.
- user = model.user.get_nonrobot_user(membername)
- if not user:
- raise NotFound()
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ # Lookup the user.
+ user = model.user.get_nonrobot_user(membername)
+ if not user:
+ raise NotFound()
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- # Remove the user from the organization.
- model.organization.remove_organization_member(org, user)
- return '', 204
+ # Remove the user from the organization.
+ model.organization.remove_organization_member(org, user)
+ return "", 204
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/app/')
-@path_param('client_id', 'The OAuth client ID')
+@resource("/v1/app/")
+@path_param("client_id", "The OAuth client ID")
class ApplicationInformation(ApiResource):
- """ Resource that returns public information about a registered application. """
+ """ Resource that returns public information about a registered application. """
- @nickname('getApplicationInformation')
- def get(self, client_id):
- """ Get information on the specified application. """
- application = model.oauth.get_application_for_client_id(client_id)
- if not application:
- raise NotFound()
+ @nickname("getApplicationInformation")
+ def get(self, client_id):
+ """ Get information on the specified application. """
+ application = model.oauth.get_application_for_client_id(client_id)
+ if not application:
+ raise NotFound()
- app_email = application.avatar_email or application.organization.email
- app_data = avatar.get_data(application.name, app_email, 'app')
+ app_email = application.avatar_email or application.organization.email
+ app_data = avatar.get_data(application.name, app_email, "app")
- return {
- 'name': application.name,
- 'description': application.description,
- 'uri': application.application_uri,
- 'avatar': app_data,
- 'organization': org_view(application.organization, [])
- }
+ return {
+ "name": application.name,
+ "description": application.description,
+ "uri": application.application_uri,
+ "avatar": app_data,
+ "organization": org_view(application.organization, []),
+ }
def app_view(application):
- is_admin = AdministerOrganizationPermission(application.organization.username).can()
- client_secret = None
- if is_admin:
- # TODO(remove-unenc): Remove legacy lookup.
+ is_admin = AdministerOrganizationPermission(application.organization.username).can()
client_secret = None
- if application.secure_client_secret is not None:
- client_secret = application.secure_client_secret.decrypt()
+ if is_admin:
+ # TODO(remove-unenc): Remove legacy lookup.
+ client_secret = None
+ if application.secure_client_secret is not None:
+ client_secret = application.secure_client_secret.decrypt()
- if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS) and client_secret is None:
- client_secret = application.client_secret
+ if (
+ ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS)
+ and client_secret is None
+ ):
+ client_secret = application.client_secret
- assert (client_secret is not None) == is_admin
- return {
- 'name': application.name,
- 'description': application.description,
- 'application_uri': application.application_uri,
- 'client_id': application.client_id,
- 'client_secret': client_secret,
- 'redirect_uri': application.redirect_uri if is_admin else None,
- 'avatar_email': application.avatar_email if is_admin else None,
- }
+ assert (client_secret is not None) == is_admin
+ return {
+ "name": application.name,
+ "description": application.description,
+ "application_uri": application.application_uri,
+ "client_id": application.client_id,
+ "client_secret": client_secret,
+ "redirect_uri": application.redirect_uri if is_admin else None,
+ "avatar_email": application.avatar_email if is_admin else None,
+ }
-@resource('/v1/organization//applications')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//applications")
+@path_param("orgname", "The name of the organization")
class OrganizationApplications(ApiResource):
- """ Resource for managing applications defined by an organization. """
- schemas = {
- 'NewApp': {
- 'type': 'object',
- 'description': 'Description of a new organization application.',
- 'required': [
- 'name',
- ],
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The name of the application',
- },
- 'redirect_uri': {
- 'type': 'string',
- 'description': 'The URI for the application\'s OAuth redirect',
- },
- 'application_uri': {
- 'type': 'string',
- 'description': 'The URI for the application\'s homepage',
- },
- 'description': {
- 'type': 'string',
- 'description': 'The human-readable description for the application',
- },
- 'avatar_email': {
- 'type': 'string',
- 'description': 'The e-mail address of the avatar to use for the application',
+ """ Resource for managing applications defined by an organization. """
+
+ schemas = {
+ "NewApp": {
+ "type": "object",
+ "description": "Description of a new organization application.",
+ "required": ["name"],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the application",
+ },
+ "redirect_uri": {
+ "type": "string",
+ "description": "The URI for the application's OAuth redirect",
+ },
+ "application_uri": {
+ "type": "string",
+ "description": "The URI for the application's homepage",
+ },
+ "description": {
+ "type": "string",
+ "description": "The human-readable description for the application",
+ },
+ "avatar_email": {
+ "type": "string",
+ "description": "The e-mail address of the avatar to use for the application",
+ },
+ },
}
- },
- },
- }
+ }
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationApplications')
- def get(self, orgname):
- """ List the applications for the specified organization """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationApplications")
+ def get(self, orgname):
+ """ List the applications for the specified organization """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- applications = model.oauth.list_applications_for_org(org)
- return {'applications': [app_view(application) for application in applications]}
+ applications = model.oauth.list_applications_for_org(org)
+ return {
+ "applications": [app_view(application) for application in applications]
+ }
- raise Unauthorized()
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('createOrganizationApplication')
- @validate_json_request('NewApp')
- def post(self, orgname):
- """ Creates a new application under this organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("createOrganizationApplication")
+ @validate_json_request("NewApp")
+ def post(self, orgname):
+ """ Creates a new application under this organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- app_data = request.get_json()
- application = model.oauth.create_application(org, app_data['name'],
- app_data.get('application_uri', ''),
- app_data.get('redirect_uri', ''),
- description=app_data.get('description', ''),
- avatar_email=app_data.get('avatar_email', None))
+ app_data = request.get_json()
+ application = model.oauth.create_application(
+ org,
+ app_data["name"],
+ app_data.get("application_uri", ""),
+ app_data.get("redirect_uri", ""),
+ description=app_data.get("description", ""),
+ avatar_email=app_data.get("avatar_email", None),
+ )
- app_data.update({
- 'application_name': application.name,
- 'client_id': application.client_id
- })
+ app_data.update(
+ {
+ "application_name": application.name,
+ "client_id": application.client_id,
+ }
+ )
- log_action('create_application', orgname, app_data)
+ log_action("create_application", orgname, app_data)
- return app_view(application)
- raise Unauthorized()
+ return app_view(application)
+ raise Unauthorized()
-@resource('/v1/organization//applications/')
-@path_param('orgname', 'The name of the organization')
-@path_param('client_id', 'The OAuth client ID')
+@resource("/v1/organization//applications/")
+@path_param("orgname", "The name of the organization")
+@path_param("client_id", "The OAuth client ID")
class OrganizationApplicationResource(ApiResource):
- """ Resource for managing an application defined by an organizations. """
- schemas = {
- 'UpdateApp': {
- 'type': 'object',
- 'description': 'Description of an updated application.',
- 'required': [
- 'name',
- 'redirect_uri',
- 'application_uri'
- ],
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The name of the application',
- },
- 'redirect_uri': {
- 'type': 'string',
- 'description': 'The URI for the application\'s OAuth redirect',
- },
- 'application_uri': {
- 'type': 'string',
- 'description': 'The URI for the application\'s homepage',
- },
- 'description': {
- 'type': 'string',
- 'description': 'The human-readable description for the application',
- },
- 'avatar_email': {
- 'type': 'string',
- 'description': 'The e-mail address of the avatar to use for the application',
+ """ Resource for managing an application defined by an organizations. """
+
+ schemas = {
+ "UpdateApp": {
+ "type": "object",
+ "description": "Description of an updated application.",
+ "required": ["name", "redirect_uri", "application_uri"],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the application",
+ },
+ "redirect_uri": {
+ "type": "string",
+ "description": "The URI for the application's OAuth redirect",
+ },
+ "application_uri": {
+ "type": "string",
+ "description": "The URI for the application's homepage",
+ },
+ "description": {
+ "type": "string",
+ "description": "The human-readable description for the application",
+ },
+ "avatar_email": {
+ "type": "string",
+ "description": "The e-mail address of the avatar to use for the application",
+ },
+ },
}
- },
- },
- }
+ }
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationApplication')
- def get(self, orgname, client_id):
- """ Retrieves the application with the specified client_id under the specified organization """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationApplication")
+ def get(self, orgname, client_id):
+ """ Retrieves the application with the specified client_id under the specified organization """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- application = model.oauth.lookup_application(org, client_id)
- if not application:
- raise NotFound()
+ application = model.oauth.lookup_application(org, client_id)
+ if not application:
+ raise NotFound()
- return app_view(application)
+ return app_view(application)
- raise Unauthorized()
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('updateOrganizationApplication')
- @validate_json_request('UpdateApp')
- def put(self, orgname, client_id):
- """ Updates an application under this organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("updateOrganizationApplication")
+ @validate_json_request("UpdateApp")
+ def put(self, orgname, client_id):
+ """ Updates an application under this organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- application = model.oauth.lookup_application(org, client_id)
- if not application:
- raise NotFound()
+ application = model.oauth.lookup_application(org, client_id)
+ if not application:
+ raise NotFound()
- app_data = request.get_json()
- application.name = app_data['name']
- application.application_uri = app_data['application_uri']
- application.redirect_uri = app_data['redirect_uri']
- application.description = app_data.get('description', '')
- application.avatar_email = app_data.get('avatar_email', None)
- application.save()
+ app_data = request.get_json()
+ application.name = app_data["name"]
+ application.application_uri = app_data["application_uri"]
+ application.redirect_uri = app_data["redirect_uri"]
+ application.description = app_data.get("description", "")
+ application.avatar_email = app_data.get("avatar_email", None)
+ application.save()
- app_data.update({
- 'application_name': application.name,
- 'client_id': application.client_id
- })
+ app_data.update(
+ {
+ "application_name": application.name,
+ "client_id": application.client_id,
+ }
+ )
- log_action('update_application', orgname, app_data)
+ log_action("update_application", orgname, app_data)
- return app_view(application)
- raise Unauthorized()
+ return app_view(application)
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrganizationApplication')
- def delete(self, orgname, client_id):
- """ Deletes the application under this organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrganizationApplication")
+ def delete(self, orgname, client_id):
+ """ Deletes the application under this organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- application = model.oauth.delete_application(org, client_id)
- if not application:
- raise NotFound()
+ application = model.oauth.delete_application(org, client_id)
+ if not application:
+ raise NotFound()
- log_action('delete_application', orgname,
- {'application_name': application.name, 'client_id': client_id})
+ log_action(
+ "delete_application",
+ orgname,
+ {"application_name": application.name, "client_id": client_id},
+ )
- return '', 204
- raise Unauthorized()
+ return "", 204
+ raise Unauthorized()
-@resource('/v1/organization//applications//resetclientsecret')
-@path_param('orgname', 'The name of the organization')
-@path_param('client_id', 'The OAuth client ID')
+@resource("/v1/organization//applications//resetclientsecret")
+@path_param("orgname", "The name of the organization")
+@path_param("client_id", "The OAuth client ID")
@internal_only
class OrganizationApplicationResetClientSecret(ApiResource):
- """ Custom verb for resetting the client secret of an application. """
- @nickname('resetOrganizationApplicationClientSecret')
- def post(self, orgname, client_id):
- """ Resets the client secret of the application. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ """ Custom verb for resetting the client secret of an application. """
- application = model.oauth.lookup_application(org, client_id)
- if not application:
- raise NotFound()
+ @nickname("resetOrganizationApplicationClientSecret")
+ def post(self, orgname, client_id):
+ """ Resets the client secret of the application. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- application = model.oauth.reset_client_secret(application)
- log_action('reset_application_client_secret', orgname,
- {'application_name': application.name, 'client_id': client_id})
+ application = model.oauth.lookup_application(org, client_id)
+ if not application:
+ raise NotFound()
- return app_view(application)
- raise Unauthorized()
+ application = model.oauth.reset_client_secret(application)
+ log_action(
+ "reset_application_client_secret",
+ orgname,
+ {"application_name": application.name, "client_id": client_id},
+ )
+
+ return app_view(application)
+ raise Unauthorized()
diff --git a/endpoints/api/permission.py b/endpoints/api/permission.py
index e85c6480e..365638d36 100644
--- a/endpoints/api/permission.py
+++ b/endpoints/api/permission.py
@@ -4,8 +4,16 @@ import logging
from flask import request
-from endpoints.api import (resource, nickname, require_repo_admin, RepositoryParamResource,
- log_action, request_error, validate_json_request, path_param)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_admin,
+ RepositoryParamResource,
+ log_action,
+ request_error,
+ validate_json_request,
+ path_param,
+)
from endpoints.exception import NotFound
from permission_models_pre_oci import pre_oci_model as model
from permission_models_interface import DeleteException, SaveException
@@ -13,197 +21,232 @@ from permission_models_interface import DeleteException, SaveException
logger = logging.getLogger(__name__)
-@resource('/v1/repository//permissions/team/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//permissions/team/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryTeamPermissionList(RepositoryParamResource):
- """ Resource for repository team permissions. """
- @require_repo_admin
- @nickname('listRepoTeamPermissions')
- def get(self, namespace_name, repository_name):
- """ List all team permission. """
- repo_perms = model.get_repo_permissions_by_team(namespace_name, repository_name)
+ """ Resource for repository team permissions. """
- return {
- 'permissions': {repo_perm.team_name: repo_perm.to_dict()
- for repo_perm in repo_perms}
- }
+ @require_repo_admin
+ @nickname("listRepoTeamPermissions")
+ def get(self, namespace_name, repository_name):
+ """ List all team permission. """
+ repo_perms = model.get_repo_permissions_by_team(namespace_name, repository_name)
+
+ return {
+ "permissions": {
+ repo_perm.team_name: repo_perm.to_dict() for repo_perm in repo_perms
+ }
+ }
-@resource('/v1/repository//permissions/user/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//permissions/user/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryUserPermissionList(RepositoryParamResource):
- """ Resource for repository user permissions. """
- @require_repo_admin
- @nickname('listRepoUserPermissions')
- def get(self, namespace_name, repository_name):
- """ List all user permissions. """
- perms = model.get_repo_permissions_by_user(namespace_name, repository_name)
- return {'permissions': {p.username: p.to_dict() for p in perms}}
+ """ Resource for repository user permissions. """
+
+ @require_repo_admin
+ @nickname("listRepoUserPermissions")
+ def get(self, namespace_name, repository_name):
+ """ List all user permissions. """
+ perms = model.get_repo_permissions_by_user(namespace_name, repository_name)
+ return {"permissions": {p.username: p.to_dict() for p in perms}}
-@resource('/v1/repository//permissions/user//transitive')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('username', 'The username of the user to which the permissions apply')
+@resource(
+ "/v1/repository//permissions/user//transitive"
+)
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("username", "The username of the user to which the permissions apply")
class RepositoryUserTransitivePermission(RepositoryParamResource):
- """ Resource for retrieving whether a user has access to a repository, either directly
+ """ Resource for retrieving whether a user has access to a repository, either directly
or via a team. """
- @require_repo_admin
- @nickname('getUserTransitivePermission')
- def get(self, namespace_name, repository_name, username):
- """ Get the fetch the permission for the specified user. """
-
- roles = model.get_repo_roles(username, namespace_name, repository_name)
-
- if not roles:
- raise NotFound
-
- return {
- 'permissions': [r.to_dict() for r in roles]
+
+ @require_repo_admin
+ @nickname("getUserTransitivePermission")
+ def get(self, namespace_name, repository_name, username):
+ """ Get the fetch the permission for the specified user. """
+
+ roles = model.get_repo_roles(username, namespace_name, repository_name)
+
+ if not roles:
+ raise NotFound
+
+ return {"permissions": [r.to_dict() for r in roles]}
+
+
+@resource("/v1/repository//permissions/user/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("username", "The username of the user to which the permission applies")
+class RepositoryUserPermission(RepositoryParamResource):
+ """ Resource for managing individual user permissions. """
+
+ schemas = {
+ "UserPermission": {
+ "type": "object",
+ "description": "Description of a user permission.",
+ "required": ["role"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Role to use for the user",
+ "enum": ["read", "write", "admin"],
+ }
+ },
+ }
}
+ @require_repo_admin
+ @nickname("getUserPermissions")
+ def get(self, namespace_name, repository_name, username):
+ """ Get the permission for the specified user. """
+ logger.debug(
+ "Get repo: %s/%s permissions for user %s",
+ namespace_name,
+ repository_name,
+ username,
+ )
+ perm = model.get_repo_permission_for_user(
+ username, namespace_name, repository_name
+ )
+ return perm.to_dict()
-@resource('/v1/repository//permissions/user/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('username', 'The username of the user to which the permission applies')
-class RepositoryUserPermission(RepositoryParamResource):
- """ Resource for managing individual user permissions. """
- schemas = {
- 'UserPermission': {
- 'type': 'object',
- 'description': 'Description of a user permission.',
- 'required': [
- 'role',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Role to use for the user',
- 'enum': [
- 'read',
- 'write',
- 'admin',
- ],
- },
- },
- },
- }
+ @require_repo_admin
+ @nickname("changeUserPermissions")
+ @validate_json_request("UserPermission")
+ def put(
+ self, namespace_name, repository_name, username
+ ): # Also needs to respond to post
+ """ Update the perimssions for an existing repository. """
+ new_permission = request.get_json()
- @require_repo_admin
- @nickname('getUserPermissions')
- def get(self, namespace_name, repository_name, username):
- """ Get the permission for the specified user. """
- logger.debug('Get repo: %s/%s permissions for user %s', namespace_name, repository_name, username)
- perm = model.get_repo_permission_for_user(username, namespace_name, repository_name)
- return perm.to_dict()
+ logger.debug(
+ "Setting permission to: %s for user %s", new_permission["role"], username
+ )
- @require_repo_admin
- @nickname('changeUserPermissions')
- @validate_json_request('UserPermission')
- def put(self, namespace_name, repository_name, username): # Also needs to respond to post
- """ Update the perimssions for an existing repository. """
- new_permission = request.get_json()
+ try:
+ perm = model.set_repo_permission_for_user(
+ username, namespace_name, repository_name, new_permission["role"]
+ )
+ resp = perm.to_dict()
+ except SaveException as ex:
+ raise request_error(exception=ex)
- logger.debug('Setting permission to: %s for user %s', new_permission['role'], username)
+ log_action(
+ "change_repo_permission",
+ namespace_name,
+ {
+ "username": username,
+ "repo": repository_name,
+ "namespace": namespace_name,
+ "role": new_permission["role"],
+ },
+ repo_name=repository_name,
+ )
- try:
- perm = model.set_repo_permission_for_user(username, namespace_name, repository_name,
- new_permission['role'])
- resp = perm.to_dict()
- except SaveException as ex:
- raise request_error(exception=ex)
+ return resp, 200
- log_action('change_repo_permission', namespace_name,
- {'username': username, 'repo': repository_name,
- 'namespace': namespace_name,
- 'role': new_permission['role']},
- repo_name=repository_name)
+ @require_repo_admin
+ @nickname("deleteUserPermissions")
+ def delete(self, namespace_name, repository_name, username):
+ """ Delete the permission for the user. """
+ try:
+ model.delete_repo_permission_for_user(
+ username, namespace_name, repository_name
+ )
+ except DeleteException as ex:
+ raise request_error(exception=ex)
- return resp, 200
+ log_action(
+ "delete_repo_permission",
+ namespace_name,
+ {
+ "username": username,
+ "repo": repository_name,
+ "namespace": namespace_name,
+ },
+ repo_name=repository_name,
+ )
- @require_repo_admin
- @nickname('deleteUserPermissions')
- def delete(self, namespace_name, repository_name, username):
- """ Delete the permission for the user. """
- try:
- model.delete_repo_permission_for_user(username, namespace_name, repository_name)
- except DeleteException as ex:
- raise request_error(exception=ex)
-
- log_action('delete_repo_permission', namespace_name,
- {'username': username, 'repo': repository_name, 'namespace': namespace_name},
- repo_name=repository_name)
-
- return '', 204
+ return "", 204
-@resource('/v1/repository//permissions/team/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('teamname', 'The name of the team to which the permission applies')
+@resource("/v1/repository//permissions/team/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("teamname", "The name of the team to which the permission applies")
class RepositoryTeamPermission(RepositoryParamResource):
- """ Resource for managing individual team permissions. """
- schemas = {
- 'TeamPermission': {
- 'type': 'object',
- 'description': 'Description of a team permission.',
- 'required': [
- 'role',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Role to use for the team',
- 'enum': [
- 'read',
- 'write',
- 'admin',
- ],
- },
- },
- },
- }
+ """ Resource for managing individual team permissions. """
- @require_repo_admin
- @nickname('getTeamPermissions')
- def get(self, namespace_name, repository_name, teamname):
- """ Fetch the permission for the specified team. """
- logger.debug('Get repo: %s/%s permissions for team %s', namespace_name, repository_name, teamname)
- role = model.get_repo_role_for_team(teamname, namespace_name, repository_name)
- return role.to_dict()
+ schemas = {
+ "TeamPermission": {
+ "type": "object",
+ "description": "Description of a team permission.",
+ "required": ["role"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Role to use for the team",
+ "enum": ["read", "write", "admin"],
+ }
+ },
+ }
+ }
- @require_repo_admin
- @nickname('changeTeamPermissions')
- @validate_json_request('TeamPermission')
- def put(self, namespace_name, repository_name, teamname):
- """ Update the existing team permission. """
- new_permission = request.get_json()
+ @require_repo_admin
+ @nickname("getTeamPermissions")
+ def get(self, namespace_name, repository_name, teamname):
+ """ Fetch the permission for the specified team. """
+ logger.debug(
+ "Get repo: %s/%s permissions for team %s",
+ namespace_name,
+ repository_name,
+ teamname,
+ )
+ role = model.get_repo_role_for_team(teamname, namespace_name, repository_name)
+ return role.to_dict()
- logger.debug('Setting permission to: %s for team %s', new_permission['role'], teamname)
+ @require_repo_admin
+ @nickname("changeTeamPermissions")
+ @validate_json_request("TeamPermission")
+ def put(self, namespace_name, repository_name, teamname):
+ """ Update the existing team permission. """
+ new_permission = request.get_json()
- try:
- perm = model.set_repo_permission_for_team(teamname, namespace_name, repository_name,
- new_permission['role'])
- resp = perm.to_dict()
- except SaveException as ex:
- raise request_error(exception=ex)
-
+ logger.debug(
+ "Setting permission to: %s for team %s", new_permission["role"], teamname
+ )
- log_action('change_repo_permission', namespace_name,
- {'team': teamname, 'repo': repository_name,
- 'role': new_permission['role']},
- repo_name=repository_name)
- return resp, 200
+ try:
+ perm = model.set_repo_permission_for_team(
+ teamname, namespace_name, repository_name, new_permission["role"]
+ )
+ resp = perm.to_dict()
+ except SaveException as ex:
+ raise request_error(exception=ex)
- @require_repo_admin
- @nickname('deleteTeamPermissions')
- def delete(self, namespace_name, repository_name, teamname):
- """ Delete the permission for the specified team. """
- try:
- model.delete_repo_permission_for_team(teamname, namespace_name, repository_name)
- except DeleteException as ex:
- raise request_error(exception=ex)
-
- log_action('delete_repo_permission', namespace_name,
- {'team': teamname, 'repo': repository_name},
- repo_name=repository_name)
+ log_action(
+ "change_repo_permission",
+ namespace_name,
+ {"team": teamname, "repo": repository_name, "role": new_permission["role"]},
+ repo_name=repository_name,
+ )
+ return resp, 200
- return '', 204
+ @require_repo_admin
+ @nickname("deleteTeamPermissions")
+ def delete(self, namespace_name, repository_name, teamname):
+ """ Delete the permission for the specified team. """
+ try:
+ model.delete_repo_permission_for_team(
+ teamname, namespace_name, repository_name
+ )
+ except DeleteException as ex:
+ raise request_error(exception=ex)
+
+ log_action(
+ "delete_repo_permission",
+ namespace_name,
+ {"team": teamname, "repo": repository_name},
+ repo_name=repository_name,
+ )
+
+ return "", 204
diff --git a/endpoints/api/permission_models_interface.py b/endpoints/api/permission_models_interface.py
index 49c24744c..acb3e9e6e 100644
--- a/endpoints/api/permission_models_interface.py
+++ b/endpoints/api/permission_models_interface.py
@@ -6,81 +6,70 @@ from six import add_metaclass
class SaveException(Exception):
- def __init__(self, other):
- self.traceback = sys.exc_info()
- super(SaveException, self).__init__(str(other))
+ def __init__(self, other):
+ self.traceback = sys.exc_info()
+ super(SaveException, self).__init__(str(other))
+
class DeleteException(Exception):
- def __init__(self, other):
- self.traceback = sys.exc_info()
- super(DeleteException, self).__init__(str(other))
+ def __init__(self, other):
+ self.traceback = sys.exc_info()
+ super(DeleteException, self).__init__(str(other))
-class Role(namedtuple('Role', ['role_name'])):
- def to_dict(self):
- return {
- 'role': self.role_name,
- }
-
-class UserPermission(namedtuple('UserPermission', [
- 'role_name',
- 'username',
- 'is_robot',
- 'avatar',
- 'is_org_member',
- 'has_org',
- ])):
-
- def to_dict(self):
- perm_dict = {
- 'role': self.role_name,
- 'name': self.username,
- 'is_robot': self.is_robot,
- 'avatar': self.avatar,
- }
- if self.has_org:
- perm_dict['is_org_member'] = self.is_org_member
- return perm_dict
-
-
-class RobotPermission(namedtuple('RobotPermission', [
- 'role_name',
- 'username',
- 'is_robot',
- 'is_org_member',
-])):
-
- def to_dict(self, user=None, team=None, org_members=None):
- return {
- 'role': self.role_name,
- 'name': self.username,
- 'is_robot': True,
- 'is_org_member': self.is_org_member,
- }
+class Role(namedtuple("Role", ["role_name"])):
+ def to_dict(self):
+ return {"role": self.role_name}
-class TeamPermission(namedtuple('TeamPermission', [
- 'role_name',
- 'team_name',
- 'avatar',
-])):
+class UserPermission(
+ namedtuple(
+ "UserPermission",
+ ["role_name", "username", "is_robot", "avatar", "is_org_member", "has_org"],
+ )
+):
+ def to_dict(self):
+ perm_dict = {
+ "role": self.role_name,
+ "name": self.username,
+ "is_robot": self.is_robot,
+ "avatar": self.avatar,
+ }
+ if self.has_org:
+ perm_dict["is_org_member"] = self.is_org_member
+ return perm_dict
+
+
+class RobotPermission(
+ namedtuple(
+ "RobotPermission", ["role_name", "username", "is_robot", "is_org_member"]
+ )
+):
+ def to_dict(self, user=None, team=None, org_members=None):
+ return {
+ "role": self.role_name,
+ "name": self.username,
+ "is_robot": True,
+ "is_org_member": self.is_org_member,
+ }
+
+
+class TeamPermission(
+ namedtuple("TeamPermission", ["role_name", "team_name", "avatar"])
+):
+ def to_dict(self):
+ return {"role": self.role_name, "name": self.team_name, "avatar": self.avatar}
- def to_dict(self):
- return {
- 'role': self.role_name,
- 'name': self.team_name,
- 'avatar': self.avatar,
- }
@add_metaclass(ABCMeta)
class PermissionDataInterface(object):
- """
+ """
Data interface used by permissions API
"""
-
- @abstractmethod
- def get_repo_permissions_by_user(self, namespace_name, repository_name):
- """
+
+ @abstractmethod
+ def get_repo_permissions_by_user(self, namespace_name, repository_name):
+ """
Args:
namespace_name: string
@@ -90,9 +79,9 @@ class PermissionDataInterface(object):
list(UserPermission)
"""
- @abstractmethod
- def get_repo_roles(self, username, namespace_name, repository_name):
- """
+ @abstractmethod
+ def get_repo_roles(self, username, namespace_name, repository_name):
+ """
Args:
username: string
@@ -101,11 +90,11 @@ class PermissionDataInterface(object):
Returns:
list(Role) or None
- """
-
- @abstractmethod
- def get_repo_permission_for_user(self, username, namespace_name, repository_name):
"""
+
+ @abstractmethod
+ def get_repo_permission_for_user(self, username, namespace_name, repository_name):
+ """
Args:
username: string
@@ -115,10 +104,12 @@ class PermissionDataInterface(object):
Returns:
UserPermission
"""
-
- @abstractmethod
- def set_repo_permission_for_user(self, username, namespace_name, repository_name, role_name):
- """
+
+ @abstractmethod
+ def set_repo_permission_for_user(
+ self, username, namespace_name, repository_name, role_name
+ ):
+ """
Args:
username: string
@@ -133,9 +124,11 @@ class PermissionDataInterface(object):
SaveException
"""
- @abstractmethod
- def delete_repo_permission_for_user(self, username, namespace_name, repository_name):
- """
+ @abstractmethod
+ def delete_repo_permission_for_user(
+ self, username, namespace_name, repository_name
+ ):
+ """
Args:
username: string
@@ -149,9 +142,9 @@ class PermissionDataInterface(object):
DeleteException
"""
- @abstractmethod
- def get_repo_permissions_by_team(self, namespace_name, repository_name):
- """
+ @abstractmethod
+ def get_repo_permissions_by_team(self, namespace_name, repository_name):
+ """
Args:
namespace_name: string
@@ -161,9 +154,9 @@ class PermissionDataInterface(object):
list(TeamPermission)
"""
- @abstractmethod
- def get_repo_role_for_team(self, team_name, namespace_name, repository_name):
- """
+ @abstractmethod
+ def get_repo_role_for_team(self, team_name, namespace_name, repository_name):
+ """
Args:
team_name: string
@@ -174,9 +167,11 @@ class PermissionDataInterface(object):
Role
"""
- @abstractmethod
- def set_repo_permission_for_team(self, team_name, namespace_name, repository_name, permission):
- """
+ @abstractmethod
+ def set_repo_permission_for_team(
+ self, team_name, namespace_name, repository_name, permission
+ ):
+ """
Args:
team_name: string
@@ -191,9 +186,11 @@ class PermissionDataInterface(object):
SaveException
"""
- @abstractmethod
- def delete_repo_permission_for_team(self, team_name, namespace_name, repository_name):
- """
+ @abstractmethod
+ def delete_repo_permission_for_team(
+ self, team_name, namespace_name, repository_name
+ ):
+ """
Args:
team_name: string
@@ -205,4 +202,4 @@ class PermissionDataInterface(object):
Raises:
DeleteException
- """
\ No newline at end of file
+ """
diff --git a/endpoints/api/permission_models_pre_oci.py b/endpoints/api/permission_models_pre_oci.py
index 1f19cad10..63f57d5f6 100644
--- a/endpoints/api/permission_models_pre_oci.py
+++ b/endpoints/api/permission_models_pre_oci.py
@@ -1,115 +1,168 @@
from app import avatar
from data import model
-from permission_models_interface import PermissionDataInterface, UserPermission, TeamPermission, Role, SaveException, DeleteException
+from permission_models_interface import (
+ PermissionDataInterface,
+ UserPermission,
+ TeamPermission,
+ Role,
+ SaveException,
+ DeleteException,
+)
class PreOCIModel(PermissionDataInterface):
- """
+ """
PreOCIModel implements the data model for Permission using a database schema
before it was changed to support the OCI specification.
"""
- def get_repo_permissions_by_user(self, namespace_name, repository_name):
- org = None
- try:
- org = model.organization.get_organization(namespace_name) # Will raise an error if not org
- except model.InvalidOrganizationException:
- # This repository isn't under an org
- pass
+ def get_repo_permissions_by_user(self, namespace_name, repository_name):
+ org = None
+ try:
+ org = model.organization.get_organization(
+ namespace_name
+ ) # Will raise an error if not org
+ except model.InvalidOrganizationException:
+ # This repository isn't under an org
+ pass
- # Load the permissions.
- repo_perms = model.user.get_all_repo_users(namespace_name, repository_name)
-
- if org:
- users_filter = {perm.user for perm in repo_perms}
- org_members = model.organization.get_organization_member_set(org, users_filter=users_filter)
-
- def is_org_member(user):
- if not org:
- return False
+ # Load the permissions.
+ repo_perms = model.user.get_all_repo_users(namespace_name, repository_name)
- return user.robot or user.username in org_members
-
- return [self._user_permission(perm, org is not None, is_org_member(perm.user)) for perm in repo_perms]
-
- def get_repo_roles(self, username, namespace_name, repository_name):
- user = model.user.get_user(username)
- if not user:
- return None
+ if org:
+ users_filter = {perm.user for perm in repo_perms}
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter=users_filter
+ )
- repo = model.repository.get_repository(namespace_name, repository_name)
- if not repo:
- return None
-
- return [self._role(r) for r in model.permission.get_user_repo_permissions(user, repo)]
+ def is_org_member(user):
+ if not org:
+ return False
- def get_repo_permission_for_user(self, username, namespace_name, repository_name):
- perm = model.permission.get_user_reponame_permission(username, namespace_name, repository_name)
- org = None
- try:
- org = model.organization.get_organization(namespace_name)
- org_members = model.organization.get_organization_member_set(org, users_filter={perm.user})
- is_org_member = perm.user.robot or perm.user.username in org_members
- except model.InvalidOrganizationException:
- # This repository is not part of an organization
- is_org_member = False
-
- return self._user_permission(perm, org is not None, is_org_member)
+ return user.robot or user.username in org_members
- def set_repo_permission_for_user(self, username, namespace_name, repository_name, role_name):
- try:
- perm = model.permission.set_user_repo_permission(username, namespace_name, repository_name, role_name)
- org = None
- try:
- org = model.organization.get_organization(namespace_name)
- org_members = model.organization.get_organization_member_set(org, users_filter={perm.user})
- is_org_member = perm.user.robot or perm.user.username in org_members
- except model.InvalidOrganizationException:
- # This repository is not part of an organization
- is_org_member = False
- return self._user_permission(perm, org is not None, is_org_member)
- except model.DataModelException as ex:
- raise SaveException(ex)
+ return [
+ self._user_permission(perm, org is not None, is_org_member(perm.user))
+ for perm in repo_perms
+ ]
- def delete_repo_permission_for_user(self, username, namespace_name, repository_name):
- try:
- model.permission.delete_user_permission(username, namespace_name, repository_name)
- except model.DataModelException as ex:
- raise DeleteException(ex)
+ def get_repo_roles(self, username, namespace_name, repository_name):
+ user = model.user.get_user(username)
+ if not user:
+ return None
- def get_repo_permissions_by_team(self, namespace_name, repository_name):
- repo_perms = model.permission.get_all_repo_teams(namespace_name, repository_name)
- return [self._team_permission(perm, perm.team.name) for perm in repo_perms]
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if not repo:
+ return None
- def get_repo_role_for_team(self, team_name, namespace_name, repository_name):
- return self._role(model.permission.get_team_reponame_permission(team_name, namespace_name, repository_name))
+ return [
+ self._role(r)
+ for r in model.permission.get_user_repo_permissions(user, repo)
+ ]
- def set_repo_permission_for_team(self, team_name, namespace_name, repository_name, role_name):
- try:
- return self._team_permission(model.permission.set_team_repo_permission(team_name, namespace_name, repository_name, role_name), team_name)
- except model.DataModelException as ex:
- raise SaveException(ex)
+ def get_repo_permission_for_user(self, username, namespace_name, repository_name):
+ perm = model.permission.get_user_reponame_permission(
+ username, namespace_name, repository_name
+ )
+ org = None
+ try:
+ org = model.organization.get_organization(namespace_name)
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter={perm.user}
+ )
+ is_org_member = perm.user.robot or perm.user.username in org_members
+ except model.InvalidOrganizationException:
+ # This repository is not part of an organization
+ is_org_member = False
+
+ return self._user_permission(perm, org is not None, is_org_member)
+
+ def set_repo_permission_for_user(
+ self, username, namespace_name, repository_name, role_name
+ ):
+ try:
+ perm = model.permission.set_user_repo_permission(
+ username, namespace_name, repository_name, role_name
+ )
+ org = None
+ try:
+ org = model.organization.get_organization(namespace_name)
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter={perm.user}
+ )
+ is_org_member = perm.user.robot or perm.user.username in org_members
+ except model.InvalidOrganizationException:
+ # This repository is not part of an organization
+ is_org_member = False
+ return self._user_permission(perm, org is not None, is_org_member)
+ except model.DataModelException as ex:
+ raise SaveException(ex)
+
+ def delete_repo_permission_for_user(
+ self, username, namespace_name, repository_name
+ ):
+ try:
+ model.permission.delete_user_permission(
+ username, namespace_name, repository_name
+ )
+ except model.DataModelException as ex:
+ raise DeleteException(ex)
+
+ def get_repo_permissions_by_team(self, namespace_name, repository_name):
+ repo_perms = model.permission.get_all_repo_teams(
+ namespace_name, repository_name
+ )
+ return [self._team_permission(perm, perm.team.name) for perm in repo_perms]
+
+ def get_repo_role_for_team(self, team_name, namespace_name, repository_name):
+ return self._role(
+ model.permission.get_team_reponame_permission(
+ team_name, namespace_name, repository_name
+ )
+ )
+
+ def set_repo_permission_for_team(
+ self, team_name, namespace_name, repository_name, role_name
+ ):
+ try:
+ return self._team_permission(
+ model.permission.set_team_repo_permission(
+ team_name, namespace_name, repository_name, role_name
+ ),
+ team_name,
+ )
+ except model.DataModelException as ex:
+ raise SaveException(ex)
+
+ def delete_repo_permission_for_team(
+ self, team_name, namespace_name, repository_name
+ ):
+ try:
+ model.permission.delete_team_permission(
+ team_name, namespace_name, repository_name
+ )
+ except model.DataModelException as ex:
+ raise DeleteException(ex)
+
+ def _role(self, permission_obj):
+ return Role(role_name=permission_obj.role.name)
+
+ def _user_permission(self, permission_obj, has_org, is_org_member):
+ return UserPermission(
+ role_name=permission_obj.role.name,
+ username=permission_obj.user.username,
+ is_robot=permission_obj.user.robot,
+ avatar=avatar.get_data_for_user(permission_obj.user),
+ is_org_member=is_org_member,
+ has_org=has_org,
+ )
+
+ def _team_permission(self, permission_obj, team_name):
+ return TeamPermission(
+ role_name=permission_obj.role.name,
+ team_name=permission_obj.team.name,
+ avatar=avatar.get_data_for_team(permission_obj.team),
+ )
- def delete_repo_permission_for_team(self, team_name, namespace_name, repository_name):
- try:
- model.permission.delete_team_permission(team_name, namespace_name, repository_name)
- except model.DataModelException as ex:
- raise DeleteException(ex)
-
- def _role(self, permission_obj):
- return Role(role_name=permission_obj.role.name)
-
- def _user_permission(self, permission_obj, has_org, is_org_member):
- return UserPermission(role_name=permission_obj.role.name,
- username=permission_obj.user.username,
- is_robot=permission_obj.user.robot,
- avatar=avatar.get_data_for_user(permission_obj.user),
- is_org_member=is_org_member,
- has_org=has_org)
-
- def _team_permission(self, permission_obj, team_name):
- return TeamPermission(role_name=permission_obj.role.name,
- team_name=permission_obj.team.name,
- avatar=avatar.get_data_for_team(permission_obj.team))
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/prototype.py b/endpoints/api/prototype.py
index 2944aab60..512e0710c 100644
--- a/endpoints/api/prototype.py
+++ b/endpoints/api/prototype.py
@@ -2,8 +2,16 @@
from flask import request
-from endpoints.api import (resource, nickname, ApiResource, validate_json_request, request_error,
- log_action, path_param, require_scope)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ validate_json_request,
+ request_error,
+ log_action,
+ path_param,
+ require_scope,
+)
from endpoints.exception import Unauthorized, NotFound
from auth.permissions import AdministerOrganizationPermission
from auth.auth_context import get_authenticated_user
@@ -13,258 +21,270 @@ from app import avatar
def prototype_view(proto, org_members):
- def prototype_user_view(user):
+ def prototype_user_view(user):
+ return {
+ "name": user.username,
+ "is_robot": user.robot,
+ "kind": "user",
+ "is_org_member": user.robot or user.username in org_members,
+ "avatar": avatar.get_data_for_user(user),
+ }
+
+ if proto.delegate_user:
+ delegate_view = prototype_user_view(proto.delegate_user)
+ else:
+ delegate_view = {
+ "name": proto.delegate_team.name,
+ "kind": "team",
+ "avatar": avatar.get_data_for_team(proto.delegate_team),
+ }
+
return {
- 'name': user.username,
- 'is_robot': user.robot,
- 'kind': 'user',
- 'is_org_member': user.robot or user.username in org_members,
- 'avatar': avatar.get_data_for_user(user)
+ "activating_user": (
+ prototype_user_view(proto.activating_user)
+ if proto.activating_user
+ else None
+ ),
+ "delegate": delegate_view,
+ "role": proto.role.name,
+ "id": proto.uuid,
}
- if proto.delegate_user:
- delegate_view = prototype_user_view(proto.delegate_user)
- else:
- delegate_view = {
- 'name': proto.delegate_team.name,
- 'kind': 'team',
- 'avatar': avatar.get_data_for_team(proto.delegate_team)
- }
-
- return {
- 'activating_user': (prototype_user_view(proto.activating_user)
- if proto.activating_user else None),
- 'delegate': delegate_view,
- 'role': proto.role.name,
- 'id': proto.uuid,
- }
def log_prototype_action(action_kind, orgname, prototype, **kwargs):
- username = get_authenticated_user().username
- log_params = {
- 'prototypeid': prototype.uuid,
- 'username': username,
- 'activating_username': (prototype.activating_user.username
- if prototype.activating_user else None),
- 'role': prototype.role.name
- }
+ username = get_authenticated_user().username
+ log_params = {
+ "prototypeid": prototype.uuid,
+ "username": username,
+ "activating_username": (
+ prototype.activating_user.username if prototype.activating_user else None
+ ),
+ "role": prototype.role.name,
+ }
- for key, value in kwargs.items():
- log_params[key] = value
+ for key, value in kwargs.items():
+ log_params[key] = value
- if prototype.delegate_user:
- log_params['delegate_user'] = prototype.delegate_user.username
- elif prototype.delegate_team:
- log_params['delegate_team'] = prototype.delegate_team.name
+ if prototype.delegate_user:
+ log_params["delegate_user"] = prototype.delegate_user.username
+ elif prototype.delegate_team:
+ log_params["delegate_team"] = prototype.delegate_team.name
- log_action(action_kind, orgname, log_params)
+ log_action(action_kind, orgname, log_params)
-@resource('/v1/organization//prototypes')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//prototypes")
+@path_param("orgname", "The name of the organization")
class PermissionPrototypeList(ApiResource):
- """ Resource for listing and creating permission prototypes. """
- schemas = {
- 'NewPrototype': {
- 'type': 'object',
- 'description': 'Description of a new prototype',
- 'required': [
- 'role',
- 'delegate',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Role that should be applied to the delegate',
- 'enum': [
- 'read',
- 'write',
- 'admin',
- ],
- },
- 'activating_user': {
- 'type': 'object',
- 'description': 'Repository creating user to whom the rule should apply',
- 'required': [
- 'name',
- ],
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The username for the activating_user',
+ """ Resource for listing and creating permission prototypes. """
+
+ schemas = {
+ "NewPrototype": {
+ "type": "object",
+ "description": "Description of a new prototype",
+ "required": ["role", "delegate"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Role that should be applied to the delegate",
+ "enum": ["read", "write", "admin"],
+ },
+ "activating_user": {
+ "type": "object",
+ "description": "Repository creating user to whom the rule should apply",
+ "required": ["name"],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The username for the activating_user",
+ }
+ },
+ },
+ "delegate": {
+ "type": "object",
+ "description": "Information about the user or team to which the rule grants access",
+ "required": ["name", "kind"],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name for the delegate team or user",
+ },
+ "kind": {
+ "type": "string",
+ "description": "Whether the delegate is a user or a team",
+ "enum": ["user", "team"],
+ },
+ },
+ },
},
- },
- },
- 'delegate': {
- 'type': 'object',
- 'description': 'Information about the user or team to which the rule grants access',
- 'required': [
- 'name',
- 'kind',
- ],
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The name for the delegate team or user',
- },
- 'kind': {
- 'type': 'string',
- 'description': 'Whether the delegate is a user or a team',
- 'enum': [
- 'user',
- 'team',
- ],
- },
- },
- },
- },
- },
- }
+ }
+ }
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrganizationPrototypePermissions')
- def get(self, orgname):
- """ List the existing prototypes for this organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrganizationPrototypePermissions")
+ def get(self, orgname):
+ """ List the existing prototypes for this organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- permissions = model.permission.get_prototype_permissions(org)
+ permissions = model.permission.get_prototype_permissions(org)
- users_filter = ({p.activating_user for p in permissions} |
- {p.delegate_user for p in permissions})
- org_members = model.organization.get_organization_member_set(org, users_filter=users_filter)
- return {'prototypes': [prototype_view(p, org_members) for p in permissions]}
+ users_filter = {p.activating_user for p in permissions} | {
+ p.delegate_user for p in permissions
+ }
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter=users_filter
+ )
+ return {"prototypes": [prototype_view(p, org_members) for p in permissions]}
- raise Unauthorized()
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('createOrganizationPrototypePermission')
- @validate_json_request('NewPrototype')
- def post(self, orgname):
- """ Create a new permission prototype. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("createOrganizationPrototypePermission")
+ @validate_json_request("NewPrototype")
+ def post(self, orgname):
+ """ Create a new permission prototype. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- details = request.get_json()
- activating_username = None
+ details = request.get_json()
+ activating_username = None
- if ('activating_user' in details and details['activating_user'] and
- 'name' in details['activating_user']):
- activating_username = details['activating_user']['name']
+ if (
+ "activating_user" in details
+ and details["activating_user"]
+ and "name" in details["activating_user"]
+ ):
+ activating_username = details["activating_user"]["name"]
- delegate = details['delegate'] if 'delegate' in details else {}
- delegate_kind = delegate.get('kind', None)
- delegate_name = delegate.get('name', None)
+ delegate = details["delegate"] if "delegate" in details else {}
+ delegate_kind = delegate.get("kind", None)
+ delegate_name = delegate.get("name", None)
- delegate_username = delegate_name if delegate_kind == 'user' else None
- delegate_teamname = delegate_name if delegate_kind == 'team' else None
+ delegate_username = delegate_name if delegate_kind == "user" else None
+ delegate_teamname = delegate_name if delegate_kind == "team" else None
- activating_user = (model.user.get_user(activating_username) if activating_username else None)
- delegate_user = (model.user.get_user(delegate_username) if delegate_username else None)
- delegate_team = (model.team.get_organization_team(orgname, delegate_teamname)
- if delegate_teamname else None)
+ activating_user = (
+ model.user.get_user(activating_username)
+ if activating_username
+ else None
+ )
+ delegate_user = (
+ model.user.get_user(delegate_username) if delegate_username else None
+ )
+ delegate_team = (
+ model.team.get_organization_team(orgname, delegate_teamname)
+ if delegate_teamname
+ else None
+ )
- if activating_username and not activating_user:
- raise request_error(message='Unknown activating user')
+ if activating_username and not activating_user:
+ raise request_error(message="Unknown activating user")
- if not delegate_user and not delegate_team:
- raise request_error(message='Missing delegate user or team')
+ if not delegate_user and not delegate_team:
+ raise request_error(message="Missing delegate user or team")
- role_name = details['role']
+ role_name = details["role"]
- prototype = model.permission.add_prototype_permission(org, role_name, activating_user,
- delegate_user, delegate_team)
- log_prototype_action('create_prototype_permission', orgname, prototype)
+ prototype = model.permission.add_prototype_permission(
+ org, role_name, activating_user, delegate_user, delegate_team
+ )
+ log_prototype_action("create_prototype_permission", orgname, prototype)
- users_filter = {prototype.activating_user, prototype.delegate_user}
- org_members = model.organization.get_organization_member_set(org, users_filter=users_filter)
- return prototype_view(prototype, org_members)
+ users_filter = {prototype.activating_user, prototype.delegate_user}
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter=users_filter
+ )
+ return prototype_view(prototype, org_members)
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/organization//prototypes/')
-@path_param('orgname', 'The name of the organization')
-@path_param('prototypeid', 'The ID of the prototype')
+@resource("/v1/organization//prototypes/")
+@path_param("orgname", "The name of the organization")
+@path_param("prototypeid", "The ID of the prototype")
class PermissionPrototype(ApiResource):
- """ Resource for managingin individual permission prototypes. """
- schemas = {
- 'PrototypeUpdate': {
- 'type': 'object',
- 'description': 'Description of a the new prototype role',
- 'required': [
- 'role',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Role that should be applied to the permission',
- 'enum': [
- 'read',
- 'write',
- 'admin',
- ],
- },
- },
- },
- }
+ """ Resource for managingin individual permission prototypes. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrganizationPrototypePermission')
- def delete(self, orgname, prototypeid):
- """ Delete an existing permission prototype. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ schemas = {
+ "PrototypeUpdate": {
+ "type": "object",
+ "description": "Description of a the new prototype role",
+ "required": ["role"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Role that should be applied to the permission",
+ "enum": ["read", "write", "admin"],
+ }
+ },
+ }
+ }
- prototype = model.permission.delete_prototype_permission(org, prototypeid)
- if not prototype:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrganizationPrototypePermission")
+ def delete(self, orgname, prototypeid):
+ """ Delete an existing permission prototype. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- log_prototype_action('delete_prototype_permission', orgname, prototype)
+ prototype = model.permission.delete_prototype_permission(org, prototypeid)
+ if not prototype:
+ raise NotFound()
- return '', 204
+ log_prototype_action("delete_prototype_permission", orgname, prototype)
- raise Unauthorized()
+ return "", 204
- @require_scope(scopes.ORG_ADMIN)
- @nickname('updateOrganizationPrototypePermission')
- @validate_json_request('PrototypeUpdate')
- def put(self, orgname, prototypeid):
- """ Update the role of an existing permission prototype. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- try:
- org = model.organization.get_organization(orgname)
- except model.InvalidOrganizationException:
- raise NotFound()
+ raise Unauthorized()
- existing = model.permission.get_prototype_permission(org, prototypeid)
- if not existing:
- raise NotFound()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("updateOrganizationPrototypePermission")
+ @validate_json_request("PrototypeUpdate")
+ def put(self, orgname, prototypeid):
+ """ Update the role of an existing permission prototype. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ try:
+ org = model.organization.get_organization(orgname)
+ except model.InvalidOrganizationException:
+ raise NotFound()
- details = request.get_json()
- role_name = details['role']
- prototype = model.permission.update_prototype_permission(org, prototypeid, role_name)
- if not prototype:
- raise NotFound()
+ existing = model.permission.get_prototype_permission(org, prototypeid)
+ if not existing:
+ raise NotFound()
- log_prototype_action('modify_prototype_permission', orgname, prototype,
- original_role=existing.role.name)
+ details = request.get_json()
+ role_name = details["role"]
+ prototype = model.permission.update_prototype_permission(
+ org, prototypeid, role_name
+ )
+ if not prototype:
+ raise NotFound()
- users_filter = {prototype.activating_user, prototype.delegate_user}
- org_members = model.organization.get_organization_member_set(org, users_filter=users_filter)
- return prototype_view(prototype, org_members)
+ log_prototype_action(
+ "modify_prototype_permission",
+ orgname,
+ prototype,
+ original_role=existing.role.name,
+ )
- raise Unauthorized()
+ users_filter = {prototype.activating_user, prototype.delegate_user}
+ org_members = model.organization.get_organization_member_set(
+ org, users_filter=users_filter
+ )
+ return prototype_view(prototype, org_members)
+
+ raise Unauthorized()
diff --git a/endpoints/api/repoemail.py b/endpoints/api/repoemail.py
index 3edccb4cc..b41f7b30c 100644
--- a/endpoints/api/repoemail.py
+++ b/endpoints/api/repoemail.py
@@ -4,8 +4,17 @@ import logging
from flask import request, abort
-from endpoints.api import (resource, nickname, require_repo_admin, RepositoryParamResource,
- log_action, validate_json_request, internal_only, path_param, show_if)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_admin,
+ RepositoryParamResource,
+ log_action,
+ validate_json_request,
+ internal_only,
+ path_param,
+ show_if,
+)
from endpoints.api.repoemail_models_pre_oci import pre_oci_model as model
from endpoints.exception import NotFound
from app import tf
@@ -18,35 +27,37 @@ logger = logging.getLogger(__name__)
@internal_only
-@resource('/v1/repository//authorizedemail/')
+@resource("/v1/repository//authorizedemail/")
@show_if(features.MAILING)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('email', 'The e-mail address')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("email", "The e-mail address")
class RepositoryAuthorizedEmail(RepositoryParamResource):
- """ Resource for checking and authorizing e-mail addresses to receive repo notifications. """
+ """ Resource for checking and authorizing e-mail addresses to receive repo notifications. """
- @require_repo_admin
- @nickname('checkRepoEmailAuthorized')
- def get(self, namespace, repository, email):
- """ Checks to see if the given e-mail address is authorized on this repository. """
- record = model.get_email_authorized_for_repo(namespace, repository, email)
- if not record:
- abort(404)
+ @require_repo_admin
+ @nickname("checkRepoEmailAuthorized")
+ def get(self, namespace, repository, email):
+ """ Checks to see if the given e-mail address is authorized on this repository. """
+ record = model.get_email_authorized_for_repo(namespace, repository, email)
+ if not record:
+ abort(404)
- return record.to_dict()
-
- @require_repo_admin
- @nickname('sendAuthorizeRepoEmail')
- def post(self, namespace, repository, email):
- """ Starts the authorization process for an e-mail address on a repository. """
-
- with tf(db):
- record = model.get_email_authorized_for_repo(namespace, repository, email)
- if record and record.confirmed:
return record.to_dict()
- if not record:
- record = model.create_email_authorization_for_repo(namespace, repository, email)
+ @require_repo_admin
+ @nickname("sendAuthorizeRepoEmail")
+ def post(self, namespace, repository, email):
+ """ Starts the authorization process for an e-mail address on a repository. """
- send_repo_authorization_email(namespace, repository, email, record.code)
- return record.to_dict()
+ with tf(db):
+ record = model.get_email_authorized_for_repo(namespace, repository, email)
+ if record and record.confirmed:
+ return record.to_dict()
+
+ if not record:
+ record = model.create_email_authorization_for_repo(
+ namespace, repository, email
+ )
+
+ send_repo_authorization_email(namespace, repository, email, record.code)
+ return record.to_dict()
diff --git a/endpoints/api/repoemail_models_interface.py b/endpoints/api/repoemail_models_interface.py
index 2aae7ab9c..62dcc8a14 100644
--- a/endpoints/api/repoemail_models_interface.py
+++ b/endpoints/api/repoemail_models_interface.py
@@ -5,14 +5,12 @@ from six import add_metaclass
class RepositoryAuthorizedEmail(
- namedtuple('RepositoryAuthorizedEmail', [
- 'email',
- 'repository_name',
- 'namespace_name',
- 'confirmed',
- 'code',
- ])):
- """
+ namedtuple(
+ "RepositoryAuthorizedEmail",
+ ["email", "repository_name", "namespace_name", "confirmed", "code"],
+ )
+):
+ """
Tag represents a name to an image.
:type email: string
:type repository_name: string
@@ -21,30 +19,32 @@ class RepositoryAuthorizedEmail(
:type code: string
"""
- def to_dict(self):
- return {
- 'email': self.email,
- 'repository': self.repository_name,
- 'namespace': self.namespace_name,
- 'confirmed': self.confirmed,
- 'code': self.code
- }
+ def to_dict(self):
+ return {
+ "email": self.email,
+ "repository": self.repository_name,
+ "namespace": self.namespace_name,
+ "confirmed": self.confirmed,
+ "code": self.code,
+ }
@add_metaclass(ABCMeta)
class RepoEmailDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by a Repo Email.
"""
- @abstractmethod
- def get_email_authorized_for_repo(self, namespace_name, repository_name, email):
- """
+ @abstractmethod
+ def get_email_authorized_for_repo(self, namespace_name, repository_name, email):
+ """
Returns a RepositoryAuthorizedEmail if available else None
"""
- @abstractmethod
- def create_email_authorization_for_repo(self, namespace_name, repository_name, email):
- """
+ @abstractmethod
+ def create_email_authorization_for_repo(
+ self, namespace_name, repository_name, email
+ ):
+ """
Returns the newly created repository authorized email.
"""
diff --git a/endpoints/api/repoemail_models_pre_oci.py b/endpoints/api/repoemail_models_pre_oci.py
index 80a65c995..5d091a9f1 100644
--- a/endpoints/api/repoemail_models_pre_oci.py
+++ b/endpoints/api/repoemail_models_pre_oci.py
@@ -1,28 +1,42 @@
from data import model
-from endpoints.api.repoemail_models_interface import RepoEmailDataInterface, RepositoryAuthorizedEmail
+from endpoints.api.repoemail_models_interface import (
+ RepoEmailDataInterface,
+ RepositoryAuthorizedEmail,
+)
def _return_none_or_data(func, namespace_name, repository_name, email):
- data = func(namespace_name, repository_name, email)
- if data is None:
- return data
- return RepositoryAuthorizedEmail(email, repository_name, namespace_name, data.confirmed,
- data.code)
+ data = func(namespace_name, repository_name, email)
+ if data is None:
+ return data
+ return RepositoryAuthorizedEmail(
+ email, repository_name, namespace_name, data.confirmed, data.code
+ )
class PreOCIModel(RepoEmailDataInterface):
- """
+ """
PreOCIModel implements the data model for the Repo Email using a database schema
before it was changed to support the OCI specification.
"""
- def get_email_authorized_for_repo(self, namespace_name, repository_name, email):
- return _return_none_or_data(model.repository.get_email_authorized_for_repo, namespace_name,
- repository_name, email)
+ def get_email_authorized_for_repo(self, namespace_name, repository_name, email):
+ return _return_none_or_data(
+ model.repository.get_email_authorized_for_repo,
+ namespace_name,
+ repository_name,
+ email,
+ )
- def create_email_authorization_for_repo(self, namespace_name, repository_name, email):
- return _return_none_or_data(model.repository.create_email_authorization_for_repo,
- namespace_name, repository_name, email)
+ def create_email_authorization_for_repo(
+ self, namespace_name, repository_name, email
+ ):
+ return _return_none_or_data(
+ model.repository.create_email_authorization_for_repo,
+ namespace_name,
+ repository_name,
+ email,
+ )
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/repository.py b/endpoints/api/repository.py
index d117f238d..5b9b0aaee 100644
--- a/endpoints/api/repository.py
+++ b/endpoints/api/repository.py
@@ -12,17 +12,42 @@ from flask import request, abort
from app import dockerfile_build_queue, tuf_metadata_api
from data.database import RepositoryState
from endpoints.api import (
- format_date, nickname, log_action, validate_json_request, require_repo_read, require_repo_write,
- require_repo_admin, RepositoryParamResource, resource, parse_args, ApiResource, request_error,
- require_scope, path_param, page_support, query_param, truthy_bool, show_if)
+ format_date,
+ nickname,
+ log_action,
+ validate_json_request,
+ require_repo_read,
+ require_repo_write,
+ require_repo_admin,
+ RepositoryParamResource,
+ resource,
+ parse_args,
+ ApiResource,
+ request_error,
+ require_scope,
+ path_param,
+ page_support,
+ query_param,
+ truthy_bool,
+ show_if,
+)
from endpoints.api.repository_models_pre_oci import pre_oci_model as model
from endpoints.exception import (
- Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException, DownstreamIssue)
+ Unauthorized,
+ NotFound,
+ InvalidRequest,
+ ExceedsLicenseException,
+ DownstreamIssue,
+)
from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan
from endpoints.api.subscribe import check_repository_usage
-from auth.permissions import (ModifyRepositoryPermission, AdministerRepositoryPermission,
- CreateRepositoryPermission, ReadRepositoryPermission)
+from auth.permissions import (
+ ModifyRepositoryPermission,
+ AdministerRepositoryPermission,
+ CreateRepositoryPermission,
+ ReadRepositoryPermission,
+)
from auth.auth_context import get_authenticated_user
from auth import scopes
from util.names import REPOSITORY_NAME_REGEX
@@ -34,371 +59,443 @@ MAX_DAYS_IN_3_MONTHS = 92
def check_allowed_private_repos(namespace):
- """ Checks to see if the given namespace has reached its private repository limit. If so,
+ """ Checks to see if the given namespace has reached its private repository limit. If so,
raises a ExceedsLicenseException.
"""
- # Not enabled if billing is disabled.
- if not features.BILLING:
- return
+ # Not enabled if billing is disabled.
+ if not features.BILLING:
+ return
- if not lookup_allowed_private_repos(namespace):
- raise ExceedsLicenseException()
+ if not lookup_allowed_private_repos(namespace):
+ raise ExceedsLicenseException()
-@resource('/v1/repository')
+@resource("/v1/repository")
class RepositoryList(ApiResource):
- """Operations for creating and listing repositories."""
- schemas = {
- 'NewRepo': {
- 'type': 'object',
- 'description': 'Description of a new repository',
- 'required': [
- 'repository',
- 'visibility',
- 'description',
- ],
- 'properties': {
- 'repository': {
- 'type': 'string',
- 'description': 'Repository name',
- },
- 'visibility': {
- 'type': 'string',
- 'description': 'Visibility which the repository will start with',
- 'enum': [
- 'public',
- 'private',
- ],
- },
- 'namespace': {
- 'type':
- 'string',
- 'description': ('Namespace in which the repository should be created. If omitted, the '
- 'username of the caller is used'),
- },
- 'description': {
- 'type': 'string',
- 'description': 'Markdown encoded description for the repository',
- },
- 'repo_kind': {
- 'type': ['string', 'null'],
- 'description': 'The kind of repository',
- 'enum': ['image', 'application', None],
+ """Operations for creating and listing repositories."""
+
+ schemas = {
+ "NewRepo": {
+ "type": "object",
+ "description": "Description of a new repository",
+ "required": ["repository", "visibility", "description"],
+ "properties": {
+ "repository": {"type": "string", "description": "Repository name"},
+ "visibility": {
+ "type": "string",
+ "description": "Visibility which the repository will start with",
+ "enum": ["public", "private"],
+ },
+ "namespace": {
+ "type": "string",
+ "description": (
+ "Namespace in which the repository should be created. If omitted, the "
+ "username of the caller is used"
+ ),
+ },
+ "description": {
+ "type": "string",
+ "description": "Markdown encoded description for the repository",
+ },
+ "repo_kind": {
+ "type": ["string", "null"],
+ "description": "The kind of repository",
+ "enum": ["image", "application", None],
+ },
+ },
}
- },
- },
- }
+ }
- @require_scope(scopes.CREATE_REPO)
- @nickname('createRepo')
- @validate_json_request('NewRepo')
- def post(self):
- """Create a new repository."""
- owner = get_authenticated_user()
- req = request.get_json()
+ @require_scope(scopes.CREATE_REPO)
+ @nickname("createRepo")
+ @validate_json_request("NewRepo")
+ def post(self):
+ """Create a new repository."""
+ owner = get_authenticated_user()
+ req = request.get_json()
- if owner is None and 'namespace' not in 'req':
- raise InvalidRequest('Must provide a namespace or must be logged in.')
+ if owner is None and "namespace" not in "req":
+ raise InvalidRequest("Must provide a namespace or must be logged in.")
- namespace_name = req['namespace'] if 'namespace' in req else owner.username
+ namespace_name = req["namespace"] if "namespace" in req else owner.username
- permission = CreateRepositoryPermission(namespace_name)
- if permission.can():
- repository_name = req['repository']
- visibility = req['visibility']
+ permission = CreateRepositoryPermission(namespace_name)
+ if permission.can():
+ repository_name = req["repository"]
+ visibility = req["visibility"]
- if model.repo_exists(namespace_name, repository_name):
- raise request_error(message='Repository already exists')
+ if model.repo_exists(namespace_name, repository_name):
+ raise request_error(message="Repository already exists")
- visibility = req['visibility']
- if visibility == 'private':
- check_allowed_private_repos(namespace_name)
+ visibility = req["visibility"]
+ if visibility == "private":
+ check_allowed_private_repos(namespace_name)
- # Verify that the repository name is valid.
- if not REPOSITORY_NAME_REGEX.match(repository_name):
- raise InvalidRequest('Invalid repository name')
+ # Verify that the repository name is valid.
+ if not REPOSITORY_NAME_REGEX.match(repository_name):
+ raise InvalidRequest("Invalid repository name")
- kind = req.get('repo_kind', 'image') or 'image'
- model.create_repo(namespace_name, repository_name, owner, req['description'],
- visibility=visibility, repo_kind=kind)
+ kind = req.get("repo_kind", "image") or "image"
+ model.create_repo(
+ namespace_name,
+ repository_name,
+ owner,
+ req["description"],
+ visibility=visibility,
+ repo_kind=kind,
+ )
- log_action('create_repo', namespace_name,
- {'repo': repository_name,
- 'namespace': namespace_name}, repo_name=repository_name)
- return {
- 'namespace': namespace_name,
- 'name': repository_name,
- 'kind': kind,
- }, 201
+ log_action(
+ "create_repo",
+ namespace_name,
+ {"repo": repository_name, "namespace": namespace_name},
+ repo_name=repository_name,
+ )
+ return (
+ {"namespace": namespace_name, "name": repository_name, "kind": kind},
+ 201,
+ )
- raise Unauthorized()
+ raise Unauthorized()
- @require_scope(scopes.READ_REPO)
- @nickname('listRepos')
- @parse_args()
- @query_param('namespace', 'Filters the repositories returned to this namespace', type=str)
- @query_param('starred', 'Filters the repositories returned to those starred by the user',
- type=truthy_bool, default=False)
- @query_param('public', 'Adds any repositories visible to the user by virtue of being public',
- type=truthy_bool, default=False)
- @query_param('last_modified', 'Whether to include when the repository was last modified.',
- type=truthy_bool, default=False)
- @query_param('popularity', 'Whether to include the repository\'s popularity metric.',
- type=truthy_bool, default=False)
- @query_param('repo_kind', 'The kind of repositories to return', type=str, default='image')
- @page_support()
- def get(self, page_token, parsed_args):
- """ Fetch the list of repositories visible to the current user under a variety of situations.
+ @require_scope(scopes.READ_REPO)
+ @nickname("listRepos")
+ @parse_args()
+ @query_param(
+ "namespace", "Filters the repositories returned to this namespace", type=str
+ )
+ @query_param(
+ "starred",
+ "Filters the repositories returned to those starred by the user",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "public",
+ "Adds any repositories visible to the user by virtue of being public",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "last_modified",
+ "Whether to include when the repository was last modified.",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "popularity",
+ "Whether to include the repository's popularity metric.",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "repo_kind", "The kind of repositories to return", type=str, default="image"
+ )
+ @page_support()
+ def get(self, page_token, parsed_args):
+ """ Fetch the list of repositories visible to the current user under a variety of situations.
"""
- # Ensure that the user requests either filtered by a namespace, only starred repositories,
- # or public repositories. This ensures that the user is not requesting *all* visible repos,
- # which can cause a surge in DB CPU usage.
- if not parsed_args['namespace'] and not parsed_args['starred'] and not parsed_args['public']:
- raise InvalidRequest('namespace, starred or public are required for this API call')
+ # Ensure that the user requests either filtered by a namespace, only starred repositories,
+ # or public repositories. This ensures that the user is not requesting *all* visible repos,
+ # which can cause a surge in DB CPU usage.
+ if (
+ not parsed_args["namespace"]
+ and not parsed_args["starred"]
+ and not parsed_args["public"]
+ ):
+ raise InvalidRequest(
+ "namespace, starred or public are required for this API call"
+ )
- user = get_authenticated_user()
- username = user.username if user else None
- last_modified = parsed_args['last_modified']
- popularity = parsed_args['popularity']
+ user = get_authenticated_user()
+ username = user.username if user else None
+ last_modified = parsed_args["last_modified"]
+ popularity = parsed_args["popularity"]
- if parsed_args['starred'] and not username:
- # No repositories should be returned, as there is no user.
- abort(400)
+ if parsed_args["starred"] and not username:
+ # No repositories should be returned, as there is no user.
+ abort(400)
- repos, next_page_token = model.get_repo_list(
- parsed_args['starred'], user, parsed_args['repo_kind'], parsed_args['namespace'], username,
- parsed_args['public'], page_token, last_modified, popularity)
+ repos, next_page_token = model.get_repo_list(
+ parsed_args["starred"],
+ user,
+ parsed_args["repo_kind"],
+ parsed_args["namespace"],
+ username,
+ parsed_args["public"],
+ page_token,
+ last_modified,
+ popularity,
+ )
- return {'repositories': [repo.to_dict() for repo in repos]}, next_page_token
+ return {"repositories": [repo.to_dict() for repo in repos]}, next_page_token
-@resource('/v1/repository/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class Repository(RepositoryParamResource):
- """Operations for managing a specific repository."""
- schemas = {
- 'RepoUpdate': {
- 'type': 'object',
- 'description': 'Fields which can be updated in a repository.',
- 'required': ['description',],
- 'properties': {
- 'description': {
- 'type': 'string',
- 'description': 'Markdown encoded description for the repository',
- },
- }
+ """Operations for managing a specific repository."""
+
+ schemas = {
+ "RepoUpdate": {
+ "type": "object",
+ "description": "Fields which can be updated in a repository.",
+ "required": ["description"],
+ "properties": {
+ "description": {
+ "type": "string",
+ "description": "Markdown encoded description for the repository",
+ }
+ },
+ }
}
- }
- @parse_args()
- @query_param('includeStats', 'Whether to include action statistics', type=truthy_bool,
- default=False)
- @query_param('includeTags', 'Whether to include repository tags', type=truthy_bool,
- default=True)
- @require_repo_read
- @nickname('getRepo')
- def get(self, namespace, repository, parsed_args):
- """Fetch the specified repository."""
- logger.debug('Get repo: %s/%s' % (namespace, repository))
- include_tags = parsed_args['includeTags']
- max_tags = 500
- repo = model.get_repo(namespace, repository, get_authenticated_user(), include_tags, max_tags)
- if repo is None:
- raise NotFound()
+ @parse_args()
+ @query_param(
+ "includeStats",
+ "Whether to include action statistics",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "includeTags",
+ "Whether to include repository tags",
+ type=truthy_bool,
+ default=True,
+ )
+ @require_repo_read
+ @nickname("getRepo")
+ def get(self, namespace, repository, parsed_args):
+ """Fetch the specified repository."""
+ logger.debug("Get repo: %s/%s" % (namespace, repository))
+ include_tags = parsed_args["includeTags"]
+ max_tags = 500
+ repo = model.get_repo(
+ namespace, repository, get_authenticated_user(), include_tags, max_tags
+ )
+ if repo is None:
+ raise NotFound()
- has_write_permission = ModifyRepositoryPermission(namespace, repository).can()
- has_write_permission = has_write_permission and repo.state == RepositoryState.NORMAL
+ has_write_permission = ModifyRepositoryPermission(namespace, repository).can()
+ has_write_permission = (
+ has_write_permission and repo.state == RepositoryState.NORMAL
+ )
- repo_data = repo.to_dict()
- repo_data['can_write'] = has_write_permission
- repo_data['can_admin'] = AdministerRepositoryPermission(namespace, repository).can()
+ repo_data = repo.to_dict()
+ repo_data["can_write"] = has_write_permission
+ repo_data["can_admin"] = AdministerRepositoryPermission(
+ namespace, repository
+ ).can()
- if parsed_args['includeStats'] and repo.repository_base_elements.kind_name != 'application':
- stats = []
- found_dates = {}
+ if (
+ parsed_args["includeStats"]
+ and repo.repository_base_elements.kind_name != "application"
+ ):
+ stats = []
+ found_dates = {}
- for count in repo.counts:
- stats.append(count.to_dict())
- found_dates['%s/%s' % (count.date.month, count.date.day)] = True
+ for count in repo.counts:
+ stats.append(count.to_dict())
+ found_dates["%s/%s" % (count.date.month, count.date.day)] = True
- # Fill in any missing stats with zeros.
- for day in range(1, MAX_DAYS_IN_3_MONTHS):
- day_date = datetime.now() - timedelta(days=day)
- key = '%s/%s' % (day_date.month, day_date.day)
- if key not in found_dates:
- stats.append({
- 'date': day_date.date().isoformat(),
- 'count': 0,
- })
+ # Fill in any missing stats with zeros.
+ for day in range(1, MAX_DAYS_IN_3_MONTHS):
+ day_date = datetime.now() - timedelta(days=day)
+ key = "%s/%s" % (day_date.month, day_date.day)
+ if key not in found_dates:
+ stats.append({"date": day_date.date().isoformat(), "count": 0})
- repo_data['stats'] = stats
- return repo_data
+ repo_data["stats"] = stats
+ return repo_data
- @require_repo_write
- @nickname('updateRepo')
- @validate_json_request('RepoUpdate')
- def put(self, namespace, repository):
- """ Update the description in the specified repository. """
- if not model.repo_exists(namespace, repository):
- raise NotFound()
+ @require_repo_write
+ @nickname("updateRepo")
+ @validate_json_request("RepoUpdate")
+ def put(self, namespace, repository):
+ """ Update the description in the specified repository. """
+ if not model.repo_exists(namespace, repository):
+ raise NotFound()
- values = request.get_json()
- model.set_description(namespace, repository, values['description'])
+ values = request.get_json()
+ model.set_description(namespace, repository, values["description"])
- log_action('set_repo_description', namespace,
- {'repo': repository,
- 'namespace': namespace,
- 'description': values['description']}, repo_name=repository)
- return {'success': True}
+ log_action(
+ "set_repo_description",
+ namespace,
+ {
+ "repo": repository,
+ "namespace": namespace,
+ "description": values["description"],
+ },
+ repo_name=repository,
+ )
+ return {"success": True}
- @require_repo_admin
- @nickname('deleteRepository')
- def delete(self, namespace, repository):
- """ Delete a repository. """
- username = model.purge_repository(namespace, repository)
+ @require_repo_admin
+ @nickname("deleteRepository")
+ def delete(self, namespace, repository):
+ """ Delete a repository. """
+ username = model.purge_repository(namespace, repository)
- if features.BILLING:
- plan = get_namespace_plan(namespace)
- model.check_repository_usage(username, plan)
+ if features.BILLING:
+ plan = get_namespace_plan(namespace)
+ model.check_repository_usage(username, plan)
- # Remove any builds from the queue.
- dockerfile_build_queue.delete_namespaced_items(namespace, repository)
+ # Remove any builds from the queue.
+ dockerfile_build_queue.delete_namespaced_items(namespace, repository)
- log_action('delete_repo', namespace, {'repo': repository, 'namespace': namespace})
- return '', 204
+ log_action(
+ "delete_repo", namespace, {"repo": repository, "namespace": namespace}
+ )
+ return "", 204
-@resource('/v1/repository//changevisibility')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//changevisibility")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryVisibility(RepositoryParamResource):
- """ Custom verb for changing the visibility of the repository. """
- schemas = {
- 'ChangeVisibility': {
- 'type': 'object',
- 'description': 'Change the visibility for the repository.',
- 'required': ['visibility',],
- 'properties': {
- 'visibility': {
- 'type': 'string',
- 'description': 'Visibility which the repository will start with',
- 'enum': [
- 'public',
- 'private',
- ],
- },
- }
+ """ Custom verb for changing the visibility of the repository. """
+
+ schemas = {
+ "ChangeVisibility": {
+ "type": "object",
+ "description": "Change the visibility for the repository.",
+ "required": ["visibility"],
+ "properties": {
+ "visibility": {
+ "type": "string",
+ "description": "Visibility which the repository will start with",
+ "enum": ["public", "private"],
+ }
+ },
+ }
}
- }
- @require_repo_admin
- @nickname('changeRepoVisibility')
- @validate_json_request('ChangeVisibility')
- def post(self, namespace, repository):
- """ Change the visibility of a repository. """
- if model.repo_exists(namespace, repository):
- values = request.get_json()
- visibility = values['visibility']
- if visibility == 'private':
- check_allowed_private_repos(namespace)
+ @require_repo_admin
+ @nickname("changeRepoVisibility")
+ @validate_json_request("ChangeVisibility")
+ def post(self, namespace, repository):
+ """ Change the visibility of a repository. """
+ if model.repo_exists(namespace, repository):
+ values = request.get_json()
+ visibility = values["visibility"]
+ if visibility == "private":
+ check_allowed_private_repos(namespace)
- model.set_repository_visibility(namespace, repository, visibility)
- log_action('change_repo_visibility', namespace,
- {'repo': repository,
- 'namespace': namespace,
- 'visibility': values['visibility']}, repo_name=repository)
- return {'success': True}
+ model.set_repository_visibility(namespace, repository, visibility)
+ log_action(
+ "change_repo_visibility",
+ namespace,
+ {
+ "repo": repository,
+ "namespace": namespace,
+ "visibility": values["visibility"],
+ },
+ repo_name=repository,
+ )
+ return {"success": True}
-@resource('/v1/repository//changetrust')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//changetrust")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryTrust(RepositoryParamResource):
- """ Custom verb for changing the trust settings of the repository. """
- schemas = {
- 'ChangeRepoTrust': {
- 'type': 'object',
- 'description': 'Change the trust settings for the repository.',
- 'required': ['trust_enabled',],
- 'properties': {
- 'trust_enabled': {
- 'type': 'boolean',
- 'description': 'Whether or not signing is enabled for the repository.'
- },
- }
+ """ Custom verb for changing the trust settings of the repository. """
+
+ schemas = {
+ "ChangeRepoTrust": {
+ "type": "object",
+ "description": "Change the trust settings for the repository.",
+ "required": ["trust_enabled"],
+ "properties": {
+ "trust_enabled": {
+ "type": "boolean",
+ "description": "Whether or not signing is enabled for the repository.",
+ }
+ },
+ }
}
- }
- @show_if(features.SIGNING)
- @require_repo_admin
- @nickname('changeRepoTrust')
- @validate_json_request('ChangeRepoTrust')
- def post(self, namespace, repository):
- """ Change the visibility of a repository. """
- if not model.repo_exists(namespace, repository):
- raise NotFound()
+ @show_if(features.SIGNING)
+ @require_repo_admin
+ @nickname("changeRepoTrust")
+ @validate_json_request("ChangeRepoTrust")
+ def post(self, namespace, repository):
+ """ Change the visibility of a repository. """
+ if not model.repo_exists(namespace, repository):
+ raise NotFound()
- tags, _ = tuf_metadata_api.get_default_tags_with_expiration(namespace, repository)
- if tags and not tuf_metadata_api.delete_metadata(namespace, repository):
- raise DownstreamIssue('Unable to delete downstream trust metadata')
+ tags, _ = tuf_metadata_api.get_default_tags_with_expiration(
+ namespace, repository
+ )
+ if tags and not tuf_metadata_api.delete_metadata(namespace, repository):
+ raise DownstreamIssue("Unable to delete downstream trust metadata")
- values = request.get_json()
- model.set_trust(namespace, repository, values['trust_enabled'])
+ values = request.get_json()
+ model.set_trust(namespace, repository, values["trust_enabled"])
- log_action(
- 'change_repo_trust', namespace,
- {'repo': repository,
- 'namespace': namespace,
- 'trust_enabled': values['trust_enabled']}, repo_name=repository)
+ log_action(
+ "change_repo_trust",
+ namespace,
+ {
+ "repo": repository,
+ "namespace": namespace,
+ "trust_enabled": values["trust_enabled"],
+ },
+ repo_name=repository,
+ )
- return {'success': True}
+ return {"success": True}
-@resource('/v1/repository//changestate')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//changestate")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
@show_if(features.REPO_MIRROR)
class RepositoryStateResource(RepositoryParamResource):
- """ Custom verb for changing the state of the repository. """
- schemas = {
- 'ChangeRepoState': {
- 'type': 'object',
- 'description': 'Change the state of the repository.',
- 'required': ['state'],
- 'properties': {
- 'state': {
- 'type': 'string',
- 'description': 'Determines whether pushes are allowed.',
- 'enum': ['NORMAL', 'READ_ONLY', 'MIRROR'],
- },
- }
+ """ Custom verb for changing the state of the repository. """
+
+ schemas = {
+ "ChangeRepoState": {
+ "type": "object",
+ "description": "Change the state of the repository.",
+ "required": ["state"],
+ "properties": {
+ "state": {
+ "type": "string",
+ "description": "Determines whether pushes are allowed.",
+ "enum": ["NORMAL", "READ_ONLY", "MIRROR"],
+ }
+ },
+ }
}
- }
- @require_repo_admin
- @nickname('changeRepoState')
- @validate_json_request('ChangeRepoState')
- def put(self, namespace, repository):
- """ Change the state of a repository. """
- if not model.repo_exists(namespace, repository):
- raise NotFound()
+ @require_repo_admin
+ @nickname("changeRepoState")
+ @validate_json_request("ChangeRepoState")
+ def put(self, namespace, repository):
+ """ Change the state of a repository. """
+ if not model.repo_exists(namespace, repository):
+ raise NotFound()
- values = request.get_json()
- state_name = values['state']
+ values = request.get_json()
+ state_name = values["state"]
- try:
- state = RepositoryState[state_name]
- except KeyError:
- state = None
+ try:
+ state = RepositoryState[state_name]
+ except KeyError:
+ state = None
- if state == RepositoryState.MIRROR and not features.REPO_MIRROR:
- return {'detail': 'Unknown Repository State: %s' % state_name}, 400
+ if state == RepositoryState.MIRROR and not features.REPO_MIRROR:
+ return {"detail": "Unknown Repository State: %s" % state_name}, 400
- if state is None:
- return {'detail': '%s is not a valid Repository state.' % state_name}, 400
+ if state is None:
+ return {"detail": "%s is not a valid Repository state." % state_name}, 400
- model.set_repository_state(namespace, repository, state)
+ model.set_repository_state(namespace, repository, state)
- log_action('change_repo_state', namespace,
- {'repo': repository,
- 'namespace': namespace,
- 'state_changed': state_name}, repo_name=repository)
+ log_action(
+ "change_repo_state",
+ namespace,
+ {"repo": repository, "namespace": namespace, "state_changed": state_name},
+ repo_name=repository,
+ )
- return {'success': True}
+ return {"success": True}
diff --git a/endpoints/api/repository_models_interface.py b/endpoints/api/repository_models_interface.py
index 3b5e06a2f..4b566e2c6 100644
--- a/endpoints/api/repository_models_interface.py
+++ b/endpoints/api/repository_models_interface.py
@@ -10,13 +10,28 @@ from endpoints.api import format_date
class RepositoryBaseElement(
- namedtuple('RepositoryBaseElement', [
- 'namespace_name', 'repository_name', 'is_starred', 'is_public', 'kind_name', 'description',
- 'namespace_user_organization', 'namespace_user_removed_tag_expiration_s', 'last_modified',
- 'action_count', 'should_last_modified', 'should_popularity', 'should_is_starred',
- 'is_free_account', 'state'
- ])):
- """
+ namedtuple(
+ "RepositoryBaseElement",
+ [
+ "namespace_name",
+ "repository_name",
+ "is_starred",
+ "is_public",
+ "kind_name",
+ "description",
+ "namespace_user_organization",
+ "namespace_user_removed_tag_expiration_s",
+ "last_modified",
+ "action_count",
+ "should_last_modified",
+ "should_popularity",
+ "should_is_starred",
+ "is_free_account",
+ "state",
+ ],
+ )
+):
+ """
Repository a single quay repository
:type namespace_name: string
:type repository_name: string
@@ -30,60 +45,73 @@ class RepositoryBaseElement(
:type should_is_starred: boolean
"""
- def to_dict(self):
- repo = {
- 'namespace': self.namespace_name,
- 'name': self.repository_name,
- 'description': self.description,
- 'is_public': self.is_public,
- 'kind': self.kind_name,
- 'state': self.state.name if self.state is not None else None,
- }
+ def to_dict(self):
+ repo = {
+ "namespace": self.namespace_name,
+ "name": self.repository_name,
+ "description": self.description,
+ "is_public": self.is_public,
+ "kind": self.kind_name,
+ "state": self.state.name if self.state is not None else None,
+ }
- if self.should_last_modified:
- repo['last_modified'] = self.last_modified
+ if self.should_last_modified:
+ repo["last_modified"] = self.last_modified
- if self.should_popularity:
- repo['popularity'] = float(self.action_count if self.action_count else 0)
+ if self.should_popularity:
+ repo["popularity"] = float(self.action_count if self.action_count else 0)
- if self.should_is_starred:
- repo['is_starred'] = self.is_starred
+ if self.should_is_starred:
+ repo["is_starred"] = self.is_starred
- return repo
+ return repo
class ApplicationRepository(
- namedtuple('ApplicationRepository', ['repository_base_elements', 'channels', 'releases', 'state'])):
- """
+ namedtuple(
+ "ApplicationRepository",
+ ["repository_base_elements", "channels", "releases", "state"],
+ )
+):
+ """
Repository a single quay repository
:type repository_base_elements: RepositoryBaseElement
:type channels: [Channel]
:type releases: [Release]
"""
- def to_dict(self):
- repo_data = {
- 'namespace': self.repository_base_elements.namespace_name,
- 'name': self.repository_base_elements.repository_name,
- 'kind': self.repository_base_elements.kind_name,
- 'description': self.repository_base_elements.description,
- 'is_public': self.repository_base_elements.is_public,
- 'is_organization': self.repository_base_elements.namespace_user_organization,
- 'is_starred': self.repository_base_elements.is_starred,
- 'channels': [chan.to_dict() for chan in self.channels],
- 'releases': [release.to_dict() for release in self.releases],
- 'state': self.state.name if self.state is not None else None,
- 'is_free_account': self.repository_base_elements.is_free_account,
- }
+ def to_dict(self):
+ repo_data = {
+ "namespace": self.repository_base_elements.namespace_name,
+ "name": self.repository_base_elements.repository_name,
+ "kind": self.repository_base_elements.kind_name,
+ "description": self.repository_base_elements.description,
+ "is_public": self.repository_base_elements.is_public,
+ "is_organization": self.repository_base_elements.namespace_user_organization,
+ "is_starred": self.repository_base_elements.is_starred,
+ "channels": [chan.to_dict() for chan in self.channels],
+ "releases": [release.to_dict() for release in self.releases],
+ "state": self.state.name if self.state is not None else None,
+ "is_free_account": self.repository_base_elements.is_free_account,
+ }
- return repo_data
+ return repo_data
class ImageRepositoryRepository(
- namedtuple('NonApplicationRepository',
- ['repository_base_elements', 'tags', 'counts', 'badge_token', 'trust_enabled',
- 'state'])):
- """
+ namedtuple(
+ "NonApplicationRepository",
+ [
+ "repository_base_elements",
+ "tags",
+ "counts",
+ "badge_token",
+ "trust_enabled",
+ "state",
+ ],
+ )
+):
+ """
Repository a single quay repository
:type repository_base_elements: RepositoryBaseElement
:type tags: [Tag]
@@ -92,81 +120,95 @@ class ImageRepositoryRepository(
:type trust_enabled: boolean
"""
- def to_dict(self):
- img_repo = {
- 'namespace': self.repository_base_elements.namespace_name,
- 'name': self.repository_base_elements.repository_name,
- 'kind': self.repository_base_elements.kind_name,
- 'description': self.repository_base_elements.description,
- 'is_public': self.repository_base_elements.is_public,
- 'is_organization': self.repository_base_elements.namespace_user_organization,
- 'is_starred': self.repository_base_elements.is_starred,
- 'status_token': self.badge_token if not self.repository_base_elements.is_public else '',
- 'trust_enabled': bool(features.SIGNING) and self.trust_enabled,
- 'tag_expiration_s': self.repository_base_elements.namespace_user_removed_tag_expiration_s,
- 'is_free_account': self.repository_base_elements.is_free_account,
- 'state': self.state.name if self.state is not None else None
- }
+ def to_dict(self):
+ img_repo = {
+ "namespace": self.repository_base_elements.namespace_name,
+ "name": self.repository_base_elements.repository_name,
+ "kind": self.repository_base_elements.kind_name,
+ "description": self.repository_base_elements.description,
+ "is_public": self.repository_base_elements.is_public,
+ "is_organization": self.repository_base_elements.namespace_user_organization,
+ "is_starred": self.repository_base_elements.is_starred,
+ "status_token": self.badge_token
+ if not self.repository_base_elements.is_public
+ else "",
+ "trust_enabled": bool(features.SIGNING) and self.trust_enabled,
+ "tag_expiration_s": self.repository_base_elements.namespace_user_removed_tag_expiration_s,
+ "is_free_account": self.repository_base_elements.is_free_account,
+ "state": self.state.name if self.state is not None else None,
+ }
- if self.tags is not None:
- img_repo['tags'] = {tag.name: tag.to_dict() for tag in self.tags}
+ if self.tags is not None:
+ img_repo["tags"] = {tag.name: tag.to_dict() for tag in self.tags}
- if self.repository_base_elements.state:
- img_repo['state'] = self.repository_base_elements.state.name
+ if self.repository_base_elements.state:
+ img_repo["state"] = self.repository_base_elements.state.name
- return img_repo
+ return img_repo
-class Repository(namedtuple('Repository', [
- 'namespace_name',
- 'repository_name',
-])):
- """
+class Repository(namedtuple("Repository", ["namespace_name", "repository_name"])):
+ """
Repository a single quay repository
:type namespace_name: string
:type repository_name: string
"""
-class Channel(namedtuple('Channel', ['name', 'linked_tag_name', 'linked_tag_lifetime_start'])):
- """
+class Channel(
+ namedtuple("Channel", ["name", "linked_tag_name", "linked_tag_lifetime_start"])
+):
+ """
Repository a single quay repository
:type name: string
:type linked_tag_name: string
:type linked_tag_lifetime_start: string
"""
- def to_dict(self):
- return {
- 'name': self.name,
- 'release': self.linked_tag_name,
- 'last_modified': format_date(datetime.fromtimestamp(self.linked_tag_lifetime_start / 1000)),
- }
+ def to_dict(self):
+ return {
+ "name": self.name,
+ "release": self.linked_tag_name,
+ "last_modified": format_date(
+ datetime.fromtimestamp(self.linked_tag_lifetime_start / 1000)
+ ),
+ }
class Release(
- namedtuple('Channel', ['name', 'lifetime_start', 'releases_channels_map'])):
- """
+ namedtuple("Channel", ["name", "lifetime_start", "releases_channels_map"])
+):
+ """
Repository a single quay repository
:type name: string
:type last_modified: string
:type releases_channels_map: {string -> string}
"""
- def to_dict(self):
- return {
- 'name': self.name,
- 'last_modified': format_date(datetime.fromtimestamp(self.lifetime_start / 1000)),
- 'channels': self.releases_channels_map[self.name],
- }
+ def to_dict(self):
+ return {
+ "name": self.name,
+ "last_modified": format_date(
+ datetime.fromtimestamp(self.lifetime_start / 1000)
+ ),
+ "channels": self.releases_channels_map[self.name],
+ }
class Tag(
- namedtuple('Tag', [
- 'name', 'image_docker_image_id', 'image_aggregate_size', 'lifetime_start_ts',
- 'tag_manifest_digest', 'lifetime_end_ts',
- ])):
- """
+ namedtuple(
+ "Tag",
+ [
+ "name",
+ "image_docker_image_id",
+ "image_aggregate_size",
+ "lifetime_start_ts",
+ "tag_manifest_digest",
+ "lifetime_end_ts",
+ ],
+ )
+):
+ """
:type name: string
:type image_docker_image_id: string
:type image_aggregate_size: int
@@ -176,104 +218,120 @@ class Tag(
"""
- def to_dict(self):
- tag_info = {
- 'name': self.name,
- 'image_id': self.image_docker_image_id,
- 'size': self.image_aggregate_size
- }
+ def to_dict(self):
+ tag_info = {
+ "name": self.name,
+ "image_id": self.image_docker_image_id,
+ "size": self.image_aggregate_size,
+ }
- if self.lifetime_start_ts > 0:
- last_modified = format_date(datetime.fromtimestamp(self.lifetime_start_ts))
- tag_info['last_modified'] = last_modified
+ if self.lifetime_start_ts > 0:
+ last_modified = format_date(datetime.fromtimestamp(self.lifetime_start_ts))
+ tag_info["last_modified"] = last_modified
- if self.lifetime_end_ts:
- expiration = format_date(datetime.fromtimestamp(self.lifetime_end_ts))
- tag_info['expiration'] = expiration
+ if self.lifetime_end_ts:
+ expiration = format_date(datetime.fromtimestamp(self.lifetime_end_ts))
+ tag_info["expiration"] = expiration
- if self.tag_manifest_digest is not None:
- tag_info['manifest_digest'] = self.tag_manifest_digest
+ if self.tag_manifest_digest is not None:
+ tag_info["manifest_digest"] = self.tag_manifest_digest
- return tag_info
+ return tag_info
-class Count(namedtuple('Count', ['date', 'count'])):
- """
+class Count(namedtuple("Count", ["date", "count"])):
+ """
date: DateTime
count: int
"""
- def to_dict(self):
- return {
- 'date': self.date.isoformat(),
- 'count': self.count,
- }
+ def to_dict(self):
+ return {"date": self.date.isoformat(), "count": self.count}
@add_metaclass(ABCMeta)
class RepositoryDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by a Repository.
"""
- @abstractmethod
- def get_repo(self, namespace_name, repository_name, user, include_tags=True, max_tags=500):
- """
+ @abstractmethod
+ def get_repo(
+ self, namespace_name, repository_name, user, include_tags=True, max_tags=500
+ ):
+ """
Returns a repository
"""
- @abstractmethod
- def repo_exists(self, namespace_name, repository_name):
- """
+ @abstractmethod
+ def repo_exists(self, namespace_name, repository_name):
+ """
Returns true if a repo exists and false if not
"""
- @abstractmethod
- def create_repo(self, namespace, name, creating_user, description, visibility='private',
- repo_kind='image'):
- """
+ @abstractmethod
+ def create_repo(
+ self,
+ namespace,
+ name,
+ creating_user,
+ description,
+ visibility="private",
+ repo_kind="image",
+ ):
+ """
Returns creates a new repo
"""
- @abstractmethod
- def get_repo_list(self, starred, user, repo_kind, namespace, username, public, page_token,
- last_modified, popularity):
- """
+ @abstractmethod
+ def get_repo_list(
+ self,
+ starred,
+ user,
+ repo_kind,
+ namespace,
+ username,
+ public,
+ page_token,
+ last_modified,
+ popularity,
+ ):
+ """
Returns a RepositoryBaseElement
"""
- @abstractmethod
- def set_repository_visibility(self, namespace_name, repository_name, visibility):
- """
+ @abstractmethod
+ def set_repository_visibility(self, namespace_name, repository_name, visibility):
+ """
Sets a repository's visibility if it is found
"""
- @abstractmethod
- def set_trust(self, namespace_name, repository_name, trust):
- """
+ @abstractmethod
+ def set_trust(self, namespace_name, repository_name, trust):
+ """
Sets a repository's trust_enabled field if it is found
"""
- @abstractmethod
- def set_description(self, namespace_name, repository_name, description):
- """
+ @abstractmethod
+ def set_description(self, namespace_name, repository_name, description):
+ """
Sets a repository's description if it is found.
"""
- @abstractmethod
- def purge_repository(self, namespace_name, repository_name):
- """
+ @abstractmethod
+ def purge_repository(self, namespace_name, repository_name):
+ """
Removes a repository
"""
- @abstractmethod
- def check_repository_usage(self, user_name, plan_found):
- """
+ @abstractmethod
+ def check_repository_usage(self, user_name, plan_found):
+ """
Creates a notification for a user if they are over or under on their repository usage
"""
- @abstractmethod
- def set_repository_state(self, namespace_name, repository_name, state):
- """
+ @abstractmethod
+ def set_repository_state(self, namespace_name, repository_name, state):
+ """
Set the State of the Repository.
"""
diff --git a/endpoints/api/repository_models_pre_oci.py b/endpoints/api/repository_models_pre_oci.py
index 328c5443e..23cb3c030 100644
--- a/endpoints/api/repository_models_pre_oci.py
+++ b/endpoints/api/repository_models_pre_oci.py
@@ -9,182 +9,285 @@ from data.appr_model import channel as channel_model, release as release_model
from data.registry_model import registry_model
from data.registry_model.datatypes import RepositoryReference
from endpoints.appr.models_cnr import model as appr_model
-from endpoints.api.repository_models_interface import RepositoryDataInterface, RepositoryBaseElement, Repository, \
- ApplicationRepository, ImageRepositoryRepository, Tag, Channel, Release, Count
+from endpoints.api.repository_models_interface import (
+ RepositoryDataInterface,
+ RepositoryBaseElement,
+ Repository,
+ ApplicationRepository,
+ ImageRepositoryRepository,
+ Tag,
+ Channel,
+ Release,
+ Count,
+)
MAX_DAYS_IN_3_MONTHS = 92
REPOS_PER_PAGE = 100
def _create_channel(channel, releases_channels_map):
- releases_channels_map[channel.linked_tag.name].append(channel.name)
- return Channel(channel.name, channel.linked_tag.name, channel.linked_tag.lifetime_start)
+ releases_channels_map[channel.linked_tag.name].append(channel.name)
+ return Channel(
+ channel.name, channel.linked_tag.name, channel.linked_tag.lifetime_start
+ )
class PreOCIModel(RepositoryDataInterface):
- """
+ """
PreOCIModel implements the data model for the Repo Email using a database schema
before it was changed to support the OCI specification.
"""
- def check_repository_usage(self, username, plan_found):
- private_repos = model.user.get_private_repo_count(username)
- if plan_found is None:
- repos_allowed = 0
- else:
- repos_allowed = plan_found['privateRepos']
+ def check_repository_usage(self, username, plan_found):
+ private_repos = model.user.get_private_repo_count(username)
+ if plan_found is None:
+ repos_allowed = 0
+ else:
+ repos_allowed = plan_found["privateRepos"]
- user_or_org = model.user.get_namespace_user(username)
- if private_repos > repos_allowed:
- model.notification.create_unique_notification('over_private_usage', user_or_org,
- {'namespace': username})
- else:
- model.notification.delete_notifications_by_kind(user_or_org, 'over_private_usage')
+ user_or_org = model.user.get_namespace_user(username)
+ if private_repos > repos_allowed:
+ model.notification.create_unique_notification(
+ "over_private_usage", user_or_org, {"namespace": username}
+ )
+ else:
+ model.notification.delete_notifications_by_kind(
+ user_or_org, "over_private_usage"
+ )
- def purge_repository(self, namespace_name, repository_name):
- model.gc.purge_repository(namespace_name, repository_name)
- user = model.user.get_namespace_user(namespace_name)
- return user.username
+ def purge_repository(self, namespace_name, repository_name):
+ model.gc.purge_repository(namespace_name, repository_name)
+ user = model.user.get_namespace_user(namespace_name)
+ return user.username
- def set_description(self, namespace_name, repository_name, description):
- repo = model.repository.get_repository(namespace_name, repository_name)
- model.repository.set_description(repo, description)
+ def set_description(self, namespace_name, repository_name, description):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ model.repository.set_description(repo, description)
- def set_trust(self, namespace_name, repository_name, trust):
- repo = model.repository.get_repository(namespace_name, repository_name)
- model.repository.set_trust(repo, trust)
+ def set_trust(self, namespace_name, repository_name, trust):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ model.repository.set_trust(repo, trust)
- def set_repository_visibility(self, namespace_name, repository_name, visibility):
- repo = model.repository.get_repository(namespace_name, repository_name)
- model.repository.set_repository_visibility(repo, visibility)
+ def set_repository_visibility(self, namespace_name, repository_name, visibility):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ model.repository.set_repository_visibility(repo, visibility)
- def set_repository_state(self, namespace_name, repository_name, state):
- repo = model.repository.get_repository(namespace_name, repository_name)
- model.repository.set_repository_state(repo, state)
+ def set_repository_state(self, namespace_name, repository_name, state):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ model.repository.set_repository_state(repo, state)
- def get_repo_list(self, starred, user, repo_kind, namespace, username, public, page_token,
- last_modified, popularity):
- next_page_token = None
- # Lookup the requested repositories (either starred or non-starred.)
- if starred:
- # Return the full list of repos starred by the current user that are still visible to them.
- def can_view_repo(repo):
- can_view = ReadRepositoryPermission(repo.namespace_user.username, repo.name).can()
- return can_view or model.repository.is_repository_public(repo)
+ def get_repo_list(
+ self,
+ starred,
+ user,
+ repo_kind,
+ namespace,
+ username,
+ public,
+ page_token,
+ last_modified,
+ popularity,
+ ):
+ next_page_token = None
+ # Lookup the requested repositories (either starred or non-starred.)
+ if starred:
+ # Return the full list of repos starred by the current user that are still visible to them.
+ def can_view_repo(repo):
+ can_view = ReadRepositoryPermission(
+ repo.namespace_user.username, repo.name
+ ).can()
+ return can_view or model.repository.is_repository_public(repo)
- unfiltered_repos = model.repository.get_user_starred_repositories(user,
- kind_filter=repo_kind)
- repos = [repo for repo in unfiltered_repos if can_view_repo(repo)]
- elif namespace:
- # Repositories filtered by namespace do not need pagination (their results are fairly small),
- # so we just do the lookup directly.
- repos = list(
- model.repository.get_visible_repositories(username=username, include_public=public,
- namespace=namespace, kind_filter=repo_kind))
- else:
- # Determine the starting offset for pagination. Note that we don't use the normal
- # model.modelutil.paginate method here, as that does not operate over UNION queries, which
- # get_visible_repositories will return if there is a logged-in user (for performance reasons).
- #
- # Also note the +1 on the limit, as paginate_query uses the extra result to determine whether
- # there is a next page.
- start_id = model.modelutil.pagination_start(page_token)
- repo_query = model.repository.get_visible_repositories(
- username=username, include_public=public, start_id=start_id, limit=REPOS_PER_PAGE + 1,
- kind_filter=repo_kind)
+ unfiltered_repos = model.repository.get_user_starred_repositories(
+ user, kind_filter=repo_kind
+ )
+ repos = [repo for repo in unfiltered_repos if can_view_repo(repo)]
+ elif namespace:
+ # Repositories filtered by namespace do not need pagination (their results are fairly small),
+ # so we just do the lookup directly.
+ repos = list(
+ model.repository.get_visible_repositories(
+ username=username,
+ include_public=public,
+ namespace=namespace,
+ kind_filter=repo_kind,
+ )
+ )
+ else:
+ # Determine the starting offset for pagination. Note that we don't use the normal
+ # model.modelutil.paginate method here, as that does not operate over UNION queries, which
+ # get_visible_repositories will return if there is a logged-in user (for performance reasons).
+ #
+ # Also note the +1 on the limit, as paginate_query uses the extra result to determine whether
+ # there is a next page.
+ start_id = model.modelutil.pagination_start(page_token)
+ repo_query = model.repository.get_visible_repositories(
+ username=username,
+ include_public=public,
+ start_id=start_id,
+ limit=REPOS_PER_PAGE + 1,
+ kind_filter=repo_kind,
+ )
- repos, next_page_token = model.modelutil.paginate_query(repo_query, limit=REPOS_PER_PAGE,
- sort_field_name='rid')
+ repos, next_page_token = model.modelutil.paginate_query(
+ repo_query, limit=REPOS_PER_PAGE, sort_field_name="rid"
+ )
- # Collect the IDs of the repositories found for subequent lookup of popularity
- # and/or last modified.
- last_modified_map = {}
- action_sum_map = {}
- if last_modified or popularity:
- repository_refs = [RepositoryReference.for_id(repo.rid) for repo in repos]
- repository_ids = [repo.rid for repo in repos]
+ # Collect the IDs of the repositories found for subequent lookup of popularity
+ # and/or last modified.
+ last_modified_map = {}
+ action_sum_map = {}
+ if last_modified or popularity:
+ repository_refs = [RepositoryReference.for_id(repo.rid) for repo in repos]
+ repository_ids = [repo.rid for repo in repos]
- if last_modified:
- last_modified_map = registry_model.get_most_recent_tag_lifetime_start(repository_refs)
+ if last_modified:
+ last_modified_map = registry_model.get_most_recent_tag_lifetime_start(
+ repository_refs
+ )
- if popularity:
- action_sum_map = model.log.get_repositories_action_sums(repository_ids)
+ if popularity:
+ action_sum_map = model.log.get_repositories_action_sums(repository_ids)
- # Collect the IDs of the repositories that are starred for the user, so we can mark them
- # in the returned results.
- star_set = set()
- if username:
- starred_repos = model.repository.get_user_starred_repositories(user)
- star_set = {starred.id for starred in starred_repos}
+ # Collect the IDs of the repositories that are starred for the user, so we can mark them
+ # in the returned results.
+ star_set = set()
+ if username:
+ starred_repos = model.repository.get_user_starred_repositories(user)
+ star_set = {starred.id for starred in starred_repos}
- return [
- RepositoryBaseElement(repo.namespace_user.username, repo.name, repo.id in star_set,
- repo.visibility_id == model.repository.get_public_repo_visibility().id,
- repo_kind, repo.description, repo.namespace_user.organization,
- repo.namespace_user.removed_tag_expiration_s,
- last_modified_map.get(repo.rid),
- action_sum_map.get(repo.rid), last_modified, popularity, username,
- None, repo.state)
- for repo in repos
- ], next_page_token
+ return (
+ [
+ RepositoryBaseElement(
+ repo.namespace_user.username,
+ repo.name,
+ repo.id in star_set,
+ repo.visibility_id
+ == model.repository.get_public_repo_visibility().id,
+ repo_kind,
+ repo.description,
+ repo.namespace_user.organization,
+ repo.namespace_user.removed_tag_expiration_s,
+ last_modified_map.get(repo.rid),
+ action_sum_map.get(repo.rid),
+ last_modified,
+ popularity,
+ username,
+ None,
+ repo.state,
+ )
+ for repo in repos
+ ],
+ next_page_token,
+ )
- def repo_exists(self, namespace_name, repository_name):
- repo = model.repository.get_repository(namespace_name, repository_name)
- if repo is None:
- return False
+ def repo_exists(self, namespace_name, repository_name):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if repo is None:
+ return False
- return True
+ return True
- def create_repo(self, namespace_name, repository_name, owner, description, visibility='private',
- repo_kind='image'):
- repo = model.repository.create_repository(namespace_name, repository_name, owner, visibility,
- repo_kind=repo_kind, description=description)
- return Repository(namespace_name, repository_name)
+ def create_repo(
+ self,
+ namespace_name,
+ repository_name,
+ owner,
+ description,
+ visibility="private",
+ repo_kind="image",
+ ):
+ repo = model.repository.create_repository(
+ namespace_name,
+ repository_name,
+ owner,
+ visibility,
+ repo_kind=repo_kind,
+ description=description,
+ )
+ return Repository(namespace_name, repository_name)
- def get_repo(self, namespace_name, repository_name, user, include_tags=True, max_tags=500):
- repo = model.repository.get_repository(namespace_name, repository_name)
- if repo is None:
- return None
+ def get_repo(
+ self, namespace_name, repository_name, user, include_tags=True, max_tags=500
+ ):
+ repo = model.repository.get_repository(namespace_name, repository_name)
+ if repo is None:
+ return None
- is_starred = model.repository.repository_is_starred(user, repo) if user else False
- is_public = model.repository.is_repository_public(repo)
- kind_name = RepositoryTable.kind.get_name(repo.kind_id)
- base = RepositoryBaseElement(
- namespace_name, repository_name, is_starred, is_public, kind_name, repo.description,
- repo.namespace_user.organization, repo.namespace_user.removed_tag_expiration_s, None, None,
- False, False, False, repo.namespace_user.stripe_id is None, repo.state)
+ is_starred = (
+ model.repository.repository_is_starred(user, repo) if user else False
+ )
+ is_public = model.repository.is_repository_public(repo)
+ kind_name = RepositoryTable.kind.get_name(repo.kind_id)
+ base = RepositoryBaseElement(
+ namespace_name,
+ repository_name,
+ is_starred,
+ is_public,
+ kind_name,
+ repo.description,
+ repo.namespace_user.organization,
+ repo.namespace_user.removed_tag_expiration_s,
+ None,
+ None,
+ False,
+ False,
+ False,
+ repo.namespace_user.stripe_id is None,
+ repo.state,
+ )
- if base.kind_name == 'application':
- channels = channel_model.get_repo_channels(repo, appr_model.models_ref)
- releases = release_model.get_release_objs(repo, appr_model.models_ref)
- releases_channels_map = defaultdict(list)
- return ApplicationRepository(
- base, [_create_channel(channel, releases_channels_map) for channel in channels], [
- Release(release.name, release.lifetime_start, releases_channels_map)
- for release in releases
- ], repo.state)
+ if base.kind_name == "application":
+ channels = channel_model.get_repo_channels(repo, appr_model.models_ref)
+ releases = release_model.get_release_objs(repo, appr_model.models_ref)
+ releases_channels_map = defaultdict(list)
+ return ApplicationRepository(
+ base,
+ [
+ _create_channel(channel, releases_channels_map)
+ for channel in channels
+ ],
+ [
+ Release(release.name, release.lifetime_start, releases_channels_map)
+ for release in releases
+ ],
+ repo.state,
+ )
- tags = None
- repo_ref = RepositoryReference.for_repo_obj(repo)
- if include_tags:
- tags, _ = registry_model.list_repository_tag_history(repo_ref, page=1, size=max_tags,
- active_tags_only=True)
- tags = [
- Tag(tag.name,
- tag.legacy_image.docker_image_id if tag.legacy_image_if_present else None,
- tag.legacy_image.aggregate_size if tag.legacy_image_if_present else None,
- tag.lifetime_start_ts,
- tag.manifest_digest,
- tag.lifetime_end_ts) for tag in tags
- ]
+ tags = None
+ repo_ref = RepositoryReference.for_repo_obj(repo)
+ if include_tags:
+ tags, _ = registry_model.list_repository_tag_history(
+ repo_ref, page=1, size=max_tags, active_tags_only=True
+ )
+ tags = [
+ Tag(
+ tag.name,
+ tag.legacy_image.docker_image_id
+ if tag.legacy_image_if_present
+ else None,
+ tag.legacy_image.aggregate_size
+ if tag.legacy_image_if_present
+ else None,
+ tag.lifetime_start_ts,
+ tag.manifest_digest,
+ tag.lifetime_end_ts,
+ )
+ for tag in tags
+ ]
- start_date = datetime.now() - timedelta(days=MAX_DAYS_IN_3_MONTHS)
- counts = model.log.get_repository_action_counts(repo, start_date)
+ start_date = datetime.now() - timedelta(days=MAX_DAYS_IN_3_MONTHS)
+ counts = model.log.get_repository_action_counts(repo, start_date)
- assert repo.state is not None
- return ImageRepositoryRepository(base, tags,
- [Count(count.date, count.count) for count in counts],
- repo.badge_token, repo.trust_enabled, repo.state)
+ assert repo.state is not None
+ return ImageRepositoryRepository(
+ base,
+ tags,
+ [Count(count.date, count.count) for count in counts],
+ repo.badge_token,
+ repo.trust_enabled,
+ repo.state,
+ )
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/repositorynotification.py b/endpoints/api/repositorynotification.py
index c34cbc553..89bac8d72 100644
--- a/endpoints/api/repositorynotification.py
+++ b/endpoints/api/repositorynotification.py
@@ -4,161 +4,199 @@ import logging
from flask import request
from endpoints.api import (
- RepositoryParamResource, nickname, resource, require_repo_admin, log_action,
- validate_json_request, request_error, path_param, disallow_for_app_repositories, InvalidRequest)
+ RepositoryParamResource,
+ nickname,
+ resource,
+ require_repo_admin,
+ log_action,
+ validate_json_request,
+ request_error,
+ path_param,
+ disallow_for_app_repositories,
+ InvalidRequest,
+)
from endpoints.exception import NotFound
from notifications.models_interface import Repository
from notifications.notificationevent import NotificationEvent
from notifications.notificationmethod import (
- NotificationMethod, CannotValidateNotificationMethodException)
+ NotificationMethod,
+ CannotValidateNotificationMethodException,
+)
from endpoints.api.repositorynotification_models_pre_oci import pre_oci_model as model
logger = logging.getLogger(__name__)
-@resource('/v1/repository//notification/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//notification/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryNotificationList(RepositoryParamResource):
- """ Resource for dealing with listing and creating notifications on a repository. """
- schemas = {
- 'NotificationCreateRequest': {
- 'type': 'object',
- 'description': 'Information for creating a notification on a repository',
- 'required': [
- 'event',
- 'method',
- 'config',
- 'eventConfig',
- ],
- 'properties': {
- 'event': {
- 'type': 'string',
- 'description': 'The event on which the notification will respond',
- },
- 'method': {
- 'type': 'string',
- 'description': 'The method of notification (such as email or web callback)',
- },
- 'config': {
- 'type': 'object',
- 'description': 'JSON config information for the specific method of notification'
- },
- 'eventConfig': {
- 'type': 'object',
- 'description': 'JSON config information for the specific event of notification',
- },
- 'title': {
- 'type': 'string',
- 'description': 'The human-readable title of the notification',
- },
- }
- },
- }
+ """ Resource for dealing with listing and creating notifications on a repository. """
- @require_repo_admin
- @nickname('createRepoNotification')
- @disallow_for_app_repositories
- @validate_json_request('NotificationCreateRequest')
- def post(self, namespace_name, repository_name):
- parsed = request.get_json()
+ schemas = {
+ "NotificationCreateRequest": {
+ "type": "object",
+ "description": "Information for creating a notification on a repository",
+ "required": ["event", "method", "config", "eventConfig"],
+ "properties": {
+ "event": {
+ "type": "string",
+ "description": "The event on which the notification will respond",
+ },
+ "method": {
+ "type": "string",
+ "description": "The method of notification (such as email or web callback)",
+ },
+ "config": {
+ "type": "object",
+ "description": "JSON config information for the specific method of notification",
+ },
+ "eventConfig": {
+ "type": "object",
+ "description": "JSON config information for the specific event of notification",
+ },
+ "title": {
+ "type": "string",
+ "description": "The human-readable title of the notification",
+ },
+ },
+ }
+ }
- method_handler = NotificationMethod.get_method(parsed['method'])
- try:
- method_handler.validate(namespace_name, repository_name, parsed['config'])
- except CannotValidateNotificationMethodException as ex:
- raise request_error(message=ex.message)
+ @require_repo_admin
+ @nickname("createRepoNotification")
+ @disallow_for_app_repositories
+ @validate_json_request("NotificationCreateRequest")
+ def post(self, namespace_name, repository_name):
+ parsed = request.get_json()
- new_notification = model.create_repo_notification(namespace_name, repository_name,
- parsed['event'], parsed['method'],
- parsed['config'], parsed['eventConfig'],
- parsed.get('title'))
+ method_handler = NotificationMethod.get_method(parsed["method"])
+ try:
+ method_handler.validate(namespace_name, repository_name, parsed["config"])
+ except CannotValidateNotificationMethodException as ex:
+ raise request_error(message=ex.message)
- log_action('add_repo_notification', namespace_name, {
- 'repo': repository_name,
- 'namespace': namespace_name,
- 'notification_id': new_notification.uuid,
- 'event': new_notification.event_name,
- 'method': new_notification.method_name}, repo_name=repository_name)
- return new_notification.to_dict(), 201
+ new_notification = model.create_repo_notification(
+ namespace_name,
+ repository_name,
+ parsed["event"],
+ parsed["method"],
+ parsed["config"],
+ parsed["eventConfig"],
+ parsed.get("title"),
+ )
- @require_repo_admin
- @nickname('listRepoNotifications')
- @disallow_for_app_repositories
- def get(self, namespace_name, repository_name):
- """ List the notifications for the specified repository. """
- notifications = model.list_repo_notifications(namespace_name, repository_name)
- return {'notifications': [n.to_dict() for n in notifications]}
+ log_action(
+ "add_repo_notification",
+ namespace_name,
+ {
+ "repo": repository_name,
+ "namespace": namespace_name,
+ "notification_id": new_notification.uuid,
+ "event": new_notification.event_name,
+ "method": new_notification.method_name,
+ },
+ repo_name=repository_name,
+ )
+ return new_notification.to_dict(), 201
+
+ @require_repo_admin
+ @nickname("listRepoNotifications")
+ @disallow_for_app_repositories
+ def get(self, namespace_name, repository_name):
+ """ List the notifications for the specified repository. """
+ notifications = model.list_repo_notifications(namespace_name, repository_name)
+ return {"notifications": [n.to_dict() for n in notifications]}
-@resource('/v1/repository//notification/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('uuid', 'The UUID of the notification')
+@resource("/v1/repository//notification/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("uuid", "The UUID of the notification")
class RepositoryNotification(RepositoryParamResource):
- """ Resource for dealing with specific notifications. """
+ """ Resource for dealing with specific notifications. """
- @require_repo_admin
- @nickname('getRepoNotification')
- @disallow_for_app_repositories
- def get(self, namespace_name, repository_name, uuid):
- """ Get information for the specified notification. """
- found = model.get_repo_notification(uuid)
- if not found:
- raise NotFound()
- return found.to_dict()
+ @require_repo_admin
+ @nickname("getRepoNotification")
+ @disallow_for_app_repositories
+ def get(self, namespace_name, repository_name, uuid):
+ """ Get information for the specified notification. """
+ found = model.get_repo_notification(uuid)
+ if not found:
+ raise NotFound()
+ return found.to_dict()
- @require_repo_admin
- @nickname('deleteRepoNotification')
- @disallow_for_app_repositories
- def delete(self, namespace_name, repository_name, uuid):
- """ Deletes the specified notification. """
- deleted = model.delete_repo_notification(namespace_name, repository_name, uuid)
- if not deleted:
- raise InvalidRequest("No repository notification found for: %s, %s, %s" %
- (namespace_name, repository_name, uuid))
+ @require_repo_admin
+ @nickname("deleteRepoNotification")
+ @disallow_for_app_repositories
+ def delete(self, namespace_name, repository_name, uuid):
+ """ Deletes the specified notification. """
+ deleted = model.delete_repo_notification(namespace_name, repository_name, uuid)
+ if not deleted:
+ raise InvalidRequest(
+ "No repository notification found for: %s, %s, %s"
+ % (namespace_name, repository_name, uuid)
+ )
- log_action('delete_repo_notification', namespace_name, {
- 'repo': repository_name,
- 'namespace': namespace_name,
- 'notification_id': uuid,
- 'event': deleted.event_name,
- 'method': deleted.method_name}, repo_name=repository_name)
+ log_action(
+ "delete_repo_notification",
+ namespace_name,
+ {
+ "repo": repository_name,
+ "namespace": namespace_name,
+ "notification_id": uuid,
+ "event": deleted.event_name,
+ "method": deleted.method_name,
+ },
+ repo_name=repository_name,
+ )
- return 'No Content', 204
+ return "No Content", 204
- @require_repo_admin
- @nickname('resetRepositoryNotificationFailures')
- @disallow_for_app_repositories
- def post(self, namespace_name, repository_name, uuid):
- """ Resets repository notification to 0 failures. """
- reset = model.reset_notification_number_of_failures(namespace_name, repository_name, uuid)
- if not reset:
- raise InvalidRequest("No repository notification found for: %s, %s, %s" %
- (namespace_name, repository_name, uuid))
+ @require_repo_admin
+ @nickname("resetRepositoryNotificationFailures")
+ @disallow_for_app_repositories
+ def post(self, namespace_name, repository_name, uuid):
+ """ Resets repository notification to 0 failures. """
+ reset = model.reset_notification_number_of_failures(
+ namespace_name, repository_name, uuid
+ )
+ if not reset:
+ raise InvalidRequest(
+ "No repository notification found for: %s, %s, %s"
+ % (namespace_name, repository_name, uuid)
+ )
- log_action('reset_repo_notification', namespace_name, {
- 'repo': repository_name,
- 'namespace': namespace_name,
- 'notification_id': uuid,
- 'event': reset.event_name,
- 'method': reset.method_name}, repo_name=repository_name)
+ log_action(
+ "reset_repo_notification",
+ namespace_name,
+ {
+ "repo": repository_name,
+ "namespace": namespace_name,
+ "notification_id": uuid,
+ "event": reset.event_name,
+ "method": reset.method_name,
+ },
+ repo_name=repository_name,
+ )
- return 'No Content', 204
+ return "No Content", 204
-@resource('/v1/repository//notification//test')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('uuid', 'The UUID of the notification')
+@resource("/v1/repository//notification//test")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("uuid", "The UUID of the notification")
class TestRepositoryNotification(RepositoryParamResource):
- """ Resource for queuing a test of a notification. """
+ """ Resource for queuing a test of a notification. """
- @require_repo_admin
- @nickname('testRepoNotification')
- @disallow_for_app_repositories
- def post(self, namespace_name, repository_name, uuid):
- """ Queues a test notification for this repository. """
- test_note = model.queue_test_notification(uuid)
- if not test_note:
- raise InvalidRequest("No repository notification found for: %s, %s, %s" %
- (namespace_name, repository_name, uuid))
+ @require_repo_admin
+ @nickname("testRepoNotification")
+ @disallow_for_app_repositories
+ def post(self, namespace_name, repository_name, uuid):
+ """ Queues a test notification for this repository. """
+ test_note = model.queue_test_notification(uuid)
+ if not test_note:
+ raise InvalidRequest(
+ "No repository notification found for: %s, %s, %s"
+ % (namespace_name, repository_name, uuid)
+ )
- return {}, 200
+ return {}, 200
diff --git a/endpoints/api/repositorynotification_models_interface.py b/endpoints/api/repositorynotification_models_interface.py
index ed0ebd2f7..a8789ddf8 100644
--- a/endpoints/api/repositorynotification_models_interface.py
+++ b/endpoints/api/repositorynotification_models_interface.py
@@ -7,16 +7,20 @@ from six import add_metaclass
class RepositoryNotification(
- namedtuple('RepositoryNotification', [
- 'uuid',
- 'title',
- 'event_name',
- 'method_name',
- 'config_json',
- 'event_config_json',
- 'number_of_failures',
- ])):
- """
+ namedtuple(
+ "RepositoryNotification",
+ [
+ "uuid",
+ "title",
+ "event_name",
+ "method_name",
+ "config_json",
+ "event_config_json",
+ "number_of_failures",
+ ],
+ )
+):
+ """
RepositoryNotification represents a notification for a repository.
:type uuid: string
:type event: string
@@ -27,38 +31,46 @@ class RepositoryNotification(
:type number_of_failures: int
"""
- def to_dict(self):
- try:
- config = json.loads(self.config_json)
- except ValueError:
- config = {}
+ def to_dict(self):
+ try:
+ config = json.loads(self.config_json)
+ except ValueError:
+ config = {}
- try:
- event_config = json.loads(self.event_config_json)
- except ValueError:
- event_config = {}
+ try:
+ event_config = json.loads(self.event_config_json)
+ except ValueError:
+ event_config = {}
- return {
- 'uuid': self.uuid,
- 'title': self.title,
- 'event': self.event_name,
- 'method': self.method_name,
- 'config': config,
- 'event_config': event_config,
- 'number_of_failures': self.number_of_failures,
- }
+ return {
+ "uuid": self.uuid,
+ "title": self.title,
+ "event": self.event_name,
+ "method": self.method_name,
+ "config": config,
+ "event_config": event_config,
+ "number_of_failures": self.number_of_failures,
+ }
@add_metaclass(ABCMeta)
class RepoNotificationInterface(object):
- """
+ """
Interface that represents all data store interactions required by the RepositoryNotification API
"""
- @abstractmethod
- def create_repo_notification(self, namespace_name, repository_name, event_name, method_name,
- method_config, event_config, title=None):
- """
+ @abstractmethod
+ def create_repo_notification(
+ self,
+ namespace_name,
+ repository_name,
+ event_name,
+ method_name,
+ method_config,
+ event_config,
+ title=None,
+ ):
+ """
Args:
namespace_name: namespace of repository
@@ -73,11 +85,11 @@ class RepoNotificationInterface(object):
RepositoryNotification object
"""
- pass
+ pass
- @abstractmethod
- def list_repo_notifications(self, namespace_name, repository_name, event_name=None):
- """
+ @abstractmethod
+ def list_repo_notifications(self, namespace_name, repository_name, event_name=None):
+ """
Args:
namespace_name: namespace of repository
@@ -87,11 +99,11 @@ class RepoNotificationInterface(object):
Returns:
list(RepositoryNotification)
"""
- pass
+ pass
- @abstractmethod
- def get_repo_notification(self, uuid):
- """
+ @abstractmethod
+ def get_repo_notification(self, uuid):
+ """
Args:
uuid: uuid of notification
@@ -100,11 +112,11 @@ class RepoNotificationInterface(object):
RepositoryNotification or None
"""
- pass
+ pass
- @abstractmethod
- def delete_repo_notification(self, namespace_name, repository_name, uuid):
- """
+ @abstractmethod
+ def delete_repo_notification(self, namespace_name, repository_name, uuid):
+ """
Args:
namespace_name: namespace of repository
@@ -115,11 +127,13 @@ class RepoNotificationInterface(object):
RepositoryNotification or None
"""
- pass
+ pass
- @abstractmethod
- def reset_notification_number_of_failures(self, namespace_name, repository_name, uuid):
- """
+ @abstractmethod
+ def reset_notification_number_of_failures(
+ self, namespace_name, repository_name, uuid
+ ):
+ """
Args:
namespace_name: namespace of repository
@@ -130,11 +144,11 @@ class RepoNotificationInterface(object):
RepositoryNotification
"""
- pass
+ pass
- @abstractmethod
- def queue_test_notification(self, uuid):
- """
+ @abstractmethod
+ def queue_test_notification(self, uuid):
+ """
Args:
uuid: uuid of notification
@@ -143,4 +157,4 @@ class RepoNotificationInterface(object):
RepositoryNotification or None
"""
- pass
+ pass
diff --git a/endpoints/api/repositorynotification_models_pre_oci.py b/endpoints/api/repositorynotification_models_pre_oci.py
index b3edf43ae..1b55143bd 100644
--- a/endpoints/api/repositorynotification_models_pre_oci.py
+++ b/endpoints/api/repositorynotification_models_pre_oci.py
@@ -3,70 +3,102 @@ import json
from app import notification_queue
from data import model
from data.model import InvalidNotificationException
-from endpoints.api.repositorynotification_models_interface import (RepoNotificationInterface,
- RepositoryNotification)
+from endpoints.api.repositorynotification_models_interface import (
+ RepoNotificationInterface,
+ RepositoryNotification,
+)
from notifications import build_notification_data
from notifications.notificationevent import NotificationEvent
class RepoNotificationPreOCIModel(RepoNotificationInterface):
- def create_repo_notification(self, namespace_name, repository_name, event_name, method_name,
- method_config, event_config, title=None):
- repository = model.repository.get_repository(namespace_name, repository_name)
- return self._notification(
- model.notification.create_repo_notification(repository, event_name, method_name,
- method_config, event_config, title))
+ def create_repo_notification(
+ self,
+ namespace_name,
+ repository_name,
+ event_name,
+ method_name,
+ method_config,
+ event_config,
+ title=None,
+ ):
+ repository = model.repository.get_repository(namespace_name, repository_name)
+ return self._notification(
+ model.notification.create_repo_notification(
+ repository, event_name, method_name, method_config, event_config, title
+ )
+ )
- def list_repo_notifications(self, namespace_name, repository_name, event_name=None):
- return [
- self._notification(n)
- for n in model.notification.list_repo_notifications(namespace_name, repository_name,
- event_name)]
+ def list_repo_notifications(self, namespace_name, repository_name, event_name=None):
+ return [
+ self._notification(n)
+ for n in model.notification.list_repo_notifications(
+ namespace_name, repository_name, event_name
+ )
+ ]
- def get_repo_notification(self, uuid):
- try:
- found = model.notification.get_repo_notification(uuid)
- except InvalidNotificationException:
- return None
- return self._notification(found)
+ def get_repo_notification(self, uuid):
+ try:
+ found = model.notification.get_repo_notification(uuid)
+ except InvalidNotificationException:
+ return None
+ return self._notification(found)
- def delete_repo_notification(self, namespace_name, repository_name, uuid):
- try:
- found = model.notification.delete_repo_notification(namespace_name, repository_name, uuid)
- except InvalidNotificationException:
- return None
- return self._notification(found)
+ def delete_repo_notification(self, namespace_name, repository_name, uuid):
+ try:
+ found = model.notification.delete_repo_notification(
+ namespace_name, repository_name, uuid
+ )
+ except InvalidNotificationException:
+ return None
+ return self._notification(found)
- def reset_notification_number_of_failures(self, namespace_name, repository_name, uuid):
- return self._notification(
- model.notification.reset_notification_number_of_failures(namespace_name, repository_name,
- uuid))
+ def reset_notification_number_of_failures(
+ self, namespace_name, repository_name, uuid
+ ):
+ return self._notification(
+ model.notification.reset_notification_number_of_failures(
+ namespace_name, repository_name, uuid
+ )
+ )
- def queue_test_notification(self, uuid):
- try:
- notification = model.notification.get_repo_notification(uuid)
- except InvalidNotificationException:
- return None
+ def queue_test_notification(self, uuid):
+ try:
+ notification = model.notification.get_repo_notification(uuid)
+ except InvalidNotificationException:
+ return None
- event_config = json.loads(notification.event_config_json or '{}')
- event_info = NotificationEvent.get_event(notification.event.name)
- sample_data = event_info.get_sample_data(notification.repository.namespace_user.username,
- notification.repository.name, event_config)
- notification_data = build_notification_data(notification, sample_data)
- notification_queue.put([
- notification.repository.namespace_user.username, notification.uuid, notification.event.name],
- json.dumps(notification_data))
- return self._notification(notification)
+ event_config = json.loads(notification.event_config_json or "{}")
+ event_info = NotificationEvent.get_event(notification.event.name)
+ sample_data = event_info.get_sample_data(
+ notification.repository.namespace_user.username,
+ notification.repository.name,
+ event_config,
+ )
+ notification_data = build_notification_data(notification, sample_data)
+ notification_queue.put(
+ [
+ notification.repository.namespace_user.username,
+ notification.uuid,
+ notification.event.name,
+ ],
+ json.dumps(notification_data),
+ )
+ return self._notification(notification)
- def _notification(self, notification):
- if not notification:
- return None
+ def _notification(self, notification):
+ if not notification:
+ return None
- return RepositoryNotification(
- uuid=notification.uuid, title=notification.title, event_name=notification.event.name,
- method_name=notification.method.name, config_json=notification.config_json,
- event_config_json=notification.event_config_json,
- number_of_failures=notification.number_of_failures)
+ return RepositoryNotification(
+ uuid=notification.uuid,
+ title=notification.title,
+ event_name=notification.event.name,
+ method_name=notification.method.name,
+ config_json=notification.config_json,
+ event_config_json=notification.event_config_json,
+ number_of_failures=notification.number_of_failures,
+ )
pre_oci_model = RepoNotificationPreOCIModel()
diff --git a/endpoints/api/repotoken.py b/endpoints/api/repotoken.py
index efa25a2fb..93f0f05eb 100644
--- a/endpoints/api/repotoken.py
+++ b/endpoints/api/repotoken.py
@@ -2,99 +2,87 @@
import logging
-from endpoints.api import (resource, nickname, require_repo_admin, RepositoryParamResource,
- validate_json_request, path_param)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_admin,
+ RepositoryParamResource,
+ validate_json_request,
+ path_param,
+)
logger = logging.getLogger(__name__)
-@resource('/v1/repository//tokens/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+
+@resource("/v1/repository//tokens/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositoryTokenList(RepositoryParamResource):
- """ Resource for creating and listing repository tokens. """
- schemas = {
- 'NewToken': {
- 'type': 'object',
- 'description': 'Description of a new token.',
- 'required':[
- 'friendlyName',
- ],
- 'properties': {
- 'friendlyName': {
- 'type': 'string',
- 'description': 'Friendly name to help identify the token',
- },
- },
- },
- }
+ """ Resource for creating and listing repository tokens. """
- @require_repo_admin
- @nickname('listRepoTokens')
- def get(self, namespace_name, repo_name):
- """ List the tokens for the specified repository. """
- return {
- 'message': 'Handling of access tokens is no longer supported',
- }, 410
+ schemas = {
+ "NewToken": {
+ "type": "object",
+ "description": "Description of a new token.",
+ "required": ["friendlyName"],
+ "properties": {
+ "friendlyName": {
+ "type": "string",
+ "description": "Friendly name to help identify the token",
+ }
+ },
+ }
+ }
+
+ @require_repo_admin
+ @nickname("listRepoTokens")
+ def get(self, namespace_name, repo_name):
+ """ List the tokens for the specified repository. """
+ return {"message": "Handling of access tokens is no longer supported"}, 410
+
+ @require_repo_admin
+ @nickname("createToken")
+ @validate_json_request("NewToken")
+ def post(self, namespace_name, repo_name):
+ """ Create a new repository token. """
+ return {"message": "Creation of access tokens is no longer supported"}, 410
- @require_repo_admin
- @nickname('createToken')
- @validate_json_request('NewToken')
- def post(self, namespace_name, repo_name):
- """ Create a new repository token. """
- return {
- 'message': 'Creation of access tokens is no longer supported',
- }, 410
-
-
-@resource('/v1/repository//tokens/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('code', 'The token code')
+@resource("/v1/repository//tokens/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("code", "The token code")
class RepositoryToken(RepositoryParamResource):
- """ Resource for managing individual tokens. """
- schemas = {
- 'TokenPermission': {
- 'type': 'object',
- 'description': 'Description of a token permission',
- 'required': [
- 'role',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Role to use for the token',
- 'enum': [
- 'read',
- 'write',
- 'admin',
- ],
- },
- },
- },
- }
+ """ Resource for managing individual tokens. """
- @require_repo_admin
- @nickname('getTokens')
- def get(self, namespace_name, repo_name, code):
- """ Fetch the specified repository token information. """
- return {
- 'message': 'Handling of access tokens is no longer supported',
- }, 410
+ schemas = {
+ "TokenPermission": {
+ "type": "object",
+ "description": "Description of a token permission",
+ "required": ["role"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Role to use for the token",
+ "enum": ["read", "write", "admin"],
+ }
+ },
+ }
+ }
+ @require_repo_admin
+ @nickname("getTokens")
+ def get(self, namespace_name, repo_name, code):
+ """ Fetch the specified repository token information. """
+ return {"message": "Handling of access tokens is no longer supported"}, 410
- @require_repo_admin
- @nickname('changeToken')
- @validate_json_request('TokenPermission')
- def put(self, namespace_name, repo_name, code):
- """ Update the permissions for the specified repository token. """
- return {
- 'message': 'Handling of access tokens is no longer supported',
- }, 410
+ @require_repo_admin
+ @nickname("changeToken")
+ @validate_json_request("TokenPermission")
+ def put(self, namespace_name, repo_name, code):
+ """ Update the permissions for the specified repository token. """
+ return {"message": "Handling of access tokens is no longer supported"}, 410
-
- @require_repo_admin
- @nickname('deleteToken')
- def delete(self, namespace_name, repo_name, code):
- """ Delete the repository token. """
- return {
- 'message': 'Handling of access tokens is no longer supported',
- }, 410
+ @require_repo_admin
+ @nickname("deleteToken")
+ def delete(self, namespace_name, repo_name, code):
+ """ Delete the repository token. """
+ return {"message": "Handling of access tokens is no longer supported"}, 410
diff --git a/endpoints/api/robot.py b/endpoints/api/robot.py
index 867329323..ea22269b3 100644
--- a/endpoints/api/robot.py
+++ b/endpoints/api/robot.py
@@ -1,11 +1,26 @@
""" Manage user and organization robot accounts. """
-from endpoints.api import (resource, nickname, ApiResource, log_action, related_user_resource,
- require_user_admin, require_scope, path_param, parse_args,
- truthy_bool, query_param, validate_json_request, max_json_size)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ log_action,
+ related_user_resource,
+ require_user_admin,
+ require_scope,
+ path_param,
+ parse_args,
+ truthy_bool,
+ query_param,
+ validate_json_request,
+ max_json_size,
+)
from endpoints.api.robot_models_pre_oci import pre_oci_model as model
from endpoints.exception import Unauthorized
-from auth.permissions import AdministerOrganizationPermission, OrganizationMemberPermission
+from auth.permissions import (
+ AdministerOrganizationPermission,
+ OrganizationMemberPermission,
+)
from auth.auth_context import get_authenticated_user
from auth import scopes
from util.names import format_robot_username
@@ -13,262 +28,309 @@ from flask import abort, request
CREATE_ROBOT_SCHEMA = {
- 'type': 'object',
- 'description': 'Optional data for creating a robot',
- 'properties': {
- 'description': {
- 'type': 'string',
- 'description': 'Optional text description for the robot',
- 'maxLength': 255,
+ "type": "object",
+ "description": "Optional data for creating a robot",
+ "properties": {
+ "description": {
+ "type": "string",
+ "description": "Optional text description for the robot",
+ "maxLength": 255,
+ },
+ "unstructured_metadata": {
+ "type": "object",
+ "description": "Optional unstructured metadata for the robot",
+ },
},
- 'unstructured_metadata': {
- 'type': 'object',
- 'description': 'Optional unstructured metadata for the robot',
- },
- },
}
-ROBOT_MAX_SIZE = 1024 * 1024 # 1 KB.
+ROBOT_MAX_SIZE = 1024 * 1024 # 1 KB.
def robots_list(prefix, include_permissions=False, include_token=False, limit=None):
- robots = model.list_entity_robot_permission_teams(prefix, limit=limit,
- include_token=include_token,
- include_permissions=include_permissions)
- return {'robots': [robot.to_dict(include_token=include_token) for robot in robots]}
+ robots = model.list_entity_robot_permission_teams(
+ prefix,
+ limit=limit,
+ include_token=include_token,
+ include_permissions=include_permissions,
+ )
+ return {"robots": [robot.to_dict(include_token=include_token) for robot in robots]}
-@resource('/v1/user/robots')
+@resource("/v1/user/robots")
class UserRobotList(ApiResource):
- """ Resource for listing user robots. """
+ """ Resource for listing user robots. """
- @require_user_admin
- @nickname('getUserRobots')
- @parse_args()
- @query_param('permissions',
- 'Whether to include repositories and teams in which the robots have permission.',
- type=truthy_bool, default=False)
- @query_param('token',
- 'If false, the robot\'s token is not returned.',
- type=truthy_bool, default=True)
- @query_param('limit',
- 'If specified, the number of robots to return.',
- type=int, default=None)
- def get(self, parsed_args):
- """ List the available robots for the user. """
- user = get_authenticated_user()
- return robots_list(user.username, include_token=parsed_args.get('token', True),
- include_permissions=parsed_args.get('permissions', False),
- limit=parsed_args.get('limit'))
+ @require_user_admin
+ @nickname("getUserRobots")
+ @parse_args()
+ @query_param(
+ "permissions",
+ "Whether to include repositories and teams in which the robots have permission.",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "token",
+ "If false, the robot's token is not returned.",
+ type=truthy_bool,
+ default=True,
+ )
+ @query_param(
+ "limit", "If specified, the number of robots to return.", type=int, default=None
+ )
+ def get(self, parsed_args):
+ """ List the available robots for the user. """
+ user = get_authenticated_user()
+ return robots_list(
+ user.username,
+ include_token=parsed_args.get("token", True),
+ include_permissions=parsed_args.get("permissions", False),
+ limit=parsed_args.get("limit"),
+ )
-@resource('/v1/user/robots/')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/user/robots/")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
class UserRobot(ApiResource):
- """ Resource for managing a user's robots. """
- schemas = {
- 'CreateRobot': CREATE_ROBOT_SCHEMA,
- }
+ """ Resource for managing a user's robots. """
- @require_user_admin
- @nickname('getUserRobot')
- def get(self, robot_shortname):
- """ Returns the user's robot with the specified name. """
- parent = get_authenticated_user()
- robot = model.get_user_robot(robot_shortname, parent)
- return robot.to_dict(include_metadata=True, include_token=True)
+ schemas = {"CreateRobot": CREATE_ROBOT_SCHEMA}
- @require_user_admin
- @nickname('createUserRobot')
- @max_json_size(ROBOT_MAX_SIZE)
- @validate_json_request('CreateRobot', optional=True)
- def put(self, robot_shortname):
- """ Create a new user robot with the specified name. """
- parent = get_authenticated_user()
- create_data = request.get_json() or {}
- robot = model.create_user_robot(robot_shortname, parent, create_data.get('description'),
- create_data.get('unstructured_metadata'))
- log_action('create_robot', parent.username, {
- 'robot': robot_shortname,
- 'description': create_data.get('description'),
- 'unstructured_metadata': create_data.get('unstructured_metadata'),
- })
- return robot.to_dict(include_metadata=True, include_token=True), 201
+ @require_user_admin
+ @nickname("getUserRobot")
+ def get(self, robot_shortname):
+ """ Returns the user's robot with the specified name. """
+ parent = get_authenticated_user()
+ robot = model.get_user_robot(robot_shortname, parent)
+ return robot.to_dict(include_metadata=True, include_token=True)
- @require_user_admin
- @nickname('deleteUserRobot')
- def delete(self, robot_shortname):
- """ Delete an existing robot. """
- parent = get_authenticated_user()
- model.delete_robot(format_robot_username(parent.username, robot_shortname))
- log_action('delete_robot', parent.username, {'robot': robot_shortname})
- return '', 204
+ @require_user_admin
+ @nickname("createUserRobot")
+ @max_json_size(ROBOT_MAX_SIZE)
+ @validate_json_request("CreateRobot", optional=True)
+ def put(self, robot_shortname):
+ """ Create a new user robot with the specified name. """
+ parent = get_authenticated_user()
+ create_data = request.get_json() or {}
+ robot = model.create_user_robot(
+ robot_shortname,
+ parent,
+ create_data.get("description"),
+ create_data.get("unstructured_metadata"),
+ )
+ log_action(
+ "create_robot",
+ parent.username,
+ {
+ "robot": robot_shortname,
+ "description": create_data.get("description"),
+ "unstructured_metadata": create_data.get("unstructured_metadata"),
+ },
+ )
+ return robot.to_dict(include_metadata=True, include_token=True), 201
+
+ @require_user_admin
+ @nickname("deleteUserRobot")
+ def delete(self, robot_shortname):
+ """ Delete an existing robot. """
+ parent = get_authenticated_user()
+ model.delete_robot(format_robot_username(parent.username, robot_shortname))
+ log_action("delete_robot", parent.username, {"robot": robot_shortname})
+ return "", 204
-@resource('/v1/organization//robots')
-@path_param('orgname', 'The name of the organization')
+@resource("/v1/organization//robots")
+@path_param("orgname", "The name of the organization")
@related_user_resource(UserRobotList)
class OrgRobotList(ApiResource):
- """ Resource for listing an organization's robots. """
+ """ Resource for listing an organization's robots. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrgRobots')
- @parse_args()
- @query_param('permissions',
- 'Whether to include repostories and teams in which the robots have permission.',
- type=truthy_bool, default=False)
- @query_param('token',
- 'If false, the robot\'s token is not returned.',
- type=truthy_bool, default=True)
- @query_param('limit',
- 'If specified, the number of robots to return.',
- type=int, default=None)
- def get(self, orgname, parsed_args):
- """ List the organization's robots. """
- permission = OrganizationMemberPermission(orgname)
- if permission.can():
- include_token = (AdministerOrganizationPermission(orgname).can() and
- parsed_args.get('token', True))
- include_permissions = (AdministerOrganizationPermission(orgname).can() and
- parsed_args.get('permissions', False))
- return robots_list(orgname, include_permissions=include_permissions,
- include_token=include_token,
- limit=parsed_args.get('limit'))
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrgRobots")
+ @parse_args()
+ @query_param(
+ "permissions",
+ "Whether to include repostories and teams in which the robots have permission.",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "token",
+ "If false, the robot's token is not returned.",
+ type=truthy_bool,
+ default=True,
+ )
+ @query_param(
+ "limit", "If specified, the number of robots to return.", type=int, default=None
+ )
+ def get(self, orgname, parsed_args):
+ """ List the organization's robots. """
+ permission = OrganizationMemberPermission(orgname)
+ if permission.can():
+ include_token = AdministerOrganizationPermission(
+ orgname
+ ).can() and parsed_args.get("token", True)
+ include_permissions = AdministerOrganizationPermission(
+ orgname
+ ).can() and parsed_args.get("permissions", False)
+ return robots_list(
+ orgname,
+ include_permissions=include_permissions,
+ include_token=include_token,
+ limit=parsed_args.get("limit"),
+ )
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/organization//robots/')
-@path_param('orgname', 'The name of the organization')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/organization//robots/")
+@path_param("orgname", "The name of the organization")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
@related_user_resource(UserRobot)
class OrgRobot(ApiResource):
- """ Resource for managing an organization's robots. """
- schemas = {
- 'CreateRobot': CREATE_ROBOT_SCHEMA,
- }
+ """ Resource for managing an organization's robots. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('getOrgRobot')
- def get(self, orgname, robot_shortname):
- """ Returns the organization's robot with the specified name. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- robot = model.get_org_robot(robot_shortname, orgname)
- return robot.to_dict(include_metadata=True, include_token=True)
+ schemas = {"CreateRobot": CREATE_ROBOT_SCHEMA}
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("getOrgRobot")
+ def get(self, orgname, robot_shortname):
+ """ Returns the organization's robot with the specified name. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ robot = model.get_org_robot(robot_shortname, orgname)
+ return robot.to_dict(include_metadata=True, include_token=True)
- @require_scope(scopes.ORG_ADMIN)
- @nickname('createOrgRobot')
- @max_json_size(ROBOT_MAX_SIZE)
- @validate_json_request('CreateRobot', optional=True)
- def put(self, orgname, robot_shortname):
- """ Create a new robot in the organization. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- create_data = request.get_json() or {}
- robot = model.create_org_robot(robot_shortname, orgname, create_data.get('description'),
- create_data.get('unstructured_metadata'))
- log_action('create_robot', orgname, {
- 'robot': robot_shortname,
- 'description': create_data.get('description'),
- 'unstructured_metadata': create_data.get('unstructured_metadata'),
- })
- return robot.to_dict(include_metadata=True, include_token=True), 201
+ raise Unauthorized()
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("createOrgRobot")
+ @max_json_size(ROBOT_MAX_SIZE)
+ @validate_json_request("CreateRobot", optional=True)
+ def put(self, orgname, robot_shortname):
+ """ Create a new robot in the organization. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ create_data = request.get_json() or {}
+ robot = model.create_org_robot(
+ robot_shortname,
+ orgname,
+ create_data.get("description"),
+ create_data.get("unstructured_metadata"),
+ )
+ log_action(
+ "create_robot",
+ orgname,
+ {
+ "robot": robot_shortname,
+ "description": create_data.get("description"),
+ "unstructured_metadata": create_data.get("unstructured_metadata"),
+ },
+ )
+ return robot.to_dict(include_metadata=True, include_token=True), 201
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrgRobot')
- def delete(self, orgname, robot_shortname):
- """ Delete an existing organization robot. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- model.delete_robot(format_robot_username(orgname, robot_shortname))
- log_action('delete_robot', orgname, {'robot': robot_shortname})
- return '', 204
+ raise Unauthorized()
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrgRobot")
+ def delete(self, orgname, robot_shortname):
+ """ Delete an existing organization robot. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ model.delete_robot(format_robot_username(orgname, robot_shortname))
+ log_action("delete_robot", orgname, {"robot": robot_shortname})
+ return "", 204
+
+ raise Unauthorized()
-@resource('/v1/user/robots//permissions')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/user/robots//permissions")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
class UserRobotPermissions(ApiResource):
- """ Resource for listing the permissions a user's robot has in the system. """
+ """ Resource for listing the permissions a user's robot has in the system. """
- @require_user_admin
- @nickname('getUserRobotPermissions')
- def get(self, robot_shortname):
- """ Returns the list of repository permissions for the user's robot. """
- parent = get_authenticated_user()
- robot = model.get_user_robot(robot_shortname, parent)
- permissions = model.list_robot_permissions(robot.name)
+ @require_user_admin
+ @nickname("getUserRobotPermissions")
+ def get(self, robot_shortname):
+ """ Returns the list of repository permissions for the user's robot. """
+ parent = get_authenticated_user()
+ robot = model.get_user_robot(robot_shortname, parent)
+ permissions = model.list_robot_permissions(robot.name)
- return {
- 'permissions': [permission.to_dict() for permission in permissions]
- }
+ return {"permissions": [permission.to_dict() for permission in permissions]}
-@resource('/v1/organization//robots//permissions')
-@path_param('orgname', 'The name of the organization')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/organization//robots//permissions")
+@path_param("orgname", "The name of the organization")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
@related_user_resource(UserRobotPermissions)
class OrgRobotPermissions(ApiResource):
- """ Resource for listing the permissions an org's robot has in the system. """
+ """ Resource for listing the permissions an org's robot has in the system. """
- @require_user_admin
- @nickname('getOrgRobotPermissions')
- def get(self, orgname, robot_shortname):
- """ Returns the list of repository permissions for the org's robot. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- robot = model.get_org_robot(robot_shortname, orgname)
- permissions = model.list_robot_permissions(robot.name)
+ @require_user_admin
+ @nickname("getOrgRobotPermissions")
+ def get(self, orgname, robot_shortname):
+ """ Returns the list of repository permissions for the org's robot. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ robot = model.get_org_robot(robot_shortname, orgname)
+ permissions = model.list_robot_permissions(robot.name)
- return {
- 'permissions': [permission.to_dict() for permission in permissions]
- }
+ return {"permissions": [permission.to_dict() for permission in permissions]}
- abort(403)
+ abort(403)
-@resource('/v1/user/robots//regenerate')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/user/robots//regenerate")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
class RegenerateUserRobot(ApiResource):
- """ Resource for regenerate an organization's robot's token. """
+ """ Resource for regenerate an organization's robot's token. """
- @require_user_admin
- @nickname('regenerateUserRobotToken')
- def post(self, robot_shortname):
- """ Regenerates the token for a user's robot. """
- parent = get_authenticated_user()
- robot = model.regenerate_user_robot_token(robot_shortname, parent)
- log_action('regenerate_robot_token', parent.username, {'robot': robot_shortname})
- return robot.to_dict(include_token=True)
+ @require_user_admin
+ @nickname("regenerateUserRobotToken")
+ def post(self, robot_shortname):
+ """ Regenerates the token for a user's robot. """
+ parent = get_authenticated_user()
+ robot = model.regenerate_user_robot_token(robot_shortname, parent)
+ log_action(
+ "regenerate_robot_token", parent.username, {"robot": robot_shortname}
+ )
+ return robot.to_dict(include_token=True)
-@resource('/v1/organization//robots//regenerate')
-@path_param('orgname', 'The name of the organization')
-@path_param('robot_shortname',
- 'The short name for the robot, without any user or organization prefix')
+@resource("/v1/organization//robots//regenerate")
+@path_param("orgname", "The name of the organization")
+@path_param(
+ "robot_shortname",
+ "The short name for the robot, without any user or organization prefix",
+)
@related_user_resource(RegenerateUserRobot)
class RegenerateOrgRobot(ApiResource):
- """ Resource for regenerate an organization's robot's token. """
+ """ Resource for regenerate an organization's robot's token. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('regenerateOrgRobotToken')
- def post(self, orgname, robot_shortname):
- """ Regenerates the token for an organization robot. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- robot = model.regenerate_org_robot_token(robot_shortname, orgname)
- log_action('regenerate_robot_token', orgname, {'robot': robot_shortname})
- return robot.to_dict(include_token=True)
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("regenerateOrgRobotToken")
+ def post(self, orgname, robot_shortname):
+ """ Regenerates the token for an organization robot. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ robot = model.regenerate_org_robot_token(robot_shortname, orgname)
+ log_action("regenerate_robot_token", orgname, {"robot": robot_shortname})
+ return robot.to_dict(include_token=True)
- raise Unauthorized()
+ raise Unauthorized()
diff --git a/endpoints/api/robot_models_interface.py b/endpoints/api/robot_models_interface.py
index c4a07d304..57be885a9 100644
--- a/endpoints/api/robot_models_interface.py
+++ b/endpoints/api/robot_models_interface.py
@@ -6,45 +6,51 @@ from six import add_metaclass
from endpoints.api import format_date
-class Permission(namedtuple('Permission', ['repository_name', 'repository_visibility_name', 'role_name'])):
- """
+class Permission(
+ namedtuple(
+ "Permission", ["repository_name", "repository_visibility_name", "role_name"]
+ )
+):
+ """
Permission the relationship between a robot and a repository and whether that robot can see the repo.
"""
- def to_dict(self):
- return {
- 'repository': {
- 'name': self.repository_name,
- 'is_public': self.repository_visibility_name == 'public'
- },
- 'role': self.role_name
- }
+ def to_dict(self):
+ return {
+ "repository": {
+ "name": self.repository_name,
+ "is_public": self.repository_visibility_name == "public",
+ },
+ "role": self.role_name,
+ }
-class Team(namedtuple('Team', ['name', 'avatar'])):
- """
+class Team(namedtuple("Team", ["name", "avatar"])):
+ """
Team represents a team entry for a robot list entry.
:type name: string
:type avatar: {string -> string}
"""
- def to_dict(self):
- return {
- 'name': self.name,
- 'avatar': self.avatar,
- }
+
+ def to_dict(self):
+ return {"name": self.name, "avatar": self.avatar}
class RobotWithPermissions(
- namedtuple('RobotWithPermissions', [
- 'name',
- 'password',
- 'created',
- 'last_accessed',
- 'teams',
- 'repository_names',
- 'description',
- ])):
- """
+ namedtuple(
+ "RobotWithPermissions",
+ [
+ "name",
+ "password",
+ "created",
+ "last_accessed",
+ "teams",
+ "repository_names",
+ "description",
+ ],
+ )
+):
+ """
RobotWithPermissions is a list of robot entries.
:type name: string
:type password: string
@@ -55,32 +61,38 @@ class RobotWithPermissions(
:type description: string
"""
- def to_dict(self, include_token=False):
- data = {
- 'name': self.name,
- 'created': format_date(self.created) if self.created is not None else None,
- 'last_accessed': format_date(self.last_accessed) if self.last_accessed is not None else None,
- 'teams': [team.to_dict() for team in self.teams],
- 'repositories': self.repository_names,
- 'description': self.description,
- }
+ def to_dict(self, include_token=False):
+ data = {
+ "name": self.name,
+ "created": format_date(self.created) if self.created is not None else None,
+ "last_accessed": format_date(self.last_accessed)
+ if self.last_accessed is not None
+ else None,
+ "teams": [team.to_dict() for team in self.teams],
+ "repositories": self.repository_names,
+ "description": self.description,
+ }
- if include_token:
- data['token'] = self.password
+ if include_token:
+ data["token"] = self.password
- return data
+ return data
class Robot(
- namedtuple('Robot', [
- 'name',
- 'password',
- 'created',
- 'last_accessed',
- 'description',
- 'unstructured_metadata',
- ])):
- """
+ namedtuple(
+ "Robot",
+ [
+ "name",
+ "password",
+ "created",
+ "last_accessed",
+ "description",
+ "unstructured_metadata",
+ ],
+ )
+):
+ """
Robot represents a robot entity.
:type name: string
:type password: string
@@ -90,105 +102,108 @@ class Robot(
:type unstructured_metadata: dict
"""
- def to_dict(self, include_metadata=False, include_token=False):
- data = {
- 'name': self.name,
- 'created': format_date(self.created) if self.created is not None else None,
- 'last_accessed': format_date(self.last_accessed) if self.last_accessed is not None else None,
- 'description': self.description,
- }
+ def to_dict(self, include_metadata=False, include_token=False):
+ data = {
+ "name": self.name,
+ "created": format_date(self.created) if self.created is not None else None,
+ "last_accessed": format_date(self.last_accessed)
+ if self.last_accessed is not None
+ else None,
+ "description": self.description,
+ }
- if include_token:
- data['token'] = self.password
+ if include_token:
+ data["token"] = self.password
- if include_metadata:
- data['unstructured_metadata'] = self.unstructured_metadata
-
- return data
+ if include_metadata:
+ data["unstructured_metadata"] = self.unstructured_metadata
+
+ return data
@add_metaclass(ABCMeta)
class RobotInterface(object):
- """
+ """
Interface that represents all data store interactions required by the Robot API
"""
- @abstractmethod
- def get_org_robot(self, robot_shortname, orgname):
- """
+ @abstractmethod
+ def get_org_robot(self, robot_shortname, orgname):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def get_user_robot(self, robot_shortname, owning_user):
- """
+ @abstractmethod
+ def get_user_robot(self, robot_shortname, owning_user):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def create_user_robot(self, robot_shortname, owning_user):
- """
+ @abstractmethod
+ def create_user_robot(self, robot_shortname, owning_user):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def create_org_robot(self, robot_shortname, orgname):
- """
+ @abstractmethod
+ def create_org_robot(self, robot_shortname, orgname):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def delete_robot(self, robot_username):
- """
+ @abstractmethod
+ def delete_robot(self, robot_username):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def regenerate_user_robot_token(self, robot_shortname, owning_user):
- """
+ @abstractmethod
+ def regenerate_user_robot_token(self, robot_shortname, owning_user):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def regenerate_org_robot_token(self, robot_shortname, orgname):
- """
+ @abstractmethod
+ def regenerate_org_robot_token(self, robot_shortname, orgname):
+ """
Returns:
Robot object
"""
- @abstractmethod
- def list_entity_robot_permission_teams(self, prefix, include_permissions=False,
- include_token=False, limit=None):
- """
+ @abstractmethod
+ def list_entity_robot_permission_teams(
+ self, prefix, include_permissions=False, include_token=False, limit=None
+ ):
+ """
Returns:
list of RobotWithPermissions objects
"""
- @abstractmethod
- def list_robot_permissions(self, username):
- """
+ @abstractmethod
+ def list_robot_permissions(self, username):
+ """
Returns:
list of Robot objects
diff --git a/endpoints/api/robot_models_pre_oci.py b/endpoints/api/robot_models_pre_oci.py
index ad83decdf..29afbdac3 100644
--- a/endpoints/api/robot_models_pre_oci.py
+++ b/endpoints/api/robot_models_pre_oci.py
@@ -3,121 +3,206 @@ import features
from app import avatar
from data import model
from active_migration import ActiveDataMigration, ERTMigrationFlags
-from data.database import (User, FederatedLogin, RobotAccountToken, Team as TeamTable, Repository,
- RobotAccountMetadata)
-from endpoints.api.robot_models_interface import (RobotInterface, Robot, RobotWithPermissions, Team,
- Permission)
+from data.database import (
+ User,
+ FederatedLogin,
+ RobotAccountToken,
+ Team as TeamTable,
+ Repository,
+ RobotAccountMetadata,
+)
+from endpoints.api.robot_models_interface import (
+ RobotInterface,
+ Robot,
+ RobotWithPermissions,
+ Team,
+ Permission,
+)
class RobotPreOCIModel(RobotInterface):
- def list_robot_permissions(self, username):
- permissions = model.permission.list_robot_permissions(username)
- return [Permission(permission.repository.name, permission.repository.visibility.name, permission.role.name) for
- permission in permissions]
+ def list_robot_permissions(self, username):
+ permissions = model.permission.list_robot_permissions(username)
+ return [
+ Permission(
+ permission.repository.name,
+ permission.repository.visibility.name,
+ permission.role.name,
+ )
+ for permission in permissions
+ ]
- def list_entity_robot_permission_teams(self, prefix, include_token=False,
- include_permissions=False, limit=None):
- tuples = model.user.list_entity_robot_permission_teams(prefix, limit=limit,
- include_permissions=include_permissions)
- robots = {}
- robot_teams = set()
+ def list_entity_robot_permission_teams(
+ self, prefix, include_token=False, include_permissions=False, limit=None
+ ):
+ tuples = model.user.list_entity_robot_permission_teams(
+ prefix, limit=limit, include_permissions=include_permissions
+ )
+ robots = {}
+ robot_teams = set()
- for robot_tuple in tuples:
- robot_name = robot_tuple.get(User.username)
- if robot_name not in robots:
- token = None
- if include_token:
- # TODO(remove-unenc): Remove branches once migrated.
- if robot_tuple.get(RobotAccountToken.token):
- token = robot_tuple.get(RobotAccountToken.token).decrypt()
+ for robot_tuple in tuples:
+ robot_name = robot_tuple.get(User.username)
+ if robot_name not in robots:
+ token = None
+ if include_token:
+ # TODO(remove-unenc): Remove branches once migrated.
+ if robot_tuple.get(RobotAccountToken.token):
+ token = robot_tuple.get(RobotAccountToken.token).decrypt()
- if token is None and ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
- token = robot_tuple.get(FederatedLogin.service_ident)
- assert not token.startswith('robot:')
+ if token is None and ActiveDataMigration.has_flag(
+ ERTMigrationFlags.READ_OLD_FIELDS
+ ):
+ token = robot_tuple.get(FederatedLogin.service_ident)
+ assert not token.startswith("robot:")
- robot_dict = {
- 'name': robot_name,
- 'token': token,
- 'created': robot_tuple.get(User.creation_date),
- 'last_accessed': (robot_tuple.get(User.last_accessed)
- if features.USER_LAST_ACCESSED else None),
- 'description': robot_tuple.get(RobotAccountMetadata.description),
- 'unstructured_metadata': robot_tuple.get(RobotAccountMetadata.unstructured_json),
- }
+ robot_dict = {
+ "name": robot_name,
+ "token": token,
+ "created": robot_tuple.get(User.creation_date),
+ "last_accessed": (
+ robot_tuple.get(User.last_accessed)
+ if features.USER_LAST_ACCESSED
+ else None
+ ),
+ "description": robot_tuple.get(RobotAccountMetadata.description),
+ "unstructured_metadata": robot_tuple.get(
+ RobotAccountMetadata.unstructured_json
+ ),
+ }
- if include_permissions:
- robot_dict.update({
- 'teams': [],
- 'repositories': [],
- })
+ if include_permissions:
+ robot_dict.update({"teams": [], "repositories": []})
- robots[robot_name] = Robot(robot_dict['name'], robot_dict['token'], robot_dict['created'],
- robot_dict['last_accessed'], robot_dict['description'],
- robot_dict['unstructured_metadata'])
- if include_permissions:
- team_name = robot_tuple.get(TeamTable.name)
- repository_name = robot_tuple.get(Repository.name)
+ robots[robot_name] = Robot(
+ robot_dict["name"],
+ robot_dict["token"],
+ robot_dict["created"],
+ robot_dict["last_accessed"],
+ robot_dict["description"],
+ robot_dict["unstructured_metadata"],
+ )
+ if include_permissions:
+ team_name = robot_tuple.get(TeamTable.name)
+ repository_name = robot_tuple.get(Repository.name)
- if team_name is not None:
- check_key = robot_name + ':' + team_name
- if check_key not in robot_teams:
- robot_teams.add(check_key)
+ if team_name is not None:
+ check_key = robot_name + ":" + team_name
+ if check_key not in robot_teams:
+ robot_teams.add(check_key)
- robot_dict['teams'].append(Team(
- team_name,
- avatar.get_data(team_name, team_name, 'team')
- ))
+ robot_dict["teams"].append(
+ Team(
+ team_name, avatar.get_data(team_name, team_name, "team")
+ )
+ )
- if repository_name is not None:
- if repository_name not in robot_dict['repositories']:
- robot_dict['repositories'].append(repository_name)
- robots[robot_name] = RobotWithPermissions(robot_dict['name'], robot_dict['token'],
- robot_dict['created'],
- (robot_dict['last_accessed']
- if features.USER_LAST_ACCESSED else None),
- robot_dict['teams'],
- robot_dict['repositories'],
- robot_dict['description'])
+ if repository_name is not None:
+ if repository_name not in robot_dict["repositories"]:
+ robot_dict["repositories"].append(repository_name)
+ robots[robot_name] = RobotWithPermissions(
+ robot_dict["name"],
+ robot_dict["token"],
+ robot_dict["created"],
+ (
+ robot_dict["last_accessed"]
+ if features.USER_LAST_ACCESSED
+ else None
+ ),
+ robot_dict["teams"],
+ robot_dict["repositories"],
+ robot_dict["description"],
+ )
- return robots.values()
+ return robots.values()
- def regenerate_user_robot_token(self, robot_shortname, owning_user):
- robot, password, metadata = model.user.regenerate_robot_token(robot_shortname, owning_user)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- metadata.description, metadata.unstructured_json)
+ def regenerate_user_robot_token(self, robot_shortname, owning_user):
+ robot, password, metadata = model.user.regenerate_robot_token(
+ robot_shortname, owning_user
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ metadata.description,
+ metadata.unstructured_json,
+ )
- def regenerate_org_robot_token(self, robot_shortname, orgname):
- parent = model.organization.get_organization(orgname)
- robot, password, metadata = model.user.regenerate_robot_token(robot_shortname, parent)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- metadata.description, metadata.unstructured_json)
+ def regenerate_org_robot_token(self, robot_shortname, orgname):
+ parent = model.organization.get_organization(orgname)
+ robot, password, metadata = model.user.regenerate_robot_token(
+ robot_shortname, parent
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ metadata.description,
+ metadata.unstructured_json,
+ )
- def delete_robot(self, robot_username):
- model.user.delete_robot(robot_username)
+ def delete_robot(self, robot_username):
+ model.user.delete_robot(robot_username)
- def create_user_robot(self, robot_shortname, owning_user, description, unstructured_metadata):
- robot, password = model.user.create_robot(robot_shortname, owning_user, description or '',
- unstructured_metadata)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- description or '', unstructured_metadata)
+ def create_user_robot(
+ self, robot_shortname, owning_user, description, unstructured_metadata
+ ):
+ robot, password = model.user.create_robot(
+ robot_shortname, owning_user, description or "", unstructured_metadata
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ description or "",
+ unstructured_metadata,
+ )
- def create_org_robot(self, robot_shortname, orgname, description, unstructured_metadata):
- parent = model.organization.get_organization(orgname)
- robot, password = model.user.create_robot(robot_shortname, parent, description or '',
- unstructured_metadata)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- description or '', unstructured_metadata)
+ def create_org_robot(
+ self, robot_shortname, orgname, description, unstructured_metadata
+ ):
+ parent = model.organization.get_organization(orgname)
+ robot, password = model.user.create_robot(
+ robot_shortname, parent, description or "", unstructured_metadata
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ description or "",
+ unstructured_metadata,
+ )
- def get_org_robot(self, robot_shortname, orgname):
- parent = model.organization.get_organization(orgname)
- robot, password, metadata = model.user.get_robot_and_metadata(robot_shortname, parent)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- metadata.description, metadata.unstructured_json)
+ def get_org_robot(self, robot_shortname, orgname):
+ parent = model.organization.get_organization(orgname)
+ robot, password, metadata = model.user.get_robot_and_metadata(
+ robot_shortname, parent
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ metadata.description,
+ metadata.unstructured_json,
+ )
- def get_user_robot(self, robot_shortname, owning_user):
- robot, password, metadata = model.user.get_robot_and_metadata(robot_shortname, owning_user)
- return Robot(robot.username, password, robot.creation_date, robot.last_accessed,
- metadata.description, metadata.unstructured_json)
+ def get_user_robot(self, robot_shortname, owning_user):
+ robot, password, metadata = model.user.get_robot_and_metadata(
+ robot_shortname, owning_user
+ )
+ return Robot(
+ robot.username,
+ password,
+ robot.creation_date,
+ robot.last_accessed,
+ metadata.description,
+ metadata.unstructured_json,
+ )
pre_oci_model = RobotPreOCIModel()
diff --git a/endpoints/api/search.py b/endpoints/api/search.py
index 0ddbbc3fa..efb028e55 100644
--- a/endpoints/api/search.py
+++ b/endpoints/api/search.py
@@ -2,15 +2,30 @@
import features
-from endpoints.api import (ApiResource, parse_args, query_param, truthy_bool, nickname, resource,
- require_scope, path_param, internal_only, Unauthorized, InvalidRequest,
- show_if)
+from endpoints.api import (
+ ApiResource,
+ parse_args,
+ query_param,
+ truthy_bool,
+ nickname,
+ resource,
+ require_scope,
+ path_param,
+ internal_only,
+ Unauthorized,
+ InvalidRequest,
+ show_if,
+)
from data.database import Repository
from data import model
from data.registry_model import registry_model
-from auth.permissions import (OrganizationMemberPermission, ReadRepositoryPermission,
- UserAdminPermission, AdministerOrganizationPermission,
- ReadRepositoryPermission)
+from auth.permissions import (
+ OrganizationMemberPermission,
+ ReadRepositoryPermission,
+ UserAdminPermission,
+ AdministerOrganizationPermission,
+ ReadRepositoryPermission,
+)
from auth.auth_context import get_authenticated_user
from auth import scopes
from app import app, avatar, authentication
@@ -19,7 +34,7 @@ from operator import itemgetter
from stringscore import liquidmetal
from util.names import parse_robot_username
-import anunidecode # Don't listen to pylint's lies. This import is required.
+import anunidecode # Don't listen to pylint's lies. This import is required.
import math
@@ -28,355 +43,416 @@ TEAM_SEARCH_SCORE = 2
REPOSITORY_SEARCH_SCORE = 4
-@resource('/v1/entities/link/')
+@resource("/v1/entities/link/")
@internal_only
class LinkExternalEntity(ApiResource):
- """ Resource for linking external entities to internal users. """
- @nickname('linkExternalUser')
- def post(self, username):
- if not authentication.federated_service:
- abort(404)
+ """ Resource for linking external entities to internal users. """
- # Only allowed if there is a logged in user.
- if not get_authenticated_user():
- raise Unauthorized()
+ @nickname("linkExternalUser")
+ def post(self, username):
+ if not authentication.federated_service:
+ abort(404)
- # Try to link the user with the given *external* username, to an internal record.
- (user, err_msg) = authentication.link_user(username)
- if user is None:
- raise InvalidRequest(err_msg, payload={'username': username})
+ # Only allowed if there is a logged in user.
+ if not get_authenticated_user():
+ raise Unauthorized()
- return {
- 'entity': {
- 'name': user.username,
- 'kind': 'user',
- 'is_robot': False,
- 'avatar': avatar.get_data_for_user(user)
- }
- }
+ # Try to link the user with the given *external* username, to an internal record.
+ (user, err_msg) = authentication.link_user(username)
+ if user is None:
+ raise InvalidRequest(err_msg, payload={"username": username})
+
+ return {
+ "entity": {
+ "name": user.username,
+ "kind": "user",
+ "is_robot": False,
+ "avatar": avatar.get_data_for_user(user),
+ }
+ }
-@resource('/v1/entities/')
+@resource("/v1/entities/")
class EntitySearch(ApiResource):
- """ Resource for searching entities. """
- @path_param('prefix', 'The prefix of the entities being looked up')
- @parse_args()
- @query_param('namespace', 'Namespace to use when querying for org entities.', type=str,
- default='')
- @query_param('includeTeams', 'Whether to include team names.', type=truthy_bool, default=False)
- @query_param('includeOrgs', 'Whether to include orgs names.', type=truthy_bool, default=False)
- @nickname('getMatchingEntities')
- def get(self, prefix, parsed_args):
- """ Get a list of entities that match the specified prefix. """
+ """ Resource for searching entities. """
- # Ensure we don't have any unicode characters in the search, as it breaks the search. Nothing
- # being searched can have unicode in it anyway, so this is a safe operation.
- prefix = prefix.encode('unidecode', 'ignore').replace(' ', '').lower()
+ @path_param("prefix", "The prefix of the entities being looked up")
+ @parse_args()
+ @query_param(
+ "namespace",
+ "Namespace to use when querying for org entities.",
+ type=str,
+ default="",
+ )
+ @query_param(
+ "includeTeams",
+ "Whether to include team names.",
+ type=truthy_bool,
+ default=False,
+ )
+ @query_param(
+ "includeOrgs", "Whether to include orgs names.", type=truthy_bool, default=False
+ )
+ @nickname("getMatchingEntities")
+ def get(self, prefix, parsed_args):
+ """ Get a list of entities that match the specified prefix. """
- teams = []
- org_data = []
+ # Ensure we don't have any unicode characters in the search, as it breaks the search. Nothing
+ # being searched can have unicode in it anyway, so this is a safe operation.
+ prefix = prefix.encode("unidecode", "ignore").replace(" ", "").lower()
- namespace_name = parsed_args['namespace']
- robot_namespace = None
- organization = None
+ teams = []
+ org_data = []
- try:
- organization = model.organization.get_organization(namespace_name)
+ namespace_name = parsed_args["namespace"]
+ robot_namespace = None
+ organization = None
- # namespace name was an org
- permission = OrganizationMemberPermission(namespace_name)
- if permission.can():
- robot_namespace = namespace_name
+ try:
+ organization = model.organization.get_organization(namespace_name)
- if parsed_args['includeTeams']:
- teams = model.team.get_matching_teams(prefix, organization)
+ # namespace name was an org
+ permission = OrganizationMemberPermission(namespace_name)
+ if permission.can():
+ robot_namespace = namespace_name
- if (parsed_args['includeOrgs'] and AdministerOrganizationPermission(namespace_name) and
- namespace_name.startswith(prefix)):
- org_data = [{
- 'name': namespace_name,
- 'kind': 'org',
- 'is_org_member': True,
- 'avatar': avatar.get_data_for_org(organization),
- }]
+ if parsed_args["includeTeams"]:
+ teams = model.team.get_matching_teams(prefix, organization)
- except model.organization.InvalidOrganizationException:
- # namespace name was a user
- user = get_authenticated_user()
- if user and user.username == namespace_name:
- # Check if there is admin user permissions (login only)
- admin_permission = UserAdminPermission(user.username)
- if admin_permission.can():
- robot_namespace = namespace_name
+ if (
+ parsed_args["includeOrgs"]
+ and AdministerOrganizationPermission(namespace_name)
+ and namespace_name.startswith(prefix)
+ ):
+ org_data = [
+ {
+ "name": namespace_name,
+ "kind": "org",
+ "is_org_member": True,
+ "avatar": avatar.get_data_for_org(organization),
+ }
+ ]
- # Lookup users in the database for the prefix query.
- users = model.user.get_matching_users(prefix, robot_namespace, organization, limit=10,
- exact_matches_only=not features.PARTIAL_USER_AUTOCOMPLETE)
+ except model.organization.InvalidOrganizationException:
+ # namespace name was a user
+ user = get_authenticated_user()
+ if user and user.username == namespace_name:
+ # Check if there is admin user permissions (login only)
+ admin_permission = UserAdminPermission(user.username)
+ if admin_permission.can():
+ robot_namespace = namespace_name
- # Lookup users via the user system for the prefix query. We'll filter out any users that
- # already exist in the database.
- external_users, federated_id, _ = authentication.query_users(prefix, limit=10)
- filtered_external_users = []
- if external_users and federated_id is not None:
- users = list(users)
- user_ids = [user.id for user in users]
+ # Lookup users in the database for the prefix query.
+ users = model.user.get_matching_users(
+ prefix,
+ robot_namespace,
+ organization,
+ limit=10,
+ exact_matches_only=not features.PARTIAL_USER_AUTOCOMPLETE,
+ )
- # Filter the users if any are already found via the database. We do so by looking up all
- # the found users in the federated user system.
- federated_query = model.user.get_federated_logins(user_ids, federated_id)
- found = {result.service_ident for result in federated_query}
- filtered_external_users = [user for user in external_users if not user.username in found]
+ # Lookup users via the user system for the prefix query. We'll filter out any users that
+ # already exist in the database.
+ external_users, federated_id, _ = authentication.query_users(prefix, limit=10)
+ filtered_external_users = []
+ if external_users and federated_id is not None:
+ users = list(users)
+ user_ids = [user.id for user in users]
- def entity_team_view(team):
- result = {
- 'name': team.name,
- 'kind': 'team',
- 'is_org_member': True,
- 'avatar': avatar.get_data_for_team(team)
- }
- return result
+ # Filter the users if any are already found via the database. We do so by looking up all
+ # the found users in the federated user system.
+ federated_query = model.user.get_federated_logins(user_ids, federated_id)
+ found = {result.service_ident for result in federated_query}
+ filtered_external_users = [
+ user for user in external_users if not user.username in found
+ ]
- def user_view(user):
- user_json = {
- 'name': user.username,
- 'kind': 'user',
- 'is_robot': user.robot,
- 'avatar': avatar.get_data_for_user(user)
- }
+ def entity_team_view(team):
+ result = {
+ "name": team.name,
+ "kind": "team",
+ "is_org_member": True,
+ "avatar": avatar.get_data_for_team(team),
+ }
+ return result
- if organization is not None:
- user_json['is_org_member'] = user.robot or user.is_org_member
+ def user_view(user):
+ user_json = {
+ "name": user.username,
+ "kind": "user",
+ "is_robot": user.robot,
+ "avatar": avatar.get_data_for_user(user),
+ }
- return user_json
+ if organization is not None:
+ user_json["is_org_member"] = user.robot or user.is_org_member
- def external_view(user):
- result = {
- 'name': user.username,
- 'kind': 'external',
- 'title': user.email or '',
- 'avatar': avatar.get_data_for_external_user(user)
- }
- return result
+ return user_json
- team_data = [entity_team_view(team) for team in teams]
- user_data = [user_view(user) for user in users]
- external_data = [external_view(user) for user in filtered_external_users]
+ def external_view(user):
+ result = {
+ "name": user.username,
+ "kind": "external",
+ "title": user.email or "",
+ "avatar": avatar.get_data_for_external_user(user),
+ }
+ return result
- return {
- 'results': team_data + user_data + org_data + external_data
- }
+ team_data = [entity_team_view(team) for team in teams]
+ user_data = [user_view(user) for user in users]
+ external_data = [external_view(user) for user in filtered_external_users]
+
+ return {"results": team_data + user_data + org_data + external_data}
def search_entity_view(username, entity, get_short_name=None):
- kind = 'user'
- title = 'user'
- avatar_data = avatar.get_data_for_user(entity)
- href = '/user/' + entity.username
+ kind = "user"
+ title = "user"
+ avatar_data = avatar.get_data_for_user(entity)
+ href = "/user/" + entity.username
- if entity.organization:
- kind = 'organization'
- title = 'org'
- avatar_data = avatar.get_data_for_org(entity)
- href = '/organization/' + entity.username
- elif entity.robot:
- parts = parse_robot_username(entity.username)
- if parts[0] == username:
- href = '/user/' + username + '?tab=robots&showRobot=' + entity.username
- else:
- href = '/organization/' + parts[0] + '?tab=robots&showRobot=' + entity.username
+ if entity.organization:
+ kind = "organization"
+ title = "org"
+ avatar_data = avatar.get_data_for_org(entity)
+ href = "/organization/" + entity.username
+ elif entity.robot:
+ parts = parse_robot_username(entity.username)
+ if parts[0] == username:
+ href = "/user/" + username + "?tab=robots&showRobot=" + entity.username
+ else:
+ href = (
+ "/organization/" + parts[0] + "?tab=robots&showRobot=" + entity.username
+ )
- kind = 'robot'
- title = 'robot'
- avatar_data = None
+ kind = "robot"
+ title = "robot"
+ avatar_data = None
- data = {
- 'title': title,
- 'kind': kind,
- 'avatar': avatar_data,
- 'name': entity.username,
- 'score': ENTITY_SEARCH_SCORE,
- 'href': href
- }
+ data = {
+ "title": title,
+ "kind": kind,
+ "avatar": avatar_data,
+ "name": entity.username,
+ "score": ENTITY_SEARCH_SCORE,
+ "href": href,
+ }
- if get_short_name:
- data['short_name'] = get_short_name(entity.username)
+ if get_short_name:
+ data["short_name"] = get_short_name(entity.username)
- return data
+ return data
def conduct_team_search(username, query, encountered_teams, results):
- """ Finds the matching teams where the user is a member. """
- matching_teams = model.team.get_matching_user_teams(query, get_authenticated_user(), limit=5)
- for team in matching_teams:
- if team.id in encountered_teams:
- continue
+ """ Finds the matching teams where the user is a member. """
+ matching_teams = model.team.get_matching_user_teams(
+ query, get_authenticated_user(), limit=5
+ )
+ for team in matching_teams:
+ if team.id in encountered_teams:
+ continue
- encountered_teams.add(team.id)
+ encountered_teams.add(team.id)
- results.append({
- 'kind': 'team',
- 'name': team.name,
- 'organization': search_entity_view(username, team.organization),
- 'avatar': avatar.get_data_for_team(team),
- 'score': TEAM_SEARCH_SCORE,
- 'href': '/organization/' + team.organization.username + '/teams/' + team.name
- })
+ results.append(
+ {
+ "kind": "team",
+ "name": team.name,
+ "organization": search_entity_view(username, team.organization),
+ "avatar": avatar.get_data_for_team(team),
+ "score": TEAM_SEARCH_SCORE,
+ "href": "/organization/"
+ + team.organization.username
+ + "/teams/"
+ + team.name,
+ }
+ )
def conduct_admined_team_search(username, query, encountered_teams, results):
- """ Finds matching teams in orgs admined by the user. """
- matching_teams = model.team.get_matching_admined_teams(query, get_authenticated_user(), limit=5)
- for team in matching_teams:
- if team.id in encountered_teams:
- continue
+ """ Finds matching teams in orgs admined by the user. """
+ matching_teams = model.team.get_matching_admined_teams(
+ query, get_authenticated_user(), limit=5
+ )
+ for team in matching_teams:
+ if team.id in encountered_teams:
+ continue
- encountered_teams.add(team.id)
+ encountered_teams.add(team.id)
- results.append({
- 'kind': 'team',
- 'name': team.name,
- 'organization': search_entity_view(username, team.organization),
- 'avatar': avatar.get_data_for_team(team),
- 'score': TEAM_SEARCH_SCORE,
- 'href': '/organization/' + team.organization.username + '/teams/' + team.name
- })
+ results.append(
+ {
+ "kind": "team",
+ "name": team.name,
+ "organization": search_entity_view(username, team.organization),
+ "avatar": avatar.get_data_for_team(team),
+ "score": TEAM_SEARCH_SCORE,
+ "href": "/organization/"
+ + team.organization.username
+ + "/teams/"
+ + team.name,
+ }
+ )
def conduct_repo_search(username, query, results, offset=0, limit=5):
- """ Finds matching repositories. """
- matching_repos = model.repository.get_filtered_matching_repositories(query, username, limit=limit,
- repo_kind=None,
- offset=offset)
+ """ Finds matching repositories. """
+ matching_repos = model.repository.get_filtered_matching_repositories(
+ query, username, limit=limit, repo_kind=None, offset=offset
+ )
- for repo in matching_repos:
- # TODO: make sure the repo.kind.name doesn't cause extra queries
- results.append(repo_result_view(repo, username))
+ for repo in matching_repos:
+ # TODO: make sure the repo.kind.name doesn't cause extra queries
+ results.append(repo_result_view(repo, username))
def conduct_namespace_search(username, query, results):
- """ Finds matching users and organizations. """
- matching_entities = model.user.get_matching_user_namespaces(query, username, limit=5)
- for entity in matching_entities:
- results.append(search_entity_view(username, entity))
+ """ Finds matching users and organizations. """
+ matching_entities = model.user.get_matching_user_namespaces(
+ query, username, limit=5
+ )
+ for entity in matching_entities:
+ results.append(search_entity_view(username, entity))
def conduct_robot_search(username, query, results):
- """ Finds matching robot accounts. """
- def get_short_name(name):
- return parse_robot_username(name)[1]
+ """ Finds matching robot accounts. """
- matching_robots = model.user.get_matching_robots(query, username, limit=5)
- for robot in matching_robots:
- results.append(search_entity_view(username, robot, get_short_name))
+ def get_short_name(name):
+ return parse_robot_username(name)[1]
+
+ matching_robots = model.user.get_matching_robots(query, username, limit=5)
+ for robot in matching_robots:
+ results.append(search_entity_view(username, robot, get_short_name))
def repo_result_view(repo, username, last_modified=None, stars=None, popularity=None):
- kind = 'application' if Repository.kind.get_name(repo.kind_id) == 'application' else 'repository'
- view = {
- 'kind': kind,
- 'title': 'app' if kind == 'application' else 'repo',
- 'namespace': search_entity_view(username, repo.namespace_user),
- 'name': repo.name,
- 'description': repo.description,
- 'is_public': model.repository.is_repository_public(repo),
- 'score': REPOSITORY_SEARCH_SCORE,
- 'href': '/' + kind + '/' + repo.namespace_user.username + '/' + repo.name
- }
-
- if last_modified is not None:
- view['last_modified'] = last_modified
-
- if stars is not None:
- view['stars'] = stars
-
- if popularity is not None:
- view['popularity'] = popularity
-
- return view
-
-@resource('/v1/find/all')
-class ConductSearch(ApiResource):
- """ Resource for finding users, repositories, teams, etc. """
- @parse_args()
- @query_param('query', 'The search query.', type=str, default='')
- @require_scope(scopes.READ_REPO)
- @nickname('conductSearch')
- def get(self, parsed_args):
- """ Get a list of entities and resources that match the specified query. """
- query = parsed_args['query']
- if not query:
- return {'results': []}
-
- username = None
- results = []
-
- if get_authenticated_user():
- username = get_authenticated_user().username
-
- # Search for teams.
- encountered_teams = set()
- conduct_team_search(username, query, encountered_teams, results)
- conduct_admined_team_search(username, query, encountered_teams, results)
-
- # Search for robot accounts.
- conduct_robot_search(username, query, results)
-
- # Search for repos.
- conduct_repo_search(username, query, results)
-
- # Search for users and orgs.
- conduct_namespace_search(username, query, results)
-
- # Modify the results' scores via how close the query term is to each result's name.
- for result in results:
- name = result.get('short_name', result['name'])
- lm_score = liquidmetal.score(name, query) or 0.5
- result['score'] = result['score'] * lm_score
-
- return {'results': sorted(results, key=itemgetter('score'), reverse=True)}
-
-
-MAX_PER_PAGE = app.config.get('SEARCH_RESULTS_PER_PAGE', 10)
-MAX_RESULT_PAGE_COUNT = app.config.get('SEARCH_MAX_RESULT_PAGE_COUNT', 10)
-
-@resource('/v1/find/repositories')
-class ConductRepositorySearch(ApiResource):
- """ Resource for finding repositories. """
- @parse_args()
- @query_param('query', 'The search query.', type=str, default='')
- @query_param('page', 'The page.', type=int, default=1)
- @nickname('conductRepoSearch')
- def get(self, parsed_args):
- """ Get a list of apps and repositories that match the specified query. """
- query = parsed_args['query']
- page = min(max(1, parsed_args['page']), MAX_RESULT_PAGE_COUNT)
- offset = (page - 1) * MAX_PER_PAGE
- limit = offset + MAX_PER_PAGE + 1
-
- username = get_authenticated_user().username if get_authenticated_user() else None
-
- # Lookup matching repositories.
- matching_repos = list(model.repository.get_filtered_matching_repositories(query, username,
- repo_kind=None,
- limit=limit,
- offset=offset))
-
- # Load secondary information such as last modified time, star count and action count.
- repository_ids = [repo.id for repo in matching_repos]
- last_modified_map = registry_model.get_most_recent_tag_lifetime_start(matching_repos)
- star_map = model.repository.get_stars(repository_ids)
- action_sum_map = model.log.get_repositories_action_sums(repository_ids)
-
- # Build the results list.
- results = [repo_result_view(repo, username, last_modified_map.get(repo.id),
- star_map.get(repo.id, 0),
- float(action_sum_map.get(repo.id, 0)))
- for repo in matching_repos]
-
- return {
- 'results': results[0:MAX_PER_PAGE],
- 'has_additional': len(results) > MAX_PER_PAGE,
- 'page': page,
- 'page_size': MAX_PER_PAGE,
- 'start_index': offset,
+ kind = (
+ "application"
+ if Repository.kind.get_name(repo.kind_id) == "application"
+ else "repository"
+ )
+ view = {
+ "kind": kind,
+ "title": "app" if kind == "application" else "repo",
+ "namespace": search_entity_view(username, repo.namespace_user),
+ "name": repo.name,
+ "description": repo.description,
+ "is_public": model.repository.is_repository_public(repo),
+ "score": REPOSITORY_SEARCH_SCORE,
+ "href": "/" + kind + "/" + repo.namespace_user.username + "/" + repo.name,
}
+
+ if last_modified is not None:
+ view["last_modified"] = last_modified
+
+ if stars is not None:
+ view["stars"] = stars
+
+ if popularity is not None:
+ view["popularity"] = popularity
+
+ return view
+
+
+@resource("/v1/find/all")
+class ConductSearch(ApiResource):
+ """ Resource for finding users, repositories, teams, etc. """
+
+ @parse_args()
+ @query_param("query", "The search query.", type=str, default="")
+ @require_scope(scopes.READ_REPO)
+ @nickname("conductSearch")
+ def get(self, parsed_args):
+ """ Get a list of entities and resources that match the specified query. """
+ query = parsed_args["query"]
+ if not query:
+ return {"results": []}
+
+ username = None
+ results = []
+
+ if get_authenticated_user():
+ username = get_authenticated_user().username
+
+ # Search for teams.
+ encountered_teams = set()
+ conduct_team_search(username, query, encountered_teams, results)
+ conduct_admined_team_search(username, query, encountered_teams, results)
+
+ # Search for robot accounts.
+ conduct_robot_search(username, query, results)
+
+ # Search for repos.
+ conduct_repo_search(username, query, results)
+
+ # Search for users and orgs.
+ conduct_namespace_search(username, query, results)
+
+ # Modify the results' scores via how close the query term is to each result's name.
+ for result in results:
+ name = result.get("short_name", result["name"])
+ lm_score = liquidmetal.score(name, query) or 0.5
+ result["score"] = result["score"] * lm_score
+
+ return {"results": sorted(results, key=itemgetter("score"), reverse=True)}
+
+
+MAX_PER_PAGE = app.config.get("SEARCH_RESULTS_PER_PAGE", 10)
+MAX_RESULT_PAGE_COUNT = app.config.get("SEARCH_MAX_RESULT_PAGE_COUNT", 10)
+
+
+@resource("/v1/find/repositories")
+class ConductRepositorySearch(ApiResource):
+ """ Resource for finding repositories. """
+
+ @parse_args()
+ @query_param("query", "The search query.", type=str, default="")
+ @query_param("page", "The page.", type=int, default=1)
+ @nickname("conductRepoSearch")
+ def get(self, parsed_args):
+ """ Get a list of apps and repositories that match the specified query. """
+ query = parsed_args["query"]
+ page = min(max(1, parsed_args["page"]), MAX_RESULT_PAGE_COUNT)
+ offset = (page - 1) * MAX_PER_PAGE
+ limit = offset + MAX_PER_PAGE + 1
+
+ username = (
+ get_authenticated_user().username if get_authenticated_user() else None
+ )
+
+ # Lookup matching repositories.
+ matching_repos = list(
+ model.repository.get_filtered_matching_repositories(
+ query, username, repo_kind=None, limit=limit, offset=offset
+ )
+ )
+
+ # Load secondary information such as last modified time, star count and action count.
+ repository_ids = [repo.id for repo in matching_repos]
+ last_modified_map = registry_model.get_most_recent_tag_lifetime_start(
+ matching_repos
+ )
+ star_map = model.repository.get_stars(repository_ids)
+ action_sum_map = model.log.get_repositories_action_sums(repository_ids)
+
+ # Build the results list.
+ results = [
+ repo_result_view(
+ repo,
+ username,
+ last_modified_map.get(repo.id),
+ star_map.get(repo.id, 0),
+ float(action_sum_map.get(repo.id, 0)),
+ )
+ for repo in matching_repos
+ ]
+
+ return {
+ "results": results[0:MAX_PER_PAGE],
+ "has_additional": len(results) > MAX_PER_PAGE,
+ "page": page,
+ "page_size": MAX_PER_PAGE,
+ "start_index": offset,
+ }
diff --git a/endpoints/api/secscan.py b/endpoints/api/secscan.py
index 71422184f..a17aeb2d0 100644
--- a/endpoints/api/secscan.py
+++ b/endpoints/api/secscan.py
@@ -7,9 +7,18 @@ from app import app, secscan_api
from auth.decorators import process_basic_auth_no_pass
from data.registry_model import registry_model
from data.registry_model.datatypes import SecurityScanStatus
-from endpoints.api import (require_repo_read, path_param,
- RepositoryParamResource, resource, nickname, show_if, parse_args,
- query_param, truthy_bool, disallow_for_app_repositories)
+from endpoints.api import (
+ require_repo_read,
+ path_param,
+ RepositoryParamResource,
+ resource,
+ nickname,
+ show_if,
+ parse_args,
+ query_param,
+ truthy_bool,
+ disallow_for_app_repositories,
+)
from endpoints.exception import NotFound, DownstreamIssue
from endpoints.api.manifest import MANIFEST_DIGEST_ROUTE
from util.secscan.api import APIRequestFailure
@@ -17,92 +26,100 @@ from util.secscan.api import APIRequestFailure
logger = logging.getLogger(__name__)
+
def _security_info(manifest_or_legacy_image, include_vulnerabilities=True):
- """ Returns a dict representing the result of a call to the security status API for the given
+ """ Returns a dict representing the result of a call to the security status API for the given
manifest or image.
"""
- status = registry_model.get_security_status(manifest_or_legacy_image)
- if status is None:
- raise NotFound()
+ status = registry_model.get_security_status(manifest_or_legacy_image)
+ if status is None:
+ raise NotFound()
- if status != SecurityScanStatus.SCANNED:
- return {
- 'status': status.value,
- }
+ if status != SecurityScanStatus.SCANNED:
+ return {"status": status.value}
- try:
- if include_vulnerabilities:
- data = secscan_api.get_layer_data(manifest_or_legacy_image, include_vulnerabilities=True)
- else:
- data = secscan_api.get_layer_data(manifest_or_legacy_image, include_features=True)
- except APIRequestFailure as arf:
- raise DownstreamIssue(arf.message)
+ try:
+ if include_vulnerabilities:
+ data = secscan_api.get_layer_data(
+ manifest_or_legacy_image, include_vulnerabilities=True
+ )
+ else:
+ data = secscan_api.get_layer_data(
+ manifest_or_legacy_image, include_features=True
+ )
+ except APIRequestFailure as arf:
+ raise DownstreamIssue(arf.message)
- if data is None:
- # If no data was found but we reached this point, then it indicates we have incorrect security
- # status for the manifest or legacy image. Mark the manifest or legacy image as unindexed
- # so it automatically gets re-indexed.
- if app.config.get('REGISTRY_STATE', 'normal') == 'normal':
- registry_model.reset_security_status(manifest_or_legacy_image)
+ if data is None:
+ # If no data was found but we reached this point, then it indicates we have incorrect security
+ # status for the manifest or legacy image. Mark the manifest or legacy image as unindexed
+ # so it automatically gets re-indexed.
+ if app.config.get("REGISTRY_STATE", "normal") == "normal":
+ registry_model.reset_security_status(manifest_or_legacy_image)
- return {
- 'status': SecurityScanStatus.QUEUED.value,
- }
+ return {"status": SecurityScanStatus.QUEUED.value}
- return {
- 'status': status.value,
- 'data': data,
- }
+ return {"status": status.value, "data": data}
-@resource('/v1/repository//image//security')
+@resource("/v1/repository//image//security")
@show_if(features.SECURITY_SCANNER)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('imageid', 'The image ID')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("imageid", "The image ID")
class RepositoryImageSecurity(RepositoryParamResource):
- """ Operations for managing the vulnerabilities in a repository image. """
+ """ Operations for managing the vulnerabilities in a repository image. """
- @process_basic_auth_no_pass
- @require_repo_read
- @nickname('getRepoImageSecurity')
- @disallow_for_app_repositories
- @parse_args()
- @query_param('vulnerabilities', 'Include vulnerabilities informations', type=truthy_bool,
- default=False)
- def get(self, namespace, repository, imageid, parsed_args):
- """ Fetches the features and vulnerabilities (if any) for a repository image. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @process_basic_auth_no_pass
+ @require_repo_read
+ @nickname("getRepoImageSecurity")
+ @disallow_for_app_repositories
+ @parse_args()
+ @query_param(
+ "vulnerabilities",
+ "Include vulnerabilities informations",
+ type=truthy_bool,
+ default=False,
+ )
+ def get(self, namespace, repository, imageid, parsed_args):
+ """ Fetches the features and vulnerabilities (if any) for a repository image. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- legacy_image = registry_model.get_legacy_image(repo_ref, imageid)
- if legacy_image is None:
- raise NotFound()
+ legacy_image = registry_model.get_legacy_image(repo_ref, imageid)
+ if legacy_image is None:
+ raise NotFound()
- return _security_info(legacy_image, parsed_args.vulnerabilities)
+ return _security_info(legacy_image, parsed_args.vulnerabilities)
-@resource(MANIFEST_DIGEST_ROUTE + '/security')
+@resource(MANIFEST_DIGEST_ROUTE + "/security")
@show_if(features.SECURITY_SCANNER)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('manifestref', 'The digest of the manifest')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("manifestref", "The digest of the manifest")
class RepositoryManifestSecurity(RepositoryParamResource):
- """ Operations for managing the vulnerabilities in a repository manifest. """
+ """ Operations for managing the vulnerabilities in a repository manifest. """
- @process_basic_auth_no_pass
- @require_repo_read
- @nickname('getRepoManifestSecurity')
- @disallow_for_app_repositories
- @parse_args()
- @query_param('vulnerabilities', 'Include vulnerabilities informations', type=truthy_bool,
- default=False)
- def get(self, namespace, repository, manifestref, parsed_args):
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @process_basic_auth_no_pass
+ @require_repo_read
+ @nickname("getRepoManifestSecurity")
+ @disallow_for_app_repositories
+ @parse_args()
+ @query_param(
+ "vulnerabilities",
+ "Include vulnerabilities informations",
+ type=truthy_bool,
+ default=False,
+ )
+ def get(self, namespace, repository, manifestref, parsed_args):
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- manifest = registry_model.lookup_manifest_by_digest(repo_ref, manifestref, allow_dead=True)
- if manifest is None:
- raise NotFound()
+ manifest = registry_model.lookup_manifest_by_digest(
+ repo_ref, manifestref, allow_dead=True
+ )
+ if manifest is None:
+ raise NotFound()
- return _security_info(manifest, parsed_args.vulnerabilities)
+ return _security_info(manifest, parsed_args.vulnerabilities)
diff --git a/endpoints/api/signing.py b/endpoints/api/signing.py
index eb2e942ec..a14c91523 100644
--- a/endpoints/api/signing.py
+++ b/endpoints/api/signing.py
@@ -4,26 +4,37 @@ import logging
import features
from app import tuf_metadata_api
-from endpoints.api import (require_repo_read, path_param,
- RepositoryParamResource, resource, nickname, show_if,
- disallow_for_app_repositories, NotFound)
+from endpoints.api import (
+ require_repo_read,
+ path_param,
+ RepositoryParamResource,
+ resource,
+ nickname,
+ show_if,
+ disallow_for_app_repositories,
+ NotFound,
+)
from endpoints.api.signing_models_pre_oci import pre_oci_model as model
logger = logging.getLogger(__name__)
-@resource('/v1/repository//signatures')
+@resource("/v1/repository//signatures")
@show_if(features.SIGNING)
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class RepositorySignatures(RepositoryParamResource):
- """ Operations for managing the signatures in a repository image. """
+ """ Operations for managing the signatures in a repository image. """
- @require_repo_read
- @nickname('getRepoSignatures')
- @disallow_for_app_repositories
- def get(self, namespace, repository):
- """ Fetches the list of signed tags for the repository. """
- if not model.is_trust_enabled(namespace, repository):
- raise NotFound()
+ @require_repo_read
+ @nickname("getRepoSignatures")
+ @disallow_for_app_repositories
+ def get(self, namespace, repository):
+ """ Fetches the list of signed tags for the repository. """
+ if not model.is_trust_enabled(namespace, repository):
+ raise NotFound()
- return {'delegations': tuf_metadata_api.get_all_tags_with_expiration(namespace, repository)}
+ return {
+ "delegations": tuf_metadata_api.get_all_tags_with_expiration(
+ namespace, repository
+ )
+ }
diff --git a/endpoints/api/signing_models_interface.py b/endpoints/api/signing_models_interface.py
index 6e5ce4ca4..ff8b9e369 100644
--- a/endpoints/api/signing_models_interface.py
+++ b/endpoints/api/signing_models_interface.py
@@ -1,14 +1,16 @@
from abc import ABCMeta, abstractmethod
from six import add_metaclass
+
@add_metaclass(ABCMeta)
class SigningInterface(object):
- """
+ """
Interface that represents all data store interactions required by the signing API endpoint.
"""
- @abstractmethod
- def is_trust_enabled(self, namespace_name, repo_name):
- """
+
+ @abstractmethod
+ def is_trust_enabled(self, namespace_name, repo_name):
+ """
Returns whether the repository with the given namespace name and repository name exists and
has trust enabled.
"""
diff --git a/endpoints/api/signing_models_pre_oci.py b/endpoints/api/signing_models_pre_oci.py
index 03afb1104..4093d76de 100644
--- a/endpoints/api/signing_models_pre_oci.py
+++ b/endpoints/api/signing_models_pre_oci.py
@@ -3,16 +3,17 @@ from endpoints.api.signing_models_interface import SigningInterface
class PreOCIModel(SigningInterface):
- """
+ """
PreOCIModel implements the data model for signing using a database schema
before it was changed to support the OCI specification.
"""
- def is_trust_enabled(self, namespace_name, repo_name):
- repo = model.repository.get_repository(namespace_name, repo_name)
- if repo is None:
- return False
- return repo.trust_enabled
+ def is_trust_enabled(self, namespace_name, repo_name):
+ repo = model.repository.get_repository(namespace_name, repo_name)
+ if repo is None:
+ return False
+
+ return repo.trust_enabled
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/subscribe.py b/endpoints/api/subscribe.py
index b526e25d2..5ae8cbbef 100644
--- a/endpoints/api/subscribe.py
+++ b/endpoints/api/subscribe.py
@@ -13,123 +13,129 @@ logger = logging.getLogger(__name__)
def check_repository_usage(user_or_org, plan_found):
- private_repos = model.get_private_repo_count(user_or_org.username)
- if plan_found is None:
- repos_allowed = 0
- else:
- repos_allowed = plan_found['privateRepos']
+ private_repos = model.get_private_repo_count(user_or_org.username)
+ if plan_found is None:
+ repos_allowed = 0
+ else:
+ repos_allowed = plan_found["privateRepos"]
- if private_repos > repos_allowed:
- model.create_unique_notification('over_private_usage', user_or_org.username, {'namespace': user_or_org.username})
- else:
- model.delete_notifications_by_kind(user_or_org.username, 'over_private_usage')
+ if private_repos > repos_allowed:
+ model.create_unique_notification(
+ "over_private_usage",
+ user_or_org.username,
+ {"namespace": user_or_org.username},
+ )
+ else:
+ model.delete_notifications_by_kind(user_or_org.username, "over_private_usage")
def carderror_response(exc):
- return {'carderror': exc.message}, 402
+ return {"carderror": exc.message}, 402
+
def connection_response(exc):
- return {'message': 'Could not contact Stripe. Please try again.'}, 503
+ return {"message": "Could not contact Stripe. Please try again."}, 503
def subscription_view(stripe_subscription, used_repos):
- view = {
- 'hasSubscription': True,
- 'isExistingCustomer': True,
- 'currentPeriodStart': stripe_subscription.current_period_start,
- 'currentPeriodEnd': stripe_subscription.current_period_end,
- 'plan': stripe_subscription.plan.id,
- 'usedPrivateRepos': used_repos,
- 'trialStart': stripe_subscription.trial_start,
- 'trialEnd': stripe_subscription.trial_end
- }
+ view = {
+ "hasSubscription": True,
+ "isExistingCustomer": True,
+ "currentPeriodStart": stripe_subscription.current_period_start,
+ "currentPeriodEnd": stripe_subscription.current_period_end,
+ "plan": stripe_subscription.plan.id,
+ "usedPrivateRepos": used_repos,
+ "trialStart": stripe_subscription.trial_start,
+ "trialEnd": stripe_subscription.trial_end,
+ }
- return view
+ return view
def subscribe(user, plan, token, require_business_plan):
- if not features.BILLING:
- return
+ if not features.BILLING:
+ return
- plan_found = None
- for plan_obj in PLANS:
- if plan_obj['stripeId'] == plan:
- plan_found = plan_obj
+ plan_found = None
+ for plan_obj in PLANS:
+ if plan_obj["stripeId"] == plan:
+ plan_found = plan_obj
- if not plan_found or plan_found['deprecated']:
- logger.warning('Plan not found or deprecated: %s', plan)
- raise NotFound()
+ if not plan_found or plan_found["deprecated"]:
+ logger.warning("Plan not found or deprecated: %s", plan)
+ raise NotFound()
- if (require_business_plan and not plan_found['bus_features'] and not
- plan_found['price'] == 0):
- logger.warning('Business attempting to subscribe to personal plan: %s',
- user.username)
- raise request_error(message='No matching plan found')
+ if (
+ require_business_plan
+ and not plan_found["bus_features"]
+ and not plan_found["price"] == 0
+ ):
+ logger.warning(
+ "Business attempting to subscribe to personal plan: %s", user.username
+ )
+ raise request_error(message="No matching plan found")
- private_repos = model.get_private_repo_count(user.username)
+ private_repos = model.get_private_repo_count(user.username)
- # This is the default response
- response_json = {
- 'plan': plan,
- 'usedPrivateRepos': private_repos,
- }
- status_code = 200
+ # This is the default response
+ response_json = {"plan": plan, "usedPrivateRepos": private_repos}
+ status_code = 200
- if not user.stripe_id:
- # Check if a non-paying user is trying to subscribe to a free plan
- if not plan_found['price'] == 0:
- # They want a real paying plan, create the customer and plan
- # simultaneously
- card = token
+ if not user.stripe_id:
+ # Check if a non-paying user is trying to subscribe to a free plan
+ if not plan_found["price"] == 0:
+ # They want a real paying plan, create the customer and plan
+ # simultaneously
+ card = token
- try:
- cus = billing.Customer.create(email=user.email, plan=plan, card=card)
- user.stripe_id = cus.id
- user.save()
- check_repository_usage(user, plan_found)
- log_action('account_change_plan', user.username, {'plan': plan})
- except stripe.error.CardError as e:
- return carderror_response(e)
- except stripe.error.APIConnectionError as e:
- return connection_response(e)
+ try:
+ cus = billing.Customer.create(email=user.email, plan=plan, card=card)
+ user.stripe_id = cus.id
+ user.save()
+ check_repository_usage(user, plan_found)
+ log_action("account_change_plan", user.username, {"plan": plan})
+ except stripe.error.CardError as e:
+ return carderror_response(e)
+ except stripe.error.APIConnectionError as e:
+ return connection_response(e)
- response_json = subscription_view(cus.subscription, private_repos)
- status_code = 201
-
- else:
- # Change the plan
- try:
- cus = billing.Customer.retrieve(user.stripe_id)
- except stripe.error.APIConnectionError as e:
- return connection_response(e)
-
- if plan_found['price'] == 0:
- if cus.subscription is not None:
- # We only have to cancel the subscription if they actually have one
- try:
- cus.subscription.delete()
- except stripe.error.APIConnectionError as e:
- return connection_response(e)
-
- check_repository_usage(user, plan_found)
- log_action('account_change_plan', user.username, {'plan': plan})
+ response_json = subscription_view(cus.subscription, private_repos)
+ status_code = 201
else:
- # User may have been a previous customer who is resubscribing
- if token:
- cus.card = token
+ # Change the plan
+ try:
+ cus = billing.Customer.retrieve(user.stripe_id)
+ except stripe.error.APIConnectionError as e:
+ return connection_response(e)
- cus.plan = plan
+ if plan_found["price"] == 0:
+ if cus.subscription is not None:
+ # We only have to cancel the subscription if they actually have one
+ try:
+ cus.subscription.delete()
+ except stripe.error.APIConnectionError as e:
+ return connection_response(e)
- try:
- cus.save()
- except stripe.error.CardError as e:
- return carderror_response(e)
- except stripe.error.APIConnectionError as e:
- return connection_response(e)
+ check_repository_usage(user, plan_found)
+ log_action("account_change_plan", user.username, {"plan": plan})
- response_json = subscription_view(cus.subscription, private_repos)
- check_repository_usage(user, plan_found)
- log_action('account_change_plan', user.username, {'plan': plan})
+ else:
+ # User may have been a previous customer who is resubscribing
+ if token:
+ cus.card = token
- return response_json, status_code
+ cus.plan = plan
+
+ try:
+ cus.save()
+ except stripe.error.CardError as e:
+ return carderror_response(e)
+ except stripe.error.APIConnectionError as e:
+ return connection_response(e)
+
+ response_json = subscription_view(cus.subscription, private_repos)
+ check_repository_usage(user, plan_found)
+ log_action("account_change_plan", user.username, {"plan": plan})
+
+ return response_json, status_code
diff --git a/endpoints/api/subscribe_models_interface.py b/endpoints/api/subscribe_models_interface.py
index fbc7a8a70..e1668602f 100644
--- a/endpoints/api/subscribe_models_interface.py
+++ b/endpoints/api/subscribe_models_interface.py
@@ -4,23 +4,24 @@ from six import add_metaclass
@add_metaclass(ABCMeta)
class SubscribeInterface(object):
- """
+ """
Interface that represents all data store interactions required by the subscribe API endpoint.
"""
- @abstractmethod
- def get_private_repo_count(self, username):
- """
+
+ @abstractmethod
+ def get_private_repo_count(self, username):
+ """
Returns the number of private repositories for a given username or namespace.
"""
- @abstractmethod
- def create_unique_notification(self, kind_name, target_username, metadata={}):
- """
+ @abstractmethod
+ def create_unique_notification(self, kind_name, target_username, metadata={}):
+ """
Creates a notification using the given parameters.
"""
- @abstractmethod
- def delete_notifications_by_kind(self, target_username, kind_name):
- """
+ @abstractmethod
+ def delete_notifications_by_kind(self, target_username, kind_name):
+ """
Remove notifications for a target based on given kind.
"""
diff --git a/endpoints/api/subscribe_models_pre_oci.py b/endpoints/api/subscribe_models_pre_oci.py
index a5ca83149..8c226494e 100644
--- a/endpoints/api/subscribe_models_pre_oci.py
+++ b/endpoints/api/subscribe_models_pre_oci.py
@@ -1,23 +1,27 @@
-from data.model.notification import create_unique_notification, delete_notifications_by_kind
+from data.model.notification import (
+ create_unique_notification,
+ delete_notifications_by_kind,
+)
from data.model.user import get_private_repo_count, get_user_or_org
from endpoints.api.subscribe_models_interface import SubscribeInterface
class PreOCIModel(SubscribeInterface):
- """
+ """
PreOCIModel implements the data model for build triggers using a database schema
before it was changed to support the OCI specification.
"""
- def get_private_repo_count(self, username):
- return get_private_repo_count(username)
- def create_unique_notification(self, kind_name, target_username, metadata={}):
- target = get_user_or_org(target_username)
- create_unique_notification(kind_name, target, metadata)
+ def get_private_repo_count(self, username):
+ return get_private_repo_count(username)
- def delete_notifications_by_kind(self, target_username, kind_name):
- target = get_user_or_org(target_username)
- delete_notifications_by_kind(target, kind_name)
+ def create_unique_notification(self, kind_name, target_username, metadata={}):
+ target = get_user_or_org(target_username)
+ create_unique_notification(kind_name, target, metadata)
+
+ def delete_notifications_by_kind(self, target_username, kind_name):
+ target = get_user_or_org(target_username)
+ delete_notifications_by_kind(target, kind_name)
data_model = PreOCIModel()
diff --git a/endpoints/api/suconfig.py b/endpoints/api/suconfig.py
index a96a7356b..bfe837345 100644
--- a/endpoints/api/suconfig.py
+++ b/endpoints/api/suconfig.py
@@ -10,7 +10,14 @@ from flask import abort
from app import app, config_provider
from auth.permissions import SuperUserPermission
from endpoints.api.suconfig_models_pre_oci import pre_oci_model as model
-from endpoints.api import (ApiResource, nickname, resource, internal_only, show_if, verify_not_prod)
+from endpoints.api import (
+ ApiResource,
+ nickname,
+ resource,
+ internal_only,
+ show_if,
+ verify_not_prod,
+)
import features
@@ -19,51 +26,45 @@ logger = logging.getLogger(__name__)
def database_is_valid():
- """ Returns whether the database, as configured, is valid. """
- if app.config['TESTING']:
- return False
+ """ Returns whether the database, as configured, is valid. """
+ if app.config["TESTING"]:
+ return False
- return model.is_valid()
+ return model.is_valid()
def database_has_users():
- """ Returns whether the database has any users defined. """
- return model.has_users()
+ """ Returns whether the database has any users defined. """
+ return model.has_users()
-@resource('/v1/superuser/registrystatus')
+@resource("/v1/superuser/registrystatus")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserRegistryStatus(ApiResource):
- """ Resource for determining the status of the registry, such as if config exists,
+ """ Resource for determining the status of the registry, such as if config exists,
if a database is configured, and if it has any defined users.
"""
- @nickname('scRegistryStatus')
- @verify_not_prod
- def get(self):
- """ Returns the status of the registry. """
- # If we have SETUP_COMPLETE, then we're ready to go!
- if app.config.get('SETUP_COMPLETE', False):
- return {
- 'provider_id': config_provider.provider_id,
- 'status': 'ready'
- }
- return {
- 'status': 'setup-incomplete'
- }
+ @nickname("scRegistryStatus")
+ @verify_not_prod
+ def get(self):
+ """ Returns the status of the registry. """
+ # If we have SETUP_COMPLETE, then we're ready to go!
+ if app.config.get("SETUP_COMPLETE", False):
+ return {"provider_id": config_provider.provider_id, "status": "ready"}
+
+ return {"status": "setup-incomplete"}
class _AlembicLogHandler(logging.Handler):
- def __init__(self):
- super(_AlembicLogHandler, self).__init__()
- self.records = []
+ def __init__(self):
+ super(_AlembicLogHandler, self).__init__()
+ self.records = []
+
+ def emit(self, record):
+ self.records.append({"level": record.levelname, "message": record.getMessage()})
- def emit(self, record):
- self.records.append({
- 'level': record.levelname,
- 'message': record.getMessage()
- })
# From: https://stackoverflow.com/a/44712205
def get_process_id(name):
@@ -76,29 +77,33 @@ def get_process_id(name):
>>> get_process_id('non-existent process')
[]
"""
- child = subprocess.Popen(['pgrep', name], stdout=subprocess.PIPE, shell=False)
+ child = subprocess.Popen(["pgrep", name], stdout=subprocess.PIPE, shell=False)
response = child.communicate()[0]
return [int(pid) for pid in response.split()]
-@resource('/v1/superuser/shutdown')
+@resource("/v1/superuser/shutdown")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserShutdown(ApiResource):
- """ Resource for sending a shutdown signal to the container. """
+ """ Resource for sending a shutdown signal to the container. """
- @verify_not_prod
- @nickname('scShutdownContainer')
- def post(self):
- """ Sends a signal to the phusion init system to shut down the container. """
- # Note: This method is called to set the database configuration before super users exists,
- # so we also allow it to be called if there is no valid registry configuration setup.
- if app.config['TESTING'] or not database_has_users() or SuperUserPermission().can():
- # Note: We skip if debugging locally.
- if app.config.get('DEBUGGING') == True:
- return {}
+ @verify_not_prod
+ @nickname("scShutdownContainer")
+ def post(self):
+ """ Sends a signal to the phusion init system to shut down the container. """
+ # Note: This method is called to set the database configuration before super users exists,
+ # so we also allow it to be called if there is no valid registry configuration setup.
+ if (
+ app.config["TESTING"]
+ or not database_has_users()
+ or SuperUserPermission().can()
+ ):
+ # Note: We skip if debugging locally.
+ if app.config.get("DEBUGGING") == True:
+ return {}
- os.kill(get_process_id('my_init')[0], signal.SIGINT)
- return {}
+ os.kill(get_process_id("my_init")[0], signal.SIGINT)
+ return {}
- abort(403)
+ abort(403)
diff --git a/endpoints/api/suconfig_models_interface.py b/endpoints/api/suconfig_models_interface.py
index 9f8cbd0cb..d41a97d11 100644
--- a/endpoints/api/suconfig_models_interface.py
+++ b/endpoints/api/suconfig_models_interface.py
@@ -4,36 +4,36 @@ from six import add_metaclass
@add_metaclass(ABCMeta)
class SuperuserConfigDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by the superuser config API.
"""
- @abstractmethod
- def is_valid(self):
- """
+ @abstractmethod
+ def is_valid(self):
+ """
Returns true if the configured database is valid.
"""
- @abstractmethod
- def has_users(self):
- """
+ @abstractmethod
+ def has_users(self):
+ """
Returns true if there are any users defined.
"""
- @abstractmethod
- def create_superuser(self, username, password, email):
- """
+ @abstractmethod
+ def create_superuser(self, username, password, email):
+ """
Creates a new superuser with the given username, password and email. Returns the user's UUID.
"""
- @abstractmethod
- def has_federated_login(self, username, service_name):
- """
+ @abstractmethod
+ def has_federated_login(self, username, service_name):
+ """
Returns true if the matching user has a federated login under the matching service.
"""
- @abstractmethod
- def attach_federated_login(self, username, service_name, federated_username):
- """
+ @abstractmethod
+ def attach_federated_login(self, username, service_name, federated_username):
+ """
Attaches a federatated login to the matching user, under the given service.
"""
diff --git a/endpoints/api/suconfig_models_pre_oci.py b/endpoints/api/suconfig_models_pre_oci.py
index 9bcb40acd..e27ed4770 100644
--- a/endpoints/api/suconfig_models_pre_oci.py
+++ b/endpoints/api/suconfig_models_pre_oci.py
@@ -2,32 +2,34 @@ from data import model
from data.database import User
from endpoints.api.suconfig_models_interface import SuperuserConfigDataInterface
+
class PreOCIModel(SuperuserConfigDataInterface):
- def is_valid(self):
- try:
- list(User.select().limit(1))
- return True
- except:
- return False
+ def is_valid(self):
+ try:
+ list(User.select().limit(1))
+ return True
+ except:
+ return False
- def has_users(self):
- return bool(list(User.select().limit(1)))
+ def has_users(self):
+ return bool(list(User.select().limit(1)))
- def create_superuser(self, username, password, email):
- return model.user.create_user(username, password, email, auto_verify=True).uuid
+ def create_superuser(self, username, password, email):
+ return model.user.create_user(username, password, email, auto_verify=True).uuid
- def has_federated_login(self, username, service_name):
- user = model.user.get_user(username)
- if user is None:
- return False
+ def has_federated_login(self, username, service_name):
+ user = model.user.get_user(username)
+ if user is None:
+ return False
- return bool(model.user.lookup_federated_login(user, service_name))
+ return bool(model.user.lookup_federated_login(user, service_name))
- def attach_federated_login(self, username, service_name, federated_username):
- user = model.user.get_user(username)
- if user is None:
- return False
+ def attach_federated_login(self, username, service_name, federated_username):
+ user = model.user.get_user(username)
+ if user is None:
+ return False
+
+ model.user.attach_federated_login(user, service_name, federated_username)
- model.user.attach_federated_login(user, service_name, federated_username)
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/superuser.py b/endpoints/api/superuser.py
index ec1a4992f..15a6e4af9 100644
--- a/endpoints/api/superuser.py
+++ b/endpoints/api/superuser.py
@@ -17,15 +17,35 @@ from auth.auth_context import get_authenticated_user
from auth.permissions import SuperUserPermission
from data.database import ServiceKeyApprovalType
from data.logs_model import logs_model
-from endpoints.api import (ApiResource, nickname, resource, validate_json_request,
- internal_only, require_scope, show_if, parse_args,
- query_param, require_fresh_login, path_param, verify_not_prod,
- page_support, log_action, format_date, truthy_bool,
- InvalidRequest, NotFound, Unauthorized, InvalidResponse)
+from endpoints.api import (
+ ApiResource,
+ nickname,
+ resource,
+ validate_json_request,
+ internal_only,
+ require_scope,
+ show_if,
+ parse_args,
+ query_param,
+ require_fresh_login,
+ path_param,
+ verify_not_prod,
+ page_support,
+ log_action,
+ format_date,
+ truthy_bool,
+ InvalidRequest,
+ NotFound,
+ Unauthorized,
+ InvalidResponse,
+)
from endpoints.api.build import get_logs_or_log_url
-from endpoints.api.superuser_models_pre_oci import (pre_oci_model, ServiceKeyDoesNotExist,
- ServiceKeyAlreadyApproved,
- InvalidRepositoryBuildException)
+from endpoints.api.superuser_models_pre_oci import (
+ pre_oci_model,
+ ServiceKeyDoesNotExist,
+ ServiceKeyAlreadyApproved,
+ InvalidRepositoryBuildException,
+)
from endpoints.api.logs import _validate_logs_arguments
from util.request import get_request_ip
from util.useremails import send_confirmation_email, send_recovery_email
@@ -36,821 +56,871 @@ logger = logging.getLogger(__name__)
def get_immediate_subdirectories(directory):
- return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
+ return [
+ name
+ for name in os.listdir(directory)
+ if os.path.isdir(os.path.join(directory, name))
+ ]
def get_services():
- services = set(get_immediate_subdirectories(app.config['SYSTEM_SERVICES_PATH']))
- services = services - set(app.config['SYSTEM_SERVICE_BLACKLIST'])
- return services
+ services = set(get_immediate_subdirectories(app.config["SYSTEM_SERVICES_PATH"]))
+ services = services - set(app.config["SYSTEM_SERVICE_BLACKLIST"])
+ return services
-@resource('/v1/superuser/aggregatelogs')
+@resource("/v1/superuser/aggregatelogs")
@internal_only
class SuperUserAggregateLogs(ApiResource):
- """ Resource for fetching aggregated logs for the current user. """
+ """ Resource for fetching aggregated logs for the current user. """
- @require_fresh_login
- @verify_not_prod
- @nickname('listAllAggregateLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time from which to get logs. (%m/%d/%Y %Z)', type=str)
- @query_param('endtime', 'Latest time to which to get logs. (%m/%d/%Y %Z)', type=str)
- def get(self, parsed_args):
- """ Returns the aggregated logs for the current system. """
- if SuperUserPermission().can():
- (start_time, end_time) = _validate_logs_arguments(parsed_args['starttime'],
- parsed_args['endtime'])
- aggregated_logs = logs_model.get_aggregated_log_counts(start_time, end_time)
- return {
- 'aggregated': [log.to_dict() for log in aggregated_logs]
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("listAllAggregateLogs")
+ @parse_args()
+ @query_param(
+ "starttime", "Earliest time from which to get logs. (%m/%d/%Y %Z)", type=str
+ )
+ @query_param("endtime", "Latest time to which to get logs. (%m/%d/%Y %Z)", type=str)
+ def get(self, parsed_args):
+ """ Returns the aggregated logs for the current system. """
+ if SuperUserPermission().can():
+ (start_time, end_time) = _validate_logs_arguments(
+ parsed_args["starttime"], parsed_args["endtime"]
+ )
+ aggregated_logs = logs_model.get_aggregated_log_counts(start_time, end_time)
+ return {"aggregated": [log.to_dict() for log in aggregated_logs]}
+
+ raise Unauthorized()
- raise Unauthorized()
LOGS_PER_PAGE = 20
-@resource('/v1/superuser/logs')
+
+@resource("/v1/superuser/logs")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserLogs(ApiResource):
- """ Resource for fetching all logs in the system. """
+ """ Resource for fetching all logs in the system. """
- @require_fresh_login
- @verify_not_prod
- @nickname('listAllLogs')
- @parse_args()
- @query_param('starttime', 'Earliest time from which to get logs (%m/%d/%Y %Z)', type=str)
- @query_param('endtime', 'Latest time to which to get logs (%m/%d/%Y %Z)', type=str)
- @query_param('page', 'The page number for the logs', type=int, default=1)
- @page_support()
- @require_scope(scopes.SUPERUSER)
- def get(self, parsed_args, page_token):
- """ List the usage logs for the current system. """
- if SuperUserPermission().can():
- start_time = parsed_args['starttime']
- end_time = parsed_args['endtime']
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("listAllLogs")
+ @parse_args()
+ @query_param(
+ "starttime", "Earliest time from which to get logs (%m/%d/%Y %Z)", type=str
+ )
+ @query_param("endtime", "Latest time to which to get logs (%m/%d/%Y %Z)", type=str)
+ @query_param("page", "The page number for the logs", type=int, default=1)
+ @page_support()
+ @require_scope(scopes.SUPERUSER)
+ def get(self, parsed_args, page_token):
+ """ List the usage logs for the current system. """
+ if SuperUserPermission().can():
+ start_time = parsed_args["starttime"]
+ end_time = parsed_args["endtime"]
- (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
- log_entry_page = logs_model.lookup_logs(start_time, end_time, page_token=page_token)
- return {
- 'start_time': format_date(start_time),
- 'end_time': format_date(end_time),
- 'logs': [log.to_dict(avatar, include_namespace=True) for log in log_entry_page.logs],
- }, log_entry_page.next_page_token
+ (start_time, end_time) = _validate_logs_arguments(start_time, end_time)
+ log_entry_page = logs_model.lookup_logs(
+ start_time, end_time, page_token=page_token
+ )
+ return (
+ {
+ "start_time": format_date(start_time),
+ "end_time": format_date(end_time),
+ "logs": [
+ log.to_dict(avatar, include_namespace=True)
+ for log in log_entry_page.logs
+ ],
+ },
+ log_entry_page.next_page_token,
+ )
- raise Unauthorized()
+ raise Unauthorized()
def org_view(org):
- return {
- 'name': org.username,
- 'email': org.email,
- 'avatar': avatar.get_data_for_org(org),
- }
+ return {
+ "name": org.username,
+ "email": org.email,
+ "avatar": avatar.get_data_for_org(org),
+ }
def user_view(user, password=None):
- user_data = {
- 'kind': 'user',
- 'name': user.username,
- 'username': user.username,
- 'email': user.email,
- 'verified': user.verified,
- 'avatar': avatar.get_data_for_user(user),
- 'super_user': superusers.is_superuser(user.username),
- 'enabled': user.enabled,
- }
+ user_data = {
+ "kind": "user",
+ "name": user.username,
+ "username": user.username,
+ "email": user.email,
+ "verified": user.verified,
+ "avatar": avatar.get_data_for_user(user),
+ "super_user": superusers.is_superuser(user.username),
+ "enabled": user.enabled,
+ }
- if password is not None:
- user_data['encrypted_password'] = authentication.encrypt_user_password(password)
+ if password is not None:
+ user_data["encrypted_password"] = authentication.encrypt_user_password(password)
- return user_data
+ return user_data
-@resource('/v1/superuser/changelog/')
+@resource("/v1/superuser/changelog/")
@internal_only
@show_if(features.SUPER_USERS)
class ChangeLog(ApiResource):
- """ Resource for returning the change log for enterprise customers. """
+ """ Resource for returning the change log for enterprise customers. """
- @require_fresh_login
- @verify_not_prod
- @nickname('getChangeLog')
- @require_scope(scopes.SUPERUSER)
- def get(self):
- """ Returns the change log for this installation. """
- if SuperUserPermission().can():
- with open(os.path.join(ROOT_DIR, 'CHANGELOG.md'), 'r') as f:
- return {
- 'log': f.read()
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("getChangeLog")
+ @require_scope(scopes.SUPERUSER)
+ def get(self):
+ """ Returns the change log for this installation. """
+ if SuperUserPermission().can():
+ with open(os.path.join(ROOT_DIR, "CHANGELOG.md"), "r") as f:
+ return {"log": f.read()}
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser/organizations/')
+@resource("/v1/superuser/organizations/")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserOrganizationList(ApiResource):
- """ Resource for listing organizations in the system. """
+ """ Resource for listing organizations in the system. """
- @require_fresh_login
- @verify_not_prod
- @nickname('listAllOrganizations')
- @require_scope(scopes.SUPERUSER)
- def get(self):
- """ Returns a list of all organizations in the system. """
- if SuperUserPermission().can():
- return {
- 'organizations': [org.to_dict() for org in pre_oci_model.get_organizations()]
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("listAllOrganizations")
+ @require_scope(scopes.SUPERUSER)
+ def get(self):
+ """ Returns a list of all organizations in the system. """
+ if SuperUserPermission().can():
+ return {
+ "organizations": [
+ org.to_dict() for org in pre_oci_model.get_organizations()
+ ]
+ }
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser/users/')
+@resource("/v1/superuser/users/")
@show_if(features.SUPER_USERS)
class SuperUserList(ApiResource):
- """ Resource for listing users in the system. """
- schemas = {
- 'CreateInstallUser': {
- 'id': 'CreateInstallUser',
- 'description': 'Data for creating a user',
- 'required': ['username'],
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'The username of the user being created'
- },
+ """ Resource for listing users in the system. """
- 'email': {
- 'type': 'string',
- 'description': 'The email address of the user being created'
+ schemas = {
+ "CreateInstallUser": {
+ "id": "CreateInstallUser",
+ "description": "Data for creating a user",
+ "required": ["username"],
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "The username of the user being created",
+ },
+ "email": {
+ "type": "string",
+ "description": "The email address of the user being created",
+ },
+ },
}
- }
}
- }
- @require_fresh_login
- @verify_not_prod
- @nickname('listAllUsers')
- @parse_args()
- @query_param('disabled', 'If false, only enabled users will be returned.', type=truthy_bool,
- default=True)
- @require_scope(scopes.SUPERUSER)
- def get(self, parsed_args):
- """ Returns a list of all users in the system. """
- if SuperUserPermission().can():
- users = pre_oci_model.get_active_users(disabled=parsed_args['disabled'])
- return {
- 'users': [user.to_dict() for user in users]
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("listAllUsers")
+ @parse_args()
+ @query_param(
+ "disabled",
+ "If false, only enabled users will be returned.",
+ type=truthy_bool,
+ default=True,
+ )
+ @require_scope(scopes.SUPERUSER)
+ def get(self, parsed_args):
+ """ Returns a list of all users in the system. """
+ if SuperUserPermission().can():
+ users = pre_oci_model.get_active_users(disabled=parsed_args["disabled"])
+ return {"users": [user.to_dict() for user in users]}
- raise Unauthorized()
+ raise Unauthorized()
- @require_fresh_login
- @verify_not_prod
- @nickname('createInstallUser')
- @validate_json_request('CreateInstallUser')
- @require_scope(scopes.SUPERUSER)
- def post(self):
- """ Creates a new user. """
- # Ensure that we are using database auth.
- if app.config['AUTHENTICATION_TYPE'] != 'Database':
- raise InvalidRequest('Cannot create a user in a non-database auth system')
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("createInstallUser")
+ @validate_json_request("CreateInstallUser")
+ @require_scope(scopes.SUPERUSER)
+ def post(self):
+ """ Creates a new user. """
+ # Ensure that we are using database auth.
+ if app.config["AUTHENTICATION_TYPE"] != "Database":
+ raise InvalidRequest("Cannot create a user in a non-database auth system")
- user_information = request.get_json()
- if SuperUserPermission().can():
- # Generate a temporary password for the user.
- random = SystemRandom()
- password = ''.join([random.choice(string.ascii_uppercase + string.digits) for _ in range(32)])
+ user_information = request.get_json()
+ if SuperUserPermission().can():
+ # Generate a temporary password for the user.
+ random = SystemRandom()
+ password = "".join(
+ [
+ random.choice(string.ascii_uppercase + string.digits)
+ for _ in range(32)
+ ]
+ )
- # Create the user.
- username = user_information['username']
- email = user_information.get('email')
- install_user, confirmation_code = pre_oci_model.create_install_user(username, password, email)
- if features.MAILING:
- send_confirmation_email(install_user.username, install_user.email, confirmation_code)
+ # Create the user.
+ username = user_information["username"]
+ email = user_information.get("email")
+ install_user, confirmation_code = pre_oci_model.create_install_user(
+ username, password, email
+ )
+ if features.MAILING:
+ send_confirmation_email(
+ install_user.username, install_user.email, confirmation_code
+ )
- return {
- 'username': username,
- 'email': email,
- 'password': password,
- 'encrypted_password': authentication.encrypt_user_password(password),
- }
+ return {
+ "username": username,
+ "email": email,
+ "password": password,
+ "encrypted_password": authentication.encrypt_user_password(password),
+ }
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superusers/users//sendrecovery')
+@resource("/v1/superusers/users//sendrecovery")
@internal_only
@show_if(features.SUPER_USERS)
@show_if(features.MAILING)
class SuperUserSendRecoveryEmail(ApiResource):
- """ Resource for sending a recovery user on behalf of a user. """
+ """ Resource for sending a recovery user on behalf of a user. """
- @require_fresh_login
- @verify_not_prod
- @nickname('sendInstallUserRecoveryEmail')
- @require_scope(scopes.SUPERUSER)
- def post(self, username):
- # Ensure that we are using database auth.
- if app.config['AUTHENTICATION_TYPE'] != 'Database':
- raise InvalidRequest('Cannot send a recovery e-mail for non-database auth')
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("sendInstallUserRecoveryEmail")
+ @require_scope(scopes.SUPERUSER)
+ def post(self, username):
+ # Ensure that we are using database auth.
+ if app.config["AUTHENTICATION_TYPE"] != "Database":
+ raise InvalidRequest("Cannot send a recovery e-mail for non-database auth")
- if SuperUserPermission().can():
- user = pre_oci_model.get_nonrobot_user(username)
- if user is None:
- raise NotFound()
+ if SuperUserPermission().can():
+ user = pre_oci_model.get_nonrobot_user(username)
+ if user is None:
+ raise NotFound()
- if superusers.is_superuser(username):
- raise InvalidRequest('Cannot send a recovery email for a superuser')
+ if superusers.is_superuser(username):
+ raise InvalidRequest("Cannot send a recovery email for a superuser")
- code = pre_oci_model.create_reset_password_email_code(user.email)
- send_recovery_email(user.email, code)
- return {
- 'email': user.email
- }
+ code = pre_oci_model.create_reset_password_email_code(user.email)
+ send_recovery_email(user.email, code)
+ return {"email": user.email}
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser/users/')
-@path_param('username', 'The username of the user being managed')
+@resource("/v1/superuser/users/")
+@path_param("username", "The username of the user being managed")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserManagement(ApiResource):
- """ Resource for managing users in the system. """
- schemas = {
- 'UpdateUser': {
- 'id': 'UpdateUser',
- 'type': 'object',
- 'description': 'Description of updates for a user',
- 'properties': {
- 'password': {
- 'type': 'string',
- 'description': 'The new password for the user',
- },
- 'email': {
- 'type': 'string',
- 'description': 'The new e-mail address for the user',
- },
- 'enabled': {
- 'type': 'boolean',
- 'description': 'Whether the user is enabled'
+ """ Resource for managing users in the system. """
+
+ schemas = {
+ "UpdateUser": {
+ "id": "UpdateUser",
+ "type": "object",
+ "description": "Description of updates for a user",
+ "properties": {
+ "password": {
+ "type": "string",
+ "description": "The new password for the user",
+ },
+ "email": {
+ "type": "string",
+ "description": "The new e-mail address for the user",
+ },
+ "enabled": {
+ "type": "boolean",
+ "description": "Whether the user is enabled",
+ },
+ },
}
- },
- },
- }
+ }
- @require_fresh_login
- @verify_not_prod
- @nickname('getInstallUser')
- @require_scope(scopes.SUPERUSER)
- def get(self, username):
- """ Returns information about the specified user. """
- if SuperUserPermission().can():
- user = pre_oci_model.get_nonrobot_user(username)
- if user is None:
- raise NotFound()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("getInstallUser")
+ @require_scope(scopes.SUPERUSER)
+ def get(self, username):
+ """ Returns information about the specified user. """
+ if SuperUserPermission().can():
+ user = pre_oci_model.get_nonrobot_user(username)
+ if user is None:
+ raise NotFound()
- return user.to_dict()
+ return user.to_dict()
- raise Unauthorized()
+ raise Unauthorized()
- @require_fresh_login
- @verify_not_prod
- @nickname('deleteInstallUser')
- @require_scope(scopes.SUPERUSER)
- def delete(self, username):
- """ Deletes the specified user. """
- if SuperUserPermission().can():
- user = pre_oci_model.get_nonrobot_user(username)
- if user is None:
- raise NotFound()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("deleteInstallUser")
+ @require_scope(scopes.SUPERUSER)
+ def delete(self, username):
+ """ Deletes the specified user. """
+ if SuperUserPermission().can():
+ user = pre_oci_model.get_nonrobot_user(username)
+ if user is None:
+ raise NotFound()
- if superusers.is_superuser(username):
- raise InvalidRequest('Cannot delete a superuser')
+ if superusers.is_superuser(username):
+ raise InvalidRequest("Cannot delete a superuser")
- pre_oci_model.mark_user_for_deletion(username)
- return '', 204
+ pre_oci_model.mark_user_for_deletion(username)
+ return "", 204
- raise Unauthorized()
+ raise Unauthorized()
- @require_fresh_login
- @verify_not_prod
- @nickname('changeInstallUser')
- @validate_json_request('UpdateUser')
- @require_scope(scopes.SUPERUSER)
- def put(self, username):
- """ Updates information about the specified user. """
- if SuperUserPermission().can():
- user = pre_oci_model.get_nonrobot_user(username)
- if user is None:
- raise NotFound()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("changeInstallUser")
+ @validate_json_request("UpdateUser")
+ @require_scope(scopes.SUPERUSER)
+ def put(self, username):
+ """ Updates information about the specified user. """
+ if SuperUserPermission().can():
+ user = pre_oci_model.get_nonrobot_user(username)
+ if user is None:
+ raise NotFound()
- if superusers.is_superuser(username):
- raise InvalidRequest('Cannot update a superuser')
+ if superusers.is_superuser(username):
+ raise InvalidRequest("Cannot update a superuser")
- user_data = request.get_json()
- if 'password' in user_data:
- # Ensure that we are using database auth.
- if app.config['AUTHENTICATION_TYPE'] != 'Database':
- raise InvalidRequest('Cannot change password in non-database auth')
+ user_data = request.get_json()
+ if "password" in user_data:
+ # Ensure that we are using database auth.
+ if app.config["AUTHENTICATION_TYPE"] != "Database":
+ raise InvalidRequest("Cannot change password in non-database auth")
- pre_oci_model.change_password(username, user_data['password'])
+ pre_oci_model.change_password(username, user_data["password"])
- if 'email' in user_data:
- # Ensure that we are using database auth.
- if app.config['AUTHENTICATION_TYPE'] not in ['Database', 'AppToken']:
- raise InvalidRequest('Cannot change e-mail in non-database auth')
+ if "email" in user_data:
+ # Ensure that we are using database auth.
+ if app.config["AUTHENTICATION_TYPE"] not in ["Database", "AppToken"]:
+ raise InvalidRequest("Cannot change e-mail in non-database auth")
- pre_oci_model.update_email(username, user_data['email'], auto_verify=True)
+ pre_oci_model.update_email(
+ username, user_data["email"], auto_verify=True
+ )
- if 'enabled' in user_data:
- # Disable/enable the user.
- pre_oci_model.update_enabled(username, bool(user_data['enabled']))
+ if "enabled" in user_data:
+ # Disable/enable the user.
+ pre_oci_model.update_enabled(username, bool(user_data["enabled"]))
- if 'superuser' in user_data:
- config_object = config_provider.get_config()
- superusers_set = set(config_object['SUPER_USERS'])
+ if "superuser" in user_data:
+ config_object = config_provider.get_config()
+ superusers_set = set(config_object["SUPER_USERS"])
- if user_data['superuser']:
- superusers_set.add(username)
- elif username in superusers_set:
- superusers_set.remove(username)
+ if user_data["superuser"]:
+ superusers_set.add(username)
+ elif username in superusers_set:
+ superusers_set.remove(username)
- config_object['SUPER_USERS'] = list(superusers_set)
- config_provider.save_config(config_object)
+ config_object["SUPER_USERS"] = list(superusers_set)
+ config_provider.save_config(config_object)
- return_value = user.to_dict()
- if user_data.get('password') is not None:
- password = user_data.get('password')
- return_value['encrypted_password'] = authentication.encrypt_user_password(password)
- if user_data.get('email') is not None:
- return_value['email'] = user_data.get('email')
+ return_value = user.to_dict()
+ if user_data.get("password") is not None:
+ password = user_data.get("password")
+ return_value[
+ "encrypted_password"
+ ] = authentication.encrypt_user_password(password)
+ if user_data.get("email") is not None:
+ return_value["email"] = user_data.get("email")
- return return_value
+ return return_value
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser/takeownership/')
-@path_param('namespace', 'The namespace of the user or organization being managed')
+@resource("/v1/superuser/takeownership/")
+@path_param("namespace", "The namespace of the user or organization being managed")
@internal_only
@show_if(features.SUPER_USERS)
class SuperUserTakeOwnership(ApiResource):
- """ Resource for a superuser to take ownership of a namespace. """
+ """ Resource for a superuser to take ownership of a namespace. """
- @require_fresh_login
- @verify_not_prod
- @nickname('takeOwnership')
- @require_scope(scopes.SUPERUSER)
- def post(self, namespace):
- """ Takes ownership of the specified organization or user. """
- if SuperUserPermission().can():
- # Disallow for superusers.
- if superusers.is_superuser(namespace):
- raise InvalidRequest('Cannot take ownership of a superuser')
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("takeOwnership")
+ @require_scope(scopes.SUPERUSER)
+ def post(self, namespace):
+ """ Takes ownership of the specified organization or user. """
+ if SuperUserPermission().can():
+ # Disallow for superusers.
+ if superusers.is_superuser(namespace):
+ raise InvalidRequest("Cannot take ownership of a superuser")
- authed_user = get_authenticated_user()
- entity_id, was_user = pre_oci_model.take_ownership(namespace, authed_user)
- if entity_id is None:
- raise NotFound()
+ authed_user = get_authenticated_user()
+ entity_id, was_user = pre_oci_model.take_ownership(namespace, authed_user)
+ if entity_id is None:
+ raise NotFound()
- # Log the change.
- log_metadata = {
- 'entity_id': entity_id,
- 'namespace': namespace,
- 'was_user': was_user,
- 'superuser': authed_user.username,
- }
+ # Log the change.
+ log_metadata = {
+ "entity_id": entity_id,
+ "namespace": namespace,
+ "was_user": was_user,
+ "superuser": authed_user.username,
+ }
- log_action('take_ownership', authed_user.username, log_metadata)
+ log_action("take_ownership", authed_user.username, log_metadata)
- return jsonify({
- 'namespace': namespace
- })
+ return jsonify({"namespace": namespace})
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser/organizations/')
-@path_param('name', 'The name of the organizaton being managed')
+@resource("/v1/superuser/organizations/")
+@path_param("name", "The name of the organizaton being managed")
@show_if(features.SUPER_USERS)
class SuperUserOrganizationManagement(ApiResource):
- """ Resource for managing organizations in the system. """
- schemas = {
- 'UpdateOrg': {
- 'id': 'UpdateOrg',
- 'type': 'object',
- 'description': 'Description of updates for an organization',
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The new name for the organization',
+ """ Resource for managing organizations in the system. """
+
+ schemas = {
+ "UpdateOrg": {
+ "id": "UpdateOrg",
+ "type": "object",
+ "description": "Description of updates for an organization",
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The new name for the organization",
+ }
+ },
}
- },
- },
- }
+ }
- @require_fresh_login
- @verify_not_prod
- @nickname('deleteOrganization')
- @require_scope(scopes.SUPERUSER)
- def delete(self, name):
- """ Deletes the specified organization. """
- if SuperUserPermission().can():
- pre_oci_model.mark_organization_for_deletion(name)
- return '', 204
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("deleteOrganization")
+ @require_scope(scopes.SUPERUSER)
+ def delete(self, name):
+ """ Deletes the specified organization. """
+ if SuperUserPermission().can():
+ pre_oci_model.mark_organization_for_deletion(name)
+ return "", 204
- raise Unauthorized()
+ raise Unauthorized()
- @require_fresh_login
- @verify_not_prod
- @nickname('changeOrganization')
- @validate_json_request('UpdateOrg')
- @require_scope(scopes.SUPERUSER)
- def put(self, name):
- """ Updates information about the specified user. """
- if SuperUserPermission().can():
- org_data = request.get_json()
- old_name = org_data['name'] if 'name' in org_data else None
- org = pre_oci_model.change_organization_name(name, old_name)
- return org.to_dict()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("changeOrganization")
+ @validate_json_request("UpdateOrg")
+ @require_scope(scopes.SUPERUSER)
+ def put(self, name):
+ """ Updates information about the specified user. """
+ if SuperUserPermission().can():
+ org_data = request.get_json()
+ old_name = org_data["name"] if "name" in org_data else None
+ org = pre_oci_model.change_organization_name(name, old_name)
+ return org.to_dict()
- raise Unauthorized()
+ raise Unauthorized()
def key_view(key):
- return {
- 'name': key.name,
- 'kid': key.kid,
- 'service': key.service,
- 'jwk': key.jwk,
- 'metadata': key.metadata,
- 'created_date': key.created_date,
- 'expiration_date': key.expiration_date,
- 'rotation_duration': key.rotation_duration,
- 'approval': approval_view(key.approval) if key.approval is not None else None,
- }
+ return {
+ "name": key.name,
+ "kid": key.kid,
+ "service": key.service,
+ "jwk": key.jwk,
+ "metadata": key.metadata,
+ "created_date": key.created_date,
+ "expiration_date": key.expiration_date,
+ "rotation_duration": key.rotation_duration,
+ "approval": approval_view(key.approval) if key.approval is not None else None,
+ }
def approval_view(approval):
- return {
- 'approver': user_view(approval.approver) if approval.approver else None,
- 'approval_type': approval.approval_type,
- 'approved_date': approval.approved_date,
- 'notes': approval.notes,
- }
+ return {
+ "approver": user_view(approval.approver) if approval.approver else None,
+ "approval_type": approval.approval_type,
+ "approved_date": approval.approved_date,
+ "notes": approval.notes,
+ }
-@resource('/v1/superuser/keys')
+@resource("/v1/superuser/keys")
@show_if(features.SUPER_USERS)
class SuperUserServiceKeyManagement(ApiResource):
- """ Resource for managing service keys."""
- schemas = {
- 'CreateServiceKey': {
- 'id': 'CreateServiceKey',
- 'type': 'object',
- 'description': 'Description of creation of a service key',
- 'required': ['service', 'expiration'],
- 'properties': {
- 'service': {
- 'type': 'string',
- 'description': 'The service authenticating with this key',
- },
- 'name': {
- 'type': 'string',
- 'description': 'The friendly name of a service key',
- },
- 'metadata': {
- 'type': 'object',
- 'description': 'The key/value pairs of this key\'s metadata',
- },
- 'notes': {
- 'type': 'string',
- 'description': 'If specified, the extra notes for the key',
- },
- 'expiration': {
- 'description': 'The expiration date as a unix timestamp',
- 'anyOf': [{'type': 'number'}, {'type': 'null'}],
- },
- },
- },
- }
+ """ Resource for managing service keys."""
- @verify_not_prod
- @nickname('listServiceKeys')
- @require_scope(scopes.SUPERUSER)
- def get(self):
- if SuperUserPermission().can():
- keys = pre_oci_model.list_all_service_keys()
+ schemas = {
+ "CreateServiceKey": {
+ "id": "CreateServiceKey",
+ "type": "object",
+ "description": "Description of creation of a service key",
+ "required": ["service", "expiration"],
+ "properties": {
+ "service": {
+ "type": "string",
+ "description": "The service authenticating with this key",
+ },
+ "name": {
+ "type": "string",
+ "description": "The friendly name of a service key",
+ },
+ "metadata": {
+ "type": "object",
+ "description": "The key/value pairs of this key's metadata",
+ },
+ "notes": {
+ "type": "string",
+ "description": "If specified, the extra notes for the key",
+ },
+ "expiration": {
+ "description": "The expiration date as a unix timestamp",
+ "anyOf": [{"type": "number"}, {"type": "null"}],
+ },
+ },
+ }
+ }
- return jsonify({
- 'keys': [key.to_dict() for key in keys],
- })
+ @verify_not_prod
+ @nickname("listServiceKeys")
+ @require_scope(scopes.SUPERUSER)
+ def get(self):
+ if SuperUserPermission().can():
+ keys = pre_oci_model.list_all_service_keys()
- raise Unauthorized()
+ return jsonify({"keys": [key.to_dict() for key in keys]})
- @require_fresh_login
- @verify_not_prod
- @nickname('createServiceKey')
- @require_scope(scopes.SUPERUSER)
- @validate_json_request('CreateServiceKey')
- def post(self):
- if SuperUserPermission().can():
- body = request.get_json()
- key_name = body.get('name', '')
- if not validate_service_key_name(key_name):
- raise InvalidRequest('Invalid service key friendly name: %s' % key_name)
+ raise Unauthorized()
- # Ensure we have a valid expiration date if specified.
- expiration_date = body.get('expiration', None)
- if expiration_date is not None:
- try:
- expiration_date = datetime.utcfromtimestamp(float(expiration_date))
- except ValueError as ve:
- raise InvalidRequest('Invalid expiration date: %s' % ve)
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("createServiceKey")
+ @require_scope(scopes.SUPERUSER)
+ @validate_json_request("CreateServiceKey")
+ def post(self):
+ if SuperUserPermission().can():
+ body = request.get_json()
+ key_name = body.get("name", "")
+ if not validate_service_key_name(key_name):
+ raise InvalidRequest("Invalid service key friendly name: %s" % key_name)
- if expiration_date <= datetime.now():
- raise InvalidRequest('Expiration date cannot be in the past')
+ # Ensure we have a valid expiration date if specified.
+ expiration_date = body.get("expiration", None)
+ if expiration_date is not None:
+ try:
+ expiration_date = datetime.utcfromtimestamp(float(expiration_date))
+ except ValueError as ve:
+ raise InvalidRequest("Invalid expiration date: %s" % ve)
- # Create the metadata for the key.
- user = get_authenticated_user()
- metadata = body.get('metadata', {})
- metadata.update({
- 'created_by': 'Quay Superuser Panel',
- 'creator': user.username,
- 'ip': get_request_ip(),
- })
+ if expiration_date <= datetime.now():
+ raise InvalidRequest("Expiration date cannot be in the past")
- # Generate a key with a private key that we *never save*.
- (private_key, key_id) = pre_oci_model.generate_service_key(body['service'], expiration_date,
- metadata=metadata,
- name=key_name)
- # Auto-approve the service key.
- pre_oci_model.approve_service_key(key_id, user, ServiceKeyApprovalType.SUPERUSER,
- notes=body.get('notes', ''))
+ # Create the metadata for the key.
+ user = get_authenticated_user()
+ metadata = body.get("metadata", {})
+ metadata.update(
+ {
+ "created_by": "Quay Superuser Panel",
+ "creator": user.username,
+ "ip": get_request_ip(),
+ }
+ )
- # Log the creation and auto-approval of the service key.
- key_log_metadata = {
- 'kid': key_id,
- 'preshared': True,
- 'service': body['service'],
- 'name': key_name,
- 'expiration_date': expiration_date,
- 'auto_approved': True,
- }
+ # Generate a key with a private key that we *never save*.
+ (private_key, key_id) = pre_oci_model.generate_service_key(
+ body["service"], expiration_date, metadata=metadata, name=key_name
+ )
+ # Auto-approve the service key.
+ pre_oci_model.approve_service_key(
+ key_id,
+ user,
+ ServiceKeyApprovalType.SUPERUSER,
+ notes=body.get("notes", ""),
+ )
- log_action('service_key_create', None, key_log_metadata)
- log_action('service_key_approve', None, key_log_metadata)
+ # Log the creation and auto-approval of the service key.
+ key_log_metadata = {
+ "kid": key_id,
+ "preshared": True,
+ "service": body["service"],
+ "name": key_name,
+ "expiration_date": expiration_date,
+ "auto_approved": True,
+ }
- return jsonify({
- 'kid': key_id,
- 'name': key_name,
- 'service': body['service'],
- 'public_key': private_key.publickey().exportKey('PEM'),
- 'private_key': private_key.exportKey('PEM'),
- })
+ log_action("service_key_create", None, key_log_metadata)
+ log_action("service_key_approve", None, key_log_metadata)
- raise Unauthorized()
+ return jsonify(
+ {
+ "kid": key_id,
+ "name": key_name,
+ "service": body["service"],
+ "public_key": private_key.publickey().exportKey("PEM"),
+ "private_key": private_key.exportKey("PEM"),
+ }
+ )
+
+ raise Unauthorized()
-@resource('/v1/superuser/keys/')
-@path_param('kid', 'The unique identifier for a service key')
+@resource("/v1/superuser/keys/")
+@path_param("kid", "The unique identifier for a service key")
@show_if(features.SUPER_USERS)
class SuperUserServiceKey(ApiResource):
- """ Resource for managing service keys. """
- schemas = {
- 'PutServiceKey': {
- 'id': 'PutServiceKey',
- 'type': 'object',
- 'description': 'Description of updates for a service key',
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'The friendly name of a service key',
- },
- 'metadata': {
- 'type': 'object',
- 'description': 'The key/value pairs of this key\'s metadata',
- },
- 'expiration': {
- 'description': 'The expiration date as a unix timestamp',
- 'anyOf': [{'type': 'number'}, {'type': 'null'}],
- },
- },
- },
- }
+ """ Resource for managing service keys. """
- @verify_not_prod
- @nickname('getServiceKey')
- @require_scope(scopes.SUPERUSER)
- def get(self, kid):
- if SuperUserPermission().can():
- try:
- key = pre_oci_model.get_service_key(kid, approved_only=False, alive_only=False)
- return jsonify(key.to_dict())
- except ServiceKeyDoesNotExist:
- raise NotFound()
+ schemas = {
+ "PutServiceKey": {
+ "id": "PutServiceKey",
+ "type": "object",
+ "description": "Description of updates for a service key",
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The friendly name of a service key",
+ },
+ "metadata": {
+ "type": "object",
+ "description": "The key/value pairs of this key's metadata",
+ },
+ "expiration": {
+ "description": "The expiration date as a unix timestamp",
+ "anyOf": [{"type": "number"}, {"type": "null"}],
+ },
+ },
+ }
+ }
- raise Unauthorized()
+ @verify_not_prod
+ @nickname("getServiceKey")
+ @require_scope(scopes.SUPERUSER)
+ def get(self, kid):
+ if SuperUserPermission().can():
+ try:
+ key = pre_oci_model.get_service_key(
+ kid, approved_only=False, alive_only=False
+ )
+ return jsonify(key.to_dict())
+ except ServiceKeyDoesNotExist:
+ raise NotFound()
- @require_fresh_login
- @verify_not_prod
- @nickname('updateServiceKey')
- @require_scope(scopes.SUPERUSER)
- @validate_json_request('PutServiceKey')
- def put(self, kid):
- if SuperUserPermission().can():
- body = request.get_json()
- try:
- key = pre_oci_model.get_service_key(kid, approved_only=False, alive_only=False)
- except ServiceKeyDoesNotExist:
- raise NotFound()
+ raise Unauthorized()
- key_log_metadata = {
- 'kid': key.kid,
- 'service': key.service,
- 'name': body.get('name', key.name),
- 'expiration_date': key.expiration_date,
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("updateServiceKey")
+ @require_scope(scopes.SUPERUSER)
+ @validate_json_request("PutServiceKey")
+ def put(self, kid):
+ if SuperUserPermission().can():
+ body = request.get_json()
+ try:
+ key = pre_oci_model.get_service_key(
+ kid, approved_only=False, alive_only=False
+ )
+ except ServiceKeyDoesNotExist:
+ raise NotFound()
- if 'expiration' in body:
- expiration_date = body['expiration']
- if expiration_date is not None and expiration_date != '':
- try:
- expiration_date = datetime.utcfromtimestamp(float(expiration_date))
- except ValueError as ve:
- raise InvalidRequest('Invalid expiration date: %s' % ve)
+ key_log_metadata = {
+ "kid": key.kid,
+ "service": key.service,
+ "name": body.get("name", key.name),
+ "expiration_date": key.expiration_date,
+ }
- if expiration_date <= datetime.now():
- raise InvalidRequest('Cannot have an expiration date in the past')
+ if "expiration" in body:
+ expiration_date = body["expiration"]
+ if expiration_date is not None and expiration_date != "":
+ try:
+ expiration_date = datetime.utcfromtimestamp(
+ float(expiration_date)
+ )
+ except ValueError as ve:
+ raise InvalidRequest("Invalid expiration date: %s" % ve)
- key_log_metadata.update({
- 'old_expiration_date': key.expiration_date,
- 'expiration_date': expiration_date,
- })
+ if expiration_date <= datetime.now():
+ raise InvalidRequest(
+ "Cannot have an expiration date in the past"
+ )
- log_action('service_key_extend', None, key_log_metadata)
- pre_oci_model.set_key_expiration(kid, expiration_date)
+ key_log_metadata.update(
+ {
+ "old_expiration_date": key.expiration_date,
+ "expiration_date": expiration_date,
+ }
+ )
- if 'name' in body or 'metadata' in body:
- key_name = body.get('name')
- if not validate_service_key_name(key_name):
- raise InvalidRequest('Invalid service key friendly name: %s' % key_name)
+ log_action("service_key_extend", None, key_log_metadata)
+ pre_oci_model.set_key_expiration(kid, expiration_date)
- pre_oci_model.update_service_key(kid, key_name, body.get('metadata'))
- log_action('service_key_modify', None, key_log_metadata)
+ if "name" in body or "metadata" in body:
+ key_name = body.get("name")
+ if not validate_service_key_name(key_name):
+ raise InvalidRequest(
+ "Invalid service key friendly name: %s" % key_name
+ )
- updated_key = pre_oci_model.get_service_key(kid, approved_only=False, alive_only=False)
- return jsonify(updated_key.to_dict())
+ pre_oci_model.update_service_key(kid, key_name, body.get("metadata"))
+ log_action("service_key_modify", None, key_log_metadata)
- raise Unauthorized()
+ updated_key = pre_oci_model.get_service_key(
+ kid, approved_only=False, alive_only=False
+ )
+ return jsonify(updated_key.to_dict())
- @require_fresh_login
- @verify_not_prod
- @nickname('deleteServiceKey')
- @require_scope(scopes.SUPERUSER)
- def delete(self, kid):
- if SuperUserPermission().can():
- try:
- key = pre_oci_model.delete_service_key(kid)
- except ServiceKeyDoesNotExist:
- raise NotFound()
+ raise Unauthorized()
- key_log_metadata = {
- 'kid': kid,
- 'service': key.service,
- 'name': key.name,
- 'created_date': key.created_date,
- 'expiration_date': key.expiration_date,
- }
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("deleteServiceKey")
+ @require_scope(scopes.SUPERUSER)
+ def delete(self, kid):
+ if SuperUserPermission().can():
+ try:
+ key = pre_oci_model.delete_service_key(kid)
+ except ServiceKeyDoesNotExist:
+ raise NotFound()
- log_action('service_key_delete', None, key_log_metadata)
- return make_response('', 204)
+ key_log_metadata = {
+ "kid": kid,
+ "service": key.service,
+ "name": key.name,
+ "created_date": key.created_date,
+ "expiration_date": key.expiration_date,
+ }
- raise Unauthorized()
+ log_action("service_key_delete", None, key_log_metadata)
+ return make_response("", 204)
+
+ raise Unauthorized()
-@resource('/v1/superuser/approvedkeys/')
-@path_param('kid', 'The unique identifier for a service key')
+@resource("/v1/superuser/approvedkeys/")
+@path_param("kid", "The unique identifier for a service key")
@show_if(features.SUPER_USERS)
class SuperUserServiceKeyApproval(ApiResource):
- """ Resource for approving service keys. """
+ """ Resource for approving service keys. """
- schemas = {
- 'ApproveServiceKey': {
- 'id': 'ApproveServiceKey',
- 'type': 'object',
- 'description': 'Information for approving service keys',
- 'properties': {
- 'notes': {
- 'type': 'string',
- 'description': 'Optional approval notes',
- },
- },
- },
- }
-
- @require_fresh_login
- @verify_not_prod
- @nickname('approveServiceKey')
- @require_scope(scopes.SUPERUSER)
- @validate_json_request('ApproveServiceKey')
- def post(self, kid):
- if SuperUserPermission().can():
- notes = request.get_json().get('notes', '')
- approver = get_authenticated_user()
- try:
- key = pre_oci_model.approve_service_key(kid, approver, ServiceKeyApprovalType.SUPERUSER,
- notes=notes)
-
- # Log the approval of the service key.
- key_log_metadata = {
- 'kid': kid,
- 'service': key.service,
- 'name': key.name,
- 'expiration_date': key.expiration_date,
+ schemas = {
+ "ApproveServiceKey": {
+ "id": "ApproveServiceKey",
+ "type": "object",
+ "description": "Information for approving service keys",
+ "properties": {
+ "notes": {"type": "string", "description": "Optional approval notes"}
+ },
}
+ }
- log_action('service_key_approve', None, key_log_metadata)
- except ServiceKeyDoesNotExist:
- raise NotFound()
- except ServiceKeyAlreadyApproved:
- pass
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("approveServiceKey")
+ @require_scope(scopes.SUPERUSER)
+ @validate_json_request("ApproveServiceKey")
+ def post(self, kid):
+ if SuperUserPermission().can():
+ notes = request.get_json().get("notes", "")
+ approver = get_authenticated_user()
+ try:
+ key = pre_oci_model.approve_service_key(
+ kid, approver, ServiceKeyApprovalType.SUPERUSER, notes=notes
+ )
- return make_response('', 201)
+ # Log the approval of the service key.
+ key_log_metadata = {
+ "kid": kid,
+ "service": key.service,
+ "name": key.name,
+ "expiration_date": key.expiration_date,
+ }
- raise Unauthorized()
+ log_action("service_key_approve", None, key_log_metadata)
+ except ServiceKeyDoesNotExist:
+ raise NotFound()
+ except ServiceKeyAlreadyApproved:
+ pass
+
+ return make_response("", 201)
+
+ raise Unauthorized()
-@resource('/v1/superuser//logs')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/superuser//logs")
+@path_param("build_uuid", "The UUID of the build")
@show_if(features.SUPER_USERS)
class SuperUserRepositoryBuildLogs(ApiResource):
- """ Resource for loading repository build logs for the superuser. """
+ """ Resource for loading repository build logs for the superuser. """
- @require_fresh_login
- @verify_not_prod
- @nickname('getRepoBuildLogsSuperUser')
- @require_scope(scopes.SUPERUSER)
- def get(self, build_uuid):
- """ Return the build logs for the build specified by the build uuid. """
- if SuperUserPermission().can():
- try:
- repo_build = pre_oci_model.get_repository_build(build_uuid)
- return get_logs_or_log_url(repo_build)
- except InvalidRepositoryBuildException as e:
- raise InvalidResponse(str(e))
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("getRepoBuildLogsSuperUser")
+ @require_scope(scopes.SUPERUSER)
+ def get(self, build_uuid):
+ """ Return the build logs for the build specified by the build uuid. """
+ if SuperUserPermission().can():
+ try:
+ repo_build = pre_oci_model.get_repository_build(build_uuid)
+ return get_logs_or_log_url(repo_build)
+ except InvalidRepositoryBuildException as e:
+ raise InvalidResponse(str(e))
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser//status')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/superuser//status")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("build_uuid", "The UUID of the build")
@show_if(features.SUPER_USERS)
class SuperUserRepositoryBuildStatus(ApiResource):
- """ Resource for dealing with repository build status. """
+ """ Resource for dealing with repository build status. """
- @require_fresh_login
- @verify_not_prod
- @nickname('getRepoBuildStatusSuperUser')
- @require_scope(scopes.SUPERUSER)
- def get(self, build_uuid):
- """ Return the status for the builds specified by the build uuids. """
- if SuperUserPermission().can():
- try:
- build = pre_oci_model.get_repository_build(build_uuid)
- except InvalidRepositoryBuildException as e:
- raise InvalidResponse(str(e))
- return build.to_dict()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("getRepoBuildStatusSuperUser")
+ @require_scope(scopes.SUPERUSER)
+ def get(self, build_uuid):
+ """ Return the status for the builds specified by the build uuids. """
+ if SuperUserPermission().can():
+ try:
+ build = pre_oci_model.get_repository_build(build_uuid)
+ except InvalidRepositoryBuildException as e:
+ raise InvalidResponse(str(e))
+ return build.to_dict()
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/superuser//build')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('build_uuid', 'The UUID of the build')
+@resource("/v1/superuser//build")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("build_uuid", "The UUID of the build")
@show_if(features.SUPER_USERS)
class SuperUserRepositoryBuildResource(ApiResource):
- """ Resource for dealing with repository builds as a super user. """
+ """ Resource for dealing with repository builds as a super user. """
- @require_fresh_login
- @verify_not_prod
- @nickname('getRepoBuildSuperUser')
- @require_scope(scopes.SUPERUSER)
- def get(self, build_uuid):
- """ Returns information about a build. """
- if SuperUserPermission().can():
- try:
- build = pre_oci_model.get_repository_build(build_uuid)
- except InvalidRepositoryBuildException:
- raise NotFound()
+ @require_fresh_login
+ @verify_not_prod
+ @nickname("getRepoBuildSuperUser")
+ @require_scope(scopes.SUPERUSER)
+ def get(self, build_uuid):
+ """ Returns information about a build. """
+ if SuperUserPermission().can():
+ try:
+ build = pre_oci_model.get_repository_build(build_uuid)
+ except InvalidRepositoryBuildException:
+ raise NotFound()
- return build.to_dict()
+ return build.to_dict()
- raise Unauthorized()
+ raise Unauthorized()
diff --git a/endpoints/api/superuser_models_interface.py b/endpoints/api/superuser_models_interface.py
index e03d98e8c..0e8211b02 100644
--- a/endpoints/api/superuser_models_interface.py
+++ b/endpoints/api/superuser_models_interface.py
@@ -15,16 +15,16 @@ from util.morecollections import AttrDict
def user_view(user):
- return {
- 'name': user.username,
- 'kind': 'user',
- 'is_robot': user.robot,
- }
+ return {"name": user.username, "kind": "user", "is_robot": user.robot}
class BuildTrigger(
- namedtuple('BuildTrigger', ['uuid', 'service_name', 'pull_robot', 'can_read', 'can_admin', 'for_build'])):
- """
+ namedtuple(
+ "BuildTrigger",
+ ["uuid", "service_name", "pull_robot", "can_read", "can_admin", "for_build"],
+ )
+):
+ """
BuildTrigger represent a trigger that is associated with a build
:type uuid: string
:type service_name: string
@@ -34,39 +34,56 @@ class BuildTrigger(
:type for_build: boolean
"""
- def to_dict(self):
- if not self.uuid:
- return None
+ def to_dict(self):
+ if not self.uuid:
+ return None
- build_trigger = BuildTriggerHandler.get_handler(self)
- build_source = build_trigger.config.get('build_source')
+ build_trigger = BuildTriggerHandler.get_handler(self)
+ build_source = build_trigger.config.get("build_source")
- repo_url = build_trigger.get_repository_url() if build_source else None
- can_read = self.can_read or self.can_admin
+ repo_url = build_trigger.get_repository_url() if build_source else None
+ can_read = self.can_read or self.can_admin
- trigger_data = {
- 'id': self.uuid,
- 'service': self.service_name,
- 'is_active': build_trigger.is_active(),
+ trigger_data = {
+ "id": self.uuid,
+ "service": self.service_name,
+ "is_active": build_trigger.is_active(),
+ "build_source": build_source if can_read else None,
+ "repository_url": repo_url if can_read else None,
+ "config": build_trigger.config if self.can_admin else {},
+ "can_invoke": self.can_admin,
+ }
- 'build_source': build_source if can_read else None,
- 'repository_url': repo_url if can_read else None,
+ if not self.for_build and self.can_admin and self.pull_robot:
+ trigger_data["pull_robot"] = user_view(self.pull_robot)
- 'config': build_trigger.config if self.can_admin else {},
- 'can_invoke': self.can_admin,
- }
-
- if not self.for_build and self.can_admin and self.pull_robot:
- trigger_data['pull_robot'] = user_view(self.pull_robot)
-
- return trigger_data
+ return trigger_data
-class RepositoryBuild(namedtuple('RepositoryBuild',
- ['uuid', 'logs_archived', 'repository_namespace_user_username', 'repository_name',
- 'can_write', 'can_read', 'pull_robot', 'resource_key', 'trigger', 'display_name',
- 'started', 'job_config', 'phase', 'status', 'error', 'archive_url'])):
- """
+class RepositoryBuild(
+ namedtuple(
+ "RepositoryBuild",
+ [
+ "uuid",
+ "logs_archived",
+ "repository_namespace_user_username",
+ "repository_name",
+ "can_write",
+ "can_read",
+ "pull_robot",
+ "resource_key",
+ "trigger",
+ "display_name",
+ "started",
+ "job_config",
+ "phase",
+ "status",
+ "error",
+ "archive_url",
+ ],
+ )
+):
+ """
RepositoryBuild represents a build associated with a repostiory
:type uuid: string
:type logs_archived: boolean
@@ -86,42 +103,46 @@ class RepositoryBuild(namedtuple('RepositoryBuild',
:type archive_url: string
"""
- def to_dict(self):
+ def to_dict(self):
- resp = {
- 'id': self.uuid,
- 'phase': self.phase,
- 'started': format_date(self.started),
- 'display_name': self.display_name,
- 'status': self.status or {},
- 'subdirectory': self.job_config.get('build_subdir', ''),
- 'dockerfile_path': self.job_config.get('build_subdir', ''),
- 'context': self.job_config.get('context', ''),
- 'tags': self.job_config.get('docker_tags', []),
- 'manual_user': self.job_config.get('manual_user', None),
- 'is_writer': self.can_write,
- 'trigger': self.trigger.to_dict(),
- 'trigger_metadata': self.job_config.get('trigger_metadata', None) if self.can_read else None,
- 'resource_key': self.resource_key,
- 'pull_robot': user_view(self.pull_robot) if self.pull_robot else None,
- 'repository': {
- 'namespace': self.repository_namespace_user_username,
- 'name': self.repository_name
- },
- 'error': self.error,
- }
+ resp = {
+ "id": self.uuid,
+ "phase": self.phase,
+ "started": format_date(self.started),
+ "display_name": self.display_name,
+ "status": self.status or {},
+ "subdirectory": self.job_config.get("build_subdir", ""),
+ "dockerfile_path": self.job_config.get("build_subdir", ""),
+ "context": self.job_config.get("context", ""),
+ "tags": self.job_config.get("docker_tags", []),
+ "manual_user": self.job_config.get("manual_user", None),
+ "is_writer": self.can_write,
+ "trigger": self.trigger.to_dict(),
+ "trigger_metadata": self.job_config.get("trigger_metadata", None)
+ if self.can_read
+ else None,
+ "resource_key": self.resource_key,
+ "pull_robot": user_view(self.pull_robot) if self.pull_robot else None,
+ "repository": {
+ "namespace": self.repository_namespace_user_username,
+ "name": self.repository_name,
+ },
+ "error": self.error,
+ }
- if self.can_write:
- if self.resource_key is not None:
- resp['archive_url'] = self.archive_url
- elif self.job_config.get('archive_url', None):
- resp['archive_url'] = self.job_config['archive_url']
+ if self.can_write:
+ if self.resource_key is not None:
+ resp["archive_url"] = self.archive_url
+ elif self.job_config.get("archive_url", None):
+ resp["archive_url"] = self.job_config["archive_url"]
- return resp
+ return resp
-class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_date', 'notes'])):
- """
+class Approval(
+ namedtuple("Approval", ["approver", "approval_type", "approved_date", "notes"])
+):
+ """
Approval represents whether a key has been approved or not
:type approver: User
:type approval_type: string
@@ -129,18 +150,32 @@ class Approval(namedtuple('Approval', ['approver', 'approval_type', 'approved_da
:type notes: string
"""
- def to_dict(self):
- return {
- 'approver': self.approver.to_dict() if self.approver else None,
- 'approval_type': self.approval_type,
- 'approved_date': self.approved_date,
- 'notes': self.notes,
- }
+ def to_dict(self):
+ return {
+ "approver": self.approver.to_dict() if self.approver else None,
+ "approval_type": self.approval_type,
+ "approved_date": self.approved_date,
+ "notes": self.notes,
+ }
-class ServiceKey(namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'metadata', 'created_date',
- 'expiration_date', 'rotation_duration', 'approval'])):
- """
+class ServiceKey(
+ namedtuple(
+ "ServiceKey",
+ [
+ "name",
+ "kid",
+ "service",
+ "jwk",
+ "metadata",
+ "created_date",
+ "expiration_date",
+ "rotation_duration",
+ "approval",
+ ],
+ )
+):
+ """
ServiceKey is an apostille signing key
:type name: string
:type kid: int
@@ -154,22 +189,22 @@ class ServiceKey(namedtuple('ServiceKey', ['name', 'kid', 'service', 'jwk', 'met
"""
- def to_dict(self):
- return {
- 'name': self.name,
- 'kid': self.kid,
- 'service': self.service,
- 'jwk': self.jwk,
- 'metadata': self.metadata,
- 'created_date': self.created_date,
- 'expiration_date': self.expiration_date,
- 'rotation_duration': self.rotation_duration,
- 'approval': self.approval.to_dict() if self.approval is not None else None,
- }
+ def to_dict(self):
+ return {
+ "name": self.name,
+ "kid": self.kid,
+ "service": self.service,
+ "jwk": self.jwk,
+ "metadata": self.metadata,
+ "created_date": self.created_date,
+ "expiration_date": self.expiration_date,
+ "rotation_duration": self.rotation_duration,
+ "approval": self.approval.to_dict() if self.approval is not None else None,
+ }
-class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robot'])):
- """
+class User(namedtuple("User", ["username", "email", "verified", "enabled", "robot"])):
+ """
User represents a single user.
:type username: string
:type email: string
@@ -178,158 +213,166 @@ class User(namedtuple('User', ['username', 'email', 'verified', 'enabled', 'robo
:type robot: User
"""
- def to_dict(self):
- user_data = {
- 'kind': 'user',
- 'name': self.username,
- 'username': self.username,
- 'email': self.email,
- 'verified': self.verified,
- 'avatar': avatar.get_data_for_user(self),
- 'super_user': superusers.is_superuser(self.username),
- 'enabled': self.enabled,
- }
+ def to_dict(self):
+ user_data = {
+ "kind": "user",
+ "name": self.username,
+ "username": self.username,
+ "email": self.email,
+ "verified": self.verified,
+ "avatar": avatar.get_data_for_user(self),
+ "super_user": superusers.is_superuser(self.username),
+ "enabled": self.enabled,
+ }
- return user_data
+ return user_data
-class Organization(namedtuple('Organization', ['username', 'email'])):
- """
+class Organization(namedtuple("Organization", ["username", "email"])):
+ """
Organization represents a single org.
:type username: string
:type email: string
"""
- def to_dict(self):
- return {
- 'name': self.username,
- 'email': self.email,
- 'avatar': avatar.get_data_for_org(self),
- }
+ def to_dict(self):
+ return {
+ "name": self.username,
+ "email": self.email,
+ "avatar": avatar.get_data_for_org(self),
+ }
@add_metaclass(ABCMeta)
class SuperuserDataInterface(object):
- """
+ """
Interface that represents all data store interactions required by a superuser api.
"""
- @abstractmethod
- def get_organizations(self):
- """
+ @abstractmethod
+ def get_organizations(self):
+ """
Returns a list of Organization
"""
- @abstractmethod
- def get_active_users(self):
- """
+ @abstractmethod
+ def get_active_users(self):
+ """
Returns a list of User
"""
- @abstractmethod
- def create_install_user(self, username, password, email):
- """
+ @abstractmethod
+ def create_install_user(self, username, password, email):
+ """
Returns the created user and confirmation code for email confirmation
"""
- @abstractmethod
- def get_nonrobot_user(self, username):
- """
+ @abstractmethod
+ def get_nonrobot_user(self, username):
+ """
Returns a User
"""
- @abstractmethod
- def create_reset_password_email_code(self, email):
- """
+ @abstractmethod
+ def create_reset_password_email_code(self, email):
+ """
Returns a recover password code
"""
- @abstractmethod
- def mark_user_for_deletion(self, username):
- """
+ @abstractmethod
+ def mark_user_for_deletion(self, username):
+ """
Returns None
"""
- @abstractmethod
- def change_password(self, username, password):
- """
+ @abstractmethod
+ def change_password(self, username, password):
+ """
Returns None
"""
- @abstractmethod
- def update_email(self, username, email, auto_verify):
- """
+ @abstractmethod
+ def update_email(self, username, email, auto_verify):
+ """
Returns None
"""
- @abstractmethod
- def update_enabled(self, username, enabled):
- """
+ @abstractmethod
+ def update_enabled(self, username, enabled):
+ """
Returns None
"""
- @abstractmethod
- def take_ownership(self, namespace, authed_user):
- """
+ @abstractmethod
+ def take_ownership(self, namespace, authed_user):
+ """
Returns id of entity and whether the entity was a user
"""
- @abstractmethod
- def mark_organization_for_deletion(self, name):
- """
+ @abstractmethod
+ def mark_organization_for_deletion(self, name):
+ """
Returns None
"""
- @abstractmethod
- def change_organization_name(self, old_org_name, new_org_name):
- """
+ @abstractmethod
+ def change_organization_name(self, old_org_name, new_org_name):
+ """
Returns updated Organization
"""
- @abstractmethod
- def list_all_service_keys(self):
- """
+ @abstractmethod
+ def list_all_service_keys(self):
+ """
Returns a list of service keys
"""
- @abstractmethod
- def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None, rotation_duration=None):
- """
+ @abstractmethod
+ def generate_service_key(
+ self,
+ service,
+ expiration_date,
+ kid=None,
+ name="",
+ metadata=None,
+ rotation_duration=None,
+ ):
+ """
Returns a tuple of private key and public key id
"""
- @abstractmethod
- def approve_service_key(self, kid, approver, approval_type, notes=''):
- """
+ @abstractmethod
+ def approve_service_key(self, kid, approver, approval_type, notes=""):
+ """
Returns the approved Key
"""
- @abstractmethod
- def get_service_key(self, kid, service=None, alive_only=True, approved_only=True):
- """
+ @abstractmethod
+ def get_service_key(self, kid, service=None, alive_only=True, approved_only=True):
+ """
Returns ServiceKey
"""
- @abstractmethod
- def set_key_expiration(self, kid, expiration_date):
- """
+ @abstractmethod
+ def set_key_expiration(self, kid, expiration_date):
+ """
Returns None
"""
- @abstractmethod
- def update_service_key(self, kid, name=None, metadata=None):
- """
+ @abstractmethod
+ def update_service_key(self, kid, name=None, metadata=None):
+ """
Returns None
"""
- @abstractmethod
- def delete_service_key(self, kid):
- """
+ @abstractmethod
+ def delete_service_key(self, kid):
+ """
Returns deleted ServiceKey
"""
- @abstractmethod
- def get_repository_build(self, uuid):
- """
+ @abstractmethod
+ def get_repository_build(self, uuid):
+ """
Returns RepositoryBuild
"""
diff --git a/endpoints/api/superuser_models_pre_oci.py b/endpoints/api/superuser_models_pre_oci.py
index 0458f9226..f48de7c24 100644
--- a/endpoints/api/superuser_models_pre_oci.py
+++ b/endpoints/api/superuser_models_pre_oci.py
@@ -3,180 +3,253 @@ import features
from flask import request
from app import all_queues, userfiles, namespace_gc_queue
-from auth.permissions import ReadRepositoryPermission, ModifyRepositoryPermission, AdministerRepositoryPermission
+from auth.permissions import (
+ ReadRepositoryPermission,
+ ModifyRepositoryPermission,
+ AdministerRepositoryPermission,
+)
from data import model, database
from endpoints.api.build import get_job_config, _get_build_status
from endpoints.api.superuser_models_interface import BuildTrigger
-from endpoints.api.superuser_models_interface import SuperuserDataInterface, Organization, User, \
- ServiceKey, Approval, RepositoryBuild
+from endpoints.api.superuser_models_interface import (
+ SuperuserDataInterface,
+ Organization,
+ User,
+ ServiceKey,
+ Approval,
+ RepositoryBuild,
+)
from util.request import get_request_ip
def _create_user(user):
- if user is None:
- return None
- return User(user.username, user.email, user.verified, user.enabled, user.robot)
+ if user is None:
+ return None
+ return User(user.username, user.email, user.verified, user.enabled, user.robot)
def _create_key(key):
- approval = None
- if key.approval is not None:
- approval = Approval(_create_user(key.approval.approver), key.approval.approval_type, key.approval.approved_date,
- key.approval.notes)
+ approval = None
+ if key.approval is not None:
+ approval = Approval(
+ _create_user(key.approval.approver),
+ key.approval.approval_type,
+ key.approval.approved_date,
+ key.approval.notes,
+ )
- return ServiceKey(key.name, key.kid, key.service, key.jwk, key.metadata, key.created_date, key.expiration_date,
- key.rotation_duration, approval)
+ return ServiceKey(
+ key.name,
+ key.kid,
+ key.service,
+ key.jwk,
+ key.metadata,
+ key.created_date,
+ key.expiration_date,
+ key.rotation_duration,
+ approval,
+ )
class ServiceKeyDoesNotExist(Exception):
- pass
+ pass
class ServiceKeyAlreadyApproved(Exception):
- pass
+ pass
class InvalidRepositoryBuildException(Exception):
- pass
+ pass
class PreOCIModel(SuperuserDataInterface):
- """
+ """
PreOCIModel implements the data model for the SuperUser using a database schema
before it was changed to support the OCI specification.
"""
- def get_repository_build(self, uuid):
- try:
- build = model.build.get_repository_build(uuid)
- except model.InvalidRepositoryBuildException as e:
- raise InvalidRepositoryBuildException(str(e))
+ def get_repository_build(self, uuid):
+ try:
+ build = model.build.get_repository_build(uuid)
+ except model.InvalidRepositoryBuildException as e:
+ raise InvalidRepositoryBuildException(str(e))
- repo_namespace = build.repository_namespace_user_username
- repo_name = build.repository_name
+ repo_namespace = build.repository_namespace_user_username
+ repo_name = build.repository_name
- can_read = ReadRepositoryPermission(repo_namespace, repo_name).can()
- can_write = ModifyRepositoryPermission(repo_namespace, repo_name).can()
- can_admin = AdministerRepositoryPermission(repo_namespace, repo_name).can()
- job_config = get_job_config(build.job_config)
- phase, status, error = _get_build_status(build)
- url = userfiles.get_file_url(self.resource_key, get_request_ip(), requires_cors=True)
+ can_read = ReadRepositoryPermission(repo_namespace, repo_name).can()
+ can_write = ModifyRepositoryPermission(repo_namespace, repo_name).can()
+ can_admin = AdministerRepositoryPermission(repo_namespace, repo_name).can()
+ job_config = get_job_config(build.job_config)
+ phase, status, error = _get_build_status(build)
+ url = userfiles.get_file_url(
+ self.resource_key, get_request_ip(), requires_cors=True
+ )
- return RepositoryBuild(build.uuid, build.logs_archived, repo_namespace, repo_name, can_write, can_read,
- _create_user(build.pull_robot), build.resource_key,
- BuildTrigger(build.trigger.uuid, build.trigger.service.name,
- _create_user(build.trigger.pull_robot), can_read, can_admin, True),
- build.display_name, build.display_name, build.started, job_config, phase, status, error, url)
+ return RepositoryBuild(
+ build.uuid,
+ build.logs_archived,
+ repo_namespace,
+ repo_name,
+ can_write,
+ can_read,
+ _create_user(build.pull_robot),
+ build.resource_key,
+ BuildTrigger(
+ build.trigger.uuid,
+ build.trigger.service.name,
+ _create_user(build.trigger.pull_robot),
+ can_read,
+ can_admin,
+ True,
+ ),
+ build.display_name,
+ build.display_name,
+ build.started,
+ job_config,
+ phase,
+ status,
+ error,
+ url,
+ )
- def delete_service_key(self, kid):
- try:
- key = model.service_keys.delete_service_key(kid)
- except model.ServiceKeyDoesNotExist:
- raise ServiceKeyDoesNotExist
- return _create_key(key)
+ def delete_service_key(self, kid):
+ try:
+ key = model.service_keys.delete_service_key(kid)
+ except model.ServiceKeyDoesNotExist:
+ raise ServiceKeyDoesNotExist
+ return _create_key(key)
- def update_service_key(self, kid, name=None, metadata=None):
- model.service_keys.update_service_key(kid, name, metadata)
+ def update_service_key(self, kid, name=None, metadata=None):
+ model.service_keys.update_service_key(kid, name, metadata)
- def set_key_expiration(self, kid, expiration_date):
- model.service_keys.set_key_expiration(kid, expiration_date)
+ def set_key_expiration(self, kid, expiration_date):
+ model.service_keys.set_key_expiration(kid, expiration_date)
- def get_service_key(self, kid, service=None, alive_only=True, approved_only=True):
- try:
- key = model.service_keys.get_service_key(kid, approved_only=approved_only, alive_only=alive_only)
- return _create_key(key)
- except model.ServiceKeyDoesNotExist:
- raise ServiceKeyDoesNotExist
+ def get_service_key(self, kid, service=None, alive_only=True, approved_only=True):
+ try:
+ key = model.service_keys.get_service_key(
+ kid, approved_only=approved_only, alive_only=alive_only
+ )
+ return _create_key(key)
+ except model.ServiceKeyDoesNotExist:
+ raise ServiceKeyDoesNotExist
- def approve_service_key(self, kid, approver, approval_type, notes=''):
- try:
- key = model.service_keys.approve_service_key(kid, approval_type, approver=approver, notes=notes)
- return _create_key(key)
- except model.ServiceKeyDoesNotExist:
- raise ServiceKeyDoesNotExist
- except model.ServiceKeyAlreadyApproved:
- raise ServiceKeyAlreadyApproved
+ def approve_service_key(self, kid, approver, approval_type, notes=""):
+ try:
+ key = model.service_keys.approve_service_key(
+ kid, approval_type, approver=approver, notes=notes
+ )
+ return _create_key(key)
+ except model.ServiceKeyDoesNotExist:
+ raise ServiceKeyDoesNotExist
+ except model.ServiceKeyAlreadyApproved:
+ raise ServiceKeyAlreadyApproved
- def generate_service_key(self, service, expiration_date, kid=None, name='', metadata=None, rotation_duration=None):
- (private_key, key) = model.service_keys.generate_service_key(service, expiration_date, metadata=metadata, name=name)
+ def generate_service_key(
+ self,
+ service,
+ expiration_date,
+ kid=None,
+ name="",
+ metadata=None,
+ rotation_duration=None,
+ ):
+ (private_key, key) = model.service_keys.generate_service_key(
+ service, expiration_date, metadata=metadata, name=name
+ )
- return private_key, key.kid
+ return private_key, key.kid
- def list_all_service_keys(self):
- keys = model.service_keys.list_all_keys()
- return [_create_key(key) for key in keys]
+ def list_all_service_keys(self):
+ keys = model.service_keys.list_all_keys()
+ return [_create_key(key) for key in keys]
- def change_organization_name(self, old_org_name, new_org_name):
- org = model.organization.get_organization(old_org_name)
- if new_org_name is not None:
- org = model.user.change_username(org.id, new_org_name)
+ def change_organization_name(self, old_org_name, new_org_name):
+ org = model.organization.get_organization(old_org_name)
+ if new_org_name is not None:
+ org = model.user.change_username(org.id, new_org_name)
- return Organization(org.username, org.email)
+ return Organization(org.username, org.email)
- def mark_organization_for_deletion(self, name):
- org = model.organization.get_organization(name)
- model.user.mark_namespace_for_deletion(org, all_queues, namespace_gc_queue, force=True)
+ def mark_organization_for_deletion(self, name):
+ org = model.organization.get_organization(name)
+ model.user.mark_namespace_for_deletion(
+ org, all_queues, namespace_gc_queue, force=True
+ )
- def take_ownership(self, namespace, authed_user):
- entity = model.user.get_user_or_org(namespace)
- if entity is None:
- return None, False
+ def take_ownership(self, namespace, authed_user):
+ entity = model.user.get_user_or_org(namespace)
+ if entity is None:
+ return None, False
- was_user = not entity.organization
- if entity.organization:
- # Add the superuser as an admin to the owners team of the org.
- model.organization.add_user_as_admin(authed_user, entity)
- else:
- # If the entity is a user, convert it to an organization and add the current superuser
- # as the admin.
- model.organization.convert_user_to_organization(entity, authed_user)
- return entity.id, was_user
+ was_user = not entity.organization
+ if entity.organization:
+ # Add the superuser as an admin to the owners team of the org.
+ model.organization.add_user_as_admin(authed_user, entity)
+ else:
+ # If the entity is a user, convert it to an organization and add the current superuser
+ # as the admin.
+ model.organization.convert_user_to_organization(entity, authed_user)
+ return entity.id, was_user
- def update_enabled(self, username, enabled):
- user = model.user.get_nonrobot_user(username)
- model.user.update_enabled(user, bool(enabled))
+ def update_enabled(self, username, enabled):
+ user = model.user.get_nonrobot_user(username)
+ model.user.update_enabled(user, bool(enabled))
- def update_email(self, username, email, auto_verify):
- user = model.user.get_nonrobot_user(username)
- model.user.update_email(user, email, auto_verify)
+ def update_email(self, username, email, auto_verify):
+ user = model.user.get_nonrobot_user(username)
+ model.user.update_email(user, email, auto_verify)
- def change_password(self, username, password):
- user = model.user.get_nonrobot_user(username)
- model.user.change_password(user, password)
+ def change_password(self, username, password):
+ user = model.user.get_nonrobot_user(username)
+ model.user.change_password(user, password)
- def mark_user_for_deletion(self, username):
- user = model.user.get_nonrobot_user(username)
- model.user.mark_namespace_for_deletion(user, all_queues, namespace_gc_queue, force=True)
+ def mark_user_for_deletion(self, username):
+ user = model.user.get_nonrobot_user(username)
+ model.user.mark_namespace_for_deletion(
+ user, all_queues, namespace_gc_queue, force=True
+ )
- def create_reset_password_email_code(self, email):
- code = model.user.create_reset_password_email_code(email)
- return code
+ def create_reset_password_email_code(self, email):
+ code = model.user.create_reset_password_email_code(email)
+ return code
- def get_nonrobot_user(self, username):
- user = model.user.get_nonrobot_user(username)
- if user is None:
- return None
- return _create_user(user)
+ def get_nonrobot_user(self, username):
+ user = model.user.get_nonrobot_user(username)
+ if user is None:
+ return None
+ return _create_user(user)
- def create_install_user(self, username, password, email):
- prompts = model.user.get_default_user_prompts(features)
- user = model.user.create_user(username, password, email, auto_verify=not features.MAILING,
- email_required=features.MAILING, prompts=prompts)
+ def create_install_user(self, username, password, email):
+ prompts = model.user.get_default_user_prompts(features)
+ user = model.user.create_user(
+ username,
+ password,
+ email,
+ auto_verify=not features.MAILING,
+ email_required=features.MAILING,
+ prompts=prompts,
+ )
- return_user = _create_user(user)
- # If mailing is turned on, send the user a verification email.
- if features.MAILING:
- confirmation_code = model.user.create_confirm_email_code(user)
- return return_user, confirmation_code
+ return_user = _create_user(user)
+ # If mailing is turned on, send the user a verification email.
+ if features.MAILING:
+ confirmation_code = model.user.create_confirm_email_code(user)
+ return return_user, confirmation_code
- return return_user, ''
+ return return_user, ""
- def get_active_users(self, disabled=True):
- users = model.user.get_active_users(disabled=disabled)
- return [_create_user(user) for user in users]
+ def get_active_users(self, disabled=True):
+ users = model.user.get_active_users(disabled=disabled)
+ return [_create_user(user) for user in users]
- def get_organizations(self):
- return [Organization(org.username, org.email) for org in model.organization.get_organizations()]
+ def get_organizations(self):
+ return [
+ Organization(org.username, org.email)
+ for org in model.organization.get_organizations()
+ ]
pre_oci_model = PreOCIModel()
diff --git a/endpoints/api/tag.py b/endpoints/api/tag.py
index 573f0fc97..66aaa1400 100644
--- a/endpoints/api/tag.py
+++ b/endpoints/api/tag.py
@@ -5,332 +5,393 @@ from flask import request, abort
from app import storage, docker_v2_signing_key
from auth.auth_context import get_authenticated_user
from data.registry_model import registry_model
-from endpoints.api import (resource, nickname, require_repo_read, require_repo_write,
- RepositoryParamResource, log_action, validate_json_request, path_param,
- parse_args, query_param, truthy_bool, disallow_for_app_repositories,
- format_date, disallow_for_non_normal_repositories)
+from endpoints.api import (
+ resource,
+ nickname,
+ require_repo_read,
+ require_repo_write,
+ RepositoryParamResource,
+ log_action,
+ validate_json_request,
+ path_param,
+ parse_args,
+ query_param,
+ truthy_bool,
+ disallow_for_app_repositories,
+ format_date,
+ disallow_for_non_normal_repositories,
+)
from endpoints.api.image import image_dict
from endpoints.exception import NotFound, InvalidRequest
from util.names import TAG_ERROR, TAG_REGEX
def _tag_dict(tag):
- tag_info = {
- 'name': tag.name,
- 'reversion': tag.reversion,
- }
+ tag_info = {"name": tag.name, "reversion": tag.reversion}
- if tag.lifetime_start_ts > 0:
- tag_info['start_ts'] = tag.lifetime_start_ts
+ if tag.lifetime_start_ts > 0:
+ tag_info["start_ts"] = tag.lifetime_start_ts
- if tag.lifetime_end_ts > 0:
- tag_info['end_ts'] = tag.lifetime_end_ts
+ if tag.lifetime_end_ts > 0:
+ tag_info["end_ts"] = tag.lifetime_end_ts
- # TODO: Remove this once fully on OCI data model.
- if tag.legacy_image_if_present:
- tag_info['docker_image_id'] = tag.legacy_image.docker_image_id
- tag_info['image_id'] = tag.legacy_image.docker_image_id
- tag_info['size'] = tag.legacy_image.aggregate_size
+ # TODO: Remove this once fully on OCI data model.
+ if tag.legacy_image_if_present:
+ tag_info["docker_image_id"] = tag.legacy_image.docker_image_id
+ tag_info["image_id"] = tag.legacy_image.docker_image_id
+ tag_info["size"] = tag.legacy_image.aggregate_size
- # TODO: Remove this check once fully on OCI data model.
- if tag.manifest_digest:
- tag_info['manifest_digest'] = tag.manifest_digest
+ # TODO: Remove this check once fully on OCI data model.
+ if tag.manifest_digest:
+ tag_info["manifest_digest"] = tag.manifest_digest
- if tag.manifest:
- tag_info['is_manifest_list'] = tag.manifest.is_manifest_list
+ if tag.manifest:
+ tag_info["is_manifest_list"] = tag.manifest.is_manifest_list
- if tag.lifetime_start_ts > 0:
- last_modified = format_date(datetime.utcfromtimestamp(tag.lifetime_start_ts))
- tag_info['last_modified'] = last_modified
+ if tag.lifetime_start_ts > 0:
+ last_modified = format_date(datetime.utcfromtimestamp(tag.lifetime_start_ts))
+ tag_info["last_modified"] = last_modified
- if tag.lifetime_end_ts is not None:
- expiration = format_date(datetime.utcfromtimestamp(tag.lifetime_end_ts))
- tag_info['expiration'] = expiration
+ if tag.lifetime_end_ts is not None:
+ expiration = format_date(datetime.utcfromtimestamp(tag.lifetime_end_ts))
+ tag_info["expiration"] = expiration
- return tag_info
+ return tag_info
-@resource('/v1/repository//tag/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
+@resource("/v1/repository//tag/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
class ListRepositoryTags(RepositoryParamResource):
- """ Resource for listing full repository tag history, alive *and dead*. """
+ """ Resource for listing full repository tag history, alive *and dead*. """
- @require_repo_read
- @disallow_for_app_repositories
- @parse_args()
- @query_param('specificTag', 'Filters the tags to the specific tag.', type=str, default='')
- @query_param('limit', 'Limit to the number of results to return per page. Max 100.', type=int,
- default=50)
- @query_param('page', 'Page index for the results. Default 1.', type=int, default=1)
- @query_param('onlyActiveTags', 'Filter to only active tags.', type=truthy_bool, default=False)
- @nickname('listRepoTags')
- def get(self, namespace, repository, parsed_args):
- specific_tag = parsed_args.get('specificTag') or None
- page = max(1, parsed_args.get('page', 1))
- limit = min(100, max(1, parsed_args.get('limit', 50)))
- active_tags_only = parsed_args.get('onlyActiveTags')
+ @require_repo_read
+ @disallow_for_app_repositories
+ @parse_args()
+ @query_param(
+ "specificTag", "Filters the tags to the specific tag.", type=str, default=""
+ )
+ @query_param(
+ "limit",
+ "Limit to the number of results to return per page. Max 100.",
+ type=int,
+ default=50,
+ )
+ @query_param("page", "Page index for the results. Default 1.", type=int, default=1)
+ @query_param(
+ "onlyActiveTags", "Filter to only active tags.", type=truthy_bool, default=False
+ )
+ @nickname("listRepoTags")
+ def get(self, namespace, repository, parsed_args):
+ specific_tag = parsed_args.get("specificTag") or None
+ page = max(1, parsed_args.get("page", 1))
+ limit = min(100, max(1, parsed_args.get("limit", 50)))
+ active_tags_only = parsed_args.get("onlyActiveTags")
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- history, has_more = registry_model.list_repository_tag_history(repo_ref, page=page,
- size=limit,
- specific_tag_name=specific_tag,
- active_tags_only=active_tags_only)
- return {
- 'tags': [_tag_dict(tag) for tag in history],
- 'page': page,
- 'has_additional': has_more,
- }
+ history, has_more = registry_model.list_repository_tag_history(
+ repo_ref,
+ page=page,
+ size=limit,
+ specific_tag_name=specific_tag,
+ active_tags_only=active_tags_only,
+ )
+ return {
+ "tags": [_tag_dict(tag) for tag in history],
+ "page": page,
+ "has_additional": has_more,
+ }
-@resource('/v1/repository//tag/')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('tag', 'The name of the tag')
+@resource("/v1/repository//tag/")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("tag", "The name of the tag")
class RepositoryTag(RepositoryParamResource):
- """ Resource for managing repository tags. """
- schemas = {
- 'ChangeTag': {
- 'type': 'object',
- 'description': 'Makes changes to a specific tag',
- 'properties': {
- 'image': {
- 'type': ['string', 'null'],
- 'description': '(Deprecated: Use `manifest_digest`) Image to which the tag should point.',
- },
- 'manifest_digest': {
- 'type': ['string', 'null'],
- 'description': '(If specified) The manifest digest to which the tag should point',
- },
- 'expiration': {
- 'type': ['number', 'null'],
- 'description': '(If specified) The expiration for the image',
- },
- },
- },
- }
+ """ Resource for managing repository tags. """
- @require_repo_write
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- @nickname('changeTag')
- @validate_json_request('ChangeTag')
- def put(self, namespace, repository, tag):
- """ Change which image a tag points to or create a new tag."""
- if not TAG_REGEX.match(tag):
- abort(400, TAG_ERROR)
+ schemas = {
+ "ChangeTag": {
+ "type": "object",
+ "description": "Makes changes to a specific tag",
+ "properties": {
+ "image": {
+ "type": ["string", "null"],
+ "description": "(Deprecated: Use `manifest_digest`) Image to which the tag should point.",
+ },
+ "manifest_digest": {
+ "type": ["string", "null"],
+ "description": "(If specified) The manifest digest to which the tag should point",
+ },
+ "expiration": {
+ "type": ["number", "null"],
+ "description": "(If specified) The expiration for the image",
+ },
+ },
+ }
+ }
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @require_repo_write
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ @nickname("changeTag")
+ @validate_json_request("ChangeTag")
+ def put(self, namespace, repository, tag):
+ """ Change which image a tag points to or create a new tag."""
+ if not TAG_REGEX.match(tag):
+ abort(400, TAG_ERROR)
- if 'expiration' in request.get_json():
- tag_ref = registry_model.get_repo_tag(repo_ref, tag)
- if tag_ref is None:
- raise NotFound()
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- expiration = request.get_json().get('expiration')
- expiration_date = None
- if expiration is not None:
- try:
- expiration_date = datetime.utcfromtimestamp(float(expiration))
- except ValueError:
- abort(400)
+ if "expiration" in request.get_json():
+ tag_ref = registry_model.get_repo_tag(repo_ref, tag)
+ if tag_ref is None:
+ raise NotFound()
- if expiration_date <= datetime.now():
- abort(400)
+ expiration = request.get_json().get("expiration")
+ expiration_date = None
+ if expiration is not None:
+ try:
+ expiration_date = datetime.utcfromtimestamp(float(expiration))
+ except ValueError:
+ abort(400)
- existing_end_ts, ok = registry_model.change_repository_tag_expiration(tag_ref,
- expiration_date)
- if ok:
- if not (existing_end_ts is None and expiration_date is None):
- log_action('change_tag_expiration', namespace, {
- 'username': get_authenticated_user().username,
- 'repo': repository,
- 'tag': tag,
- 'namespace': namespace,
- 'expiration_date': expiration_date,
- 'old_expiration_date': existing_end_ts
- }, repo_name=repository)
- else:
- raise InvalidRequest('Could not update tag expiration; Tag has probably changed')
+ if expiration_date <= datetime.now():
+ abort(400)
- if 'image' in request.get_json() or 'manifest_digest' in request.get_json():
- existing_tag = registry_model.get_repo_tag(repo_ref, tag, include_legacy_image=True)
+ existing_end_ts, ok = registry_model.change_repository_tag_expiration(
+ tag_ref, expiration_date
+ )
+ if ok:
+ if not (existing_end_ts is None and expiration_date is None):
+ log_action(
+ "change_tag_expiration",
+ namespace,
+ {
+ "username": get_authenticated_user().username,
+ "repo": repository,
+ "tag": tag,
+ "namespace": namespace,
+ "expiration_date": expiration_date,
+ "old_expiration_date": existing_end_ts,
+ },
+ repo_name=repository,
+ )
+ else:
+ raise InvalidRequest(
+ "Could not update tag expiration; Tag has probably changed"
+ )
- manifest_or_image = None
- image_id = None
- manifest_digest = None
+ if "image" in request.get_json() or "manifest_digest" in request.get_json():
+ existing_tag = registry_model.get_repo_tag(
+ repo_ref, tag, include_legacy_image=True
+ )
- if 'image' in request.get_json():
- image_id = request.get_json()['image']
- manifest_or_image = registry_model.get_legacy_image(repo_ref, image_id)
- else:
- manifest_digest = request.get_json()['manifest_digest']
- manifest_or_image = registry_model.lookup_manifest_by_digest(repo_ref, manifest_digest,
- require_available=True)
+ manifest_or_image = None
+ image_id = None
+ manifest_digest = None
- if manifest_or_image is None:
- raise NotFound()
+ if "image" in request.get_json():
+ image_id = request.get_json()["image"]
+ manifest_or_image = registry_model.get_legacy_image(repo_ref, image_id)
+ else:
+ manifest_digest = request.get_json()["manifest_digest"]
+ manifest_or_image = registry_model.lookup_manifest_by_digest(
+ repo_ref, manifest_digest, require_available=True
+ )
- # TODO: Remove this check once fully on V22
- existing_manifest_digest = None
- if existing_tag:
- existing_manifest = registry_model.get_manifest_for_tag(existing_tag)
- existing_manifest_digest = existing_manifest.digest if existing_manifest else None
+ if manifest_or_image is None:
+ raise NotFound()
- if not registry_model.retarget_tag(repo_ref, tag, manifest_or_image, storage,
- docker_v2_signing_key):
- raise InvalidRequest('Could not move tag')
+ # TODO: Remove this check once fully on V22
+ existing_manifest_digest = None
+ if existing_tag:
+ existing_manifest = registry_model.get_manifest_for_tag(existing_tag)
+ existing_manifest_digest = (
+ existing_manifest.digest if existing_manifest else None
+ )
- username = get_authenticated_user().username
+ if not registry_model.retarget_tag(
+ repo_ref, tag, manifest_or_image, storage, docker_v2_signing_key
+ ):
+ raise InvalidRequest("Could not move tag")
- log_action('move_tag' if existing_tag else 'create_tag', namespace, {
- 'username': username,
- 'repo': repository,
- 'tag': tag,
- 'namespace': namespace,
- 'image': image_id,
- 'manifest_digest': manifest_digest,
- 'original_image': (existing_tag.legacy_image.docker_image_id
- if existing_tag and existing_tag.legacy_image_if_present
- else None),
- 'original_manifest_digest': existing_manifest_digest,
- }, repo_name=repository)
+ username = get_authenticated_user().username
- return 'Updated', 201
+ log_action(
+ "move_tag" if existing_tag else "create_tag",
+ namespace,
+ {
+ "username": username,
+ "repo": repository,
+ "tag": tag,
+ "namespace": namespace,
+ "image": image_id,
+ "manifest_digest": manifest_digest,
+ "original_image": (
+ existing_tag.legacy_image.docker_image_id
+ if existing_tag and existing_tag.legacy_image_if_present
+ else None
+ ),
+ "original_manifest_digest": existing_manifest_digest,
+ },
+ repo_name=repository,
+ )
- @require_repo_write
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- @nickname('deleteFullTag')
- def delete(self, namespace, repository, tag):
- """ Delete the specified repository tag. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ return "Updated", 201
- registry_model.delete_tag(repo_ref, tag)
+ @require_repo_write
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ @nickname("deleteFullTag")
+ def delete(self, namespace, repository, tag):
+ """ Delete the specified repository tag. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- username = get_authenticated_user().username
- log_action('delete_tag', namespace,
- {'username': username,
- 'repo': repository,
- 'namespace': namespace,
- 'tag': tag}, repo_name=repository)
+ registry_model.delete_tag(repo_ref, tag)
- return '', 204
+ username = get_authenticated_user().username
+ log_action(
+ "delete_tag",
+ namespace,
+ {
+ "username": username,
+ "repo": repository,
+ "namespace": namespace,
+ "tag": tag,
+ },
+ repo_name=repository,
+ )
+
+ return "", 204
-@resource('/v1/repository//tag//images')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('tag', 'The name of the tag')
+@resource("/v1/repository//tag//images")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("tag", "The name of the tag")
class RepositoryTagImages(RepositoryParamResource):
- """ Resource for listing the images in a specific repository tag. """
+ """ Resource for listing the images in a specific repository tag. """
- @require_repo_read
- @nickname('listTagImages')
- @disallow_for_app_repositories
- @parse_args()
- @query_param('owned', 'If specified, only images wholely owned by this tag are returned.',
- type=truthy_bool, default=False)
- def get(self, namespace, repository, tag, parsed_args):
- """ List the images for the specified repository tag. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
+ @require_repo_read
+ @nickname("listTagImages")
+ @disallow_for_app_repositories
+ @parse_args()
+ @query_param(
+ "owned",
+ "If specified, only images wholely owned by this tag are returned.",
+ type=truthy_bool,
+ default=False,
+ )
+ def get(self, namespace, repository, tag, parsed_args):
+ """ List the images for the specified repository tag. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- tag_ref = registry_model.get_repo_tag(repo_ref, tag, include_legacy_image=True)
- if tag_ref is None:
- raise NotFound()
+ tag_ref = registry_model.get_repo_tag(repo_ref, tag, include_legacy_image=True)
+ if tag_ref is None:
+ raise NotFound()
- if tag_ref.legacy_image_if_present is None:
- return {'images': []}
+ if tag_ref.legacy_image_if_present is None:
+ return {"images": []}
- image_id = tag_ref.legacy_image.docker_image_id
+ image_id = tag_ref.legacy_image.docker_image_id
- all_images = None
- if parsed_args['owned']:
- # TODO: Remove the `owned` image concept once we are fully on V2_2.
- all_images = registry_model.get_legacy_images_owned_by_tag(tag_ref)
- else:
- image_with_parents = registry_model.get_legacy_image(repo_ref, image_id, include_parents=True)
- if image_with_parents is None:
- raise NotFound()
+ all_images = None
+ if parsed_args["owned"]:
+ # TODO: Remove the `owned` image concept once we are fully on V2_2.
+ all_images = registry_model.get_legacy_images_owned_by_tag(tag_ref)
+ else:
+ image_with_parents = registry_model.get_legacy_image(
+ repo_ref, image_id, include_parents=True
+ )
+ if image_with_parents is None:
+ raise NotFound()
- all_images = [image_with_parents] + image_with_parents.parents
+ all_images = [image_with_parents] + image_with_parents.parents
- return {
- 'images': [image_dict(image) for image in all_images],
- }
+ return {"images": [image_dict(image) for image in all_images]}
-@resource('/v1/repository//tag//restore')
-@path_param('repository', 'The full path of the repository. e.g. namespace/name')
-@path_param('tag', 'The name of the tag')
+@resource("/v1/repository//tag//restore")
+@path_param("repository", "The full path of the repository. e.g. namespace/name")
+@path_param("tag", "The name of the tag")
class RestoreTag(RepositoryParamResource):
- """ Resource for restoring a repository tag back to a previous image. """
- schemas = {
- 'RestoreTag': {
- 'type': 'object',
- 'description': 'Restores a tag to a specific image',
- 'properties': {
- 'image': {
- 'type': 'string',
- 'description': '(Deprecated: use `manifest_digest`) Image to which the tag should point',
- },
- 'manifest_digest': {
- 'type': 'string',
- 'description': 'If specified, the manifest digest that should be used',
- },
- },
- },
- }
+ """ Resource for restoring a repository tag back to a previous image. """
- @require_repo_write
- @disallow_for_app_repositories
- @disallow_for_non_normal_repositories
- @nickname('restoreTag')
- @validate_json_request('RestoreTag')
- def post(self, namespace, repository, tag):
- """ Restores a repository tag back to a previous image in the repository. """
- repo_ref = registry_model.lookup_repository(namespace, repository)
- if repo_ref is None:
- raise NotFound()
-
- # Restore the tag back to the previous image.
- image_id = request.get_json().get('image', None)
- manifest_digest = request.get_json().get('manifest_digest', None)
-
- if image_id is None and manifest_digest is None:
- raise InvalidRequest('Missing manifest_digest')
-
- # Data for logging the reversion/restoration.
- username = get_authenticated_user().username
- log_data = {
- 'username': username,
- 'repo': repository,
- 'tag': tag,
- 'image': image_id,
- 'manifest_digest': manifest_digest,
+ schemas = {
+ "RestoreTag": {
+ "type": "object",
+ "description": "Restores a tag to a specific image",
+ "properties": {
+ "image": {
+ "type": "string",
+ "description": "(Deprecated: use `manifest_digest`) Image to which the tag should point",
+ },
+ "manifest_digest": {
+ "type": "string",
+ "description": "If specified, the manifest digest that should be used",
+ },
+ },
+ }
}
- manifest_or_legacy_image = None
- if manifest_digest is not None:
- manifest_or_legacy_image = registry_model.lookup_manifest_by_digest(repo_ref, manifest_digest,
- allow_dead=True,
- require_available=True)
- elif image_id is not None:
- manifest_or_legacy_image = registry_model.get_legacy_image(repo_ref, image_id)
+ @require_repo_write
+ @disallow_for_app_repositories
+ @disallow_for_non_normal_repositories
+ @nickname("restoreTag")
+ @validate_json_request("RestoreTag")
+ def post(self, namespace, repository, tag):
+ """ Restores a repository tag back to a previous image in the repository. """
+ repo_ref = registry_model.lookup_repository(namespace, repository)
+ if repo_ref is None:
+ raise NotFound()
- if manifest_or_legacy_image is None:
- raise NotFound()
+ # Restore the tag back to the previous image.
+ image_id = request.get_json().get("image", None)
+ manifest_digest = request.get_json().get("manifest_digest", None)
- if not registry_model.retarget_tag(repo_ref, tag, manifest_or_legacy_image, storage,
- docker_v2_signing_key, is_reversion=True):
- raise InvalidRequest('Could not restore tag')
+ if image_id is None and manifest_digest is None:
+ raise InvalidRequest("Missing manifest_digest")
- log_action('revert_tag', namespace, log_data, repo_name=repository)
+ # Data for logging the reversion/restoration.
+ username = get_authenticated_user().username
+ log_data = {
+ "username": username,
+ "repo": repository,
+ "tag": tag,
+ "image": image_id,
+ "manifest_digest": manifest_digest,
+ }
- return {}
+ manifest_or_legacy_image = None
+ if manifest_digest is not None:
+ manifest_or_legacy_image = registry_model.lookup_manifest_by_digest(
+ repo_ref, manifest_digest, allow_dead=True, require_available=True
+ )
+ elif image_id is not None:
+ manifest_or_legacy_image = registry_model.get_legacy_image(
+ repo_ref, image_id
+ )
+
+ if manifest_or_legacy_image is None:
+ raise NotFound()
+
+ if not registry_model.retarget_tag(
+ repo_ref,
+ tag,
+ manifest_or_legacy_image,
+ storage,
+ docker_v2_signing_key,
+ is_reversion=True,
+ ):
+ raise InvalidRequest("Could not restore tag")
+
+ log_action("revert_tag", namespace, log_data, repo_name=repository)
+
+ return {}
diff --git a/endpoints/api/team.py b/endpoints/api/team.py
index b00a14393..cda23dc96 100644
--- a/endpoints/api/team.py
+++ b/endpoints/api/team.py
@@ -9,526 +9,599 @@ from flask import request
import features
from app import avatar, authentication
-from auth.permissions import (AdministerOrganizationPermission, ViewTeamPermission,
- SuperUserPermission)
+from auth.permissions import (
+ AdministerOrganizationPermission,
+ ViewTeamPermission,
+ SuperUserPermission,
+)
from auth.auth_context import get_authenticated_user
from auth import scopes
from data import model
from data.database import Team
-from endpoints.api import (resource, nickname, ApiResource, validate_json_request, request_error,
- log_action, internal_only, require_scope, path_param, query_param,
- truthy_bool, parse_args, require_user_admin, show_if, format_date,
- verify_not_prod, require_fresh_login)
+from endpoints.api import (
+ resource,
+ nickname,
+ ApiResource,
+ validate_json_request,
+ request_error,
+ log_action,
+ internal_only,
+ require_scope,
+ path_param,
+ query_param,
+ truthy_bool,
+ parse_args,
+ require_user_admin,
+ show_if,
+ format_date,
+ verify_not_prod,
+ require_fresh_login,
+)
from endpoints.exception import Unauthorized, NotFound, InvalidRequest
from util.useremails import send_org_invite_email
from util.names import parse_robot_username
+
def permission_view(permission):
- return {
- 'repository': {
- 'name': permission.repository.name,
- 'is_public': permission.repository.visibility.name == 'public'
- },
- 'role': permission.role.name
- }
-
-def try_accept_invite(code, user):
- (team, inviter) = model.team.confirm_team_invite(code, user)
-
- model.notification.delete_matching_notifications(user, 'org_team_invite',
- org=team.organization.username)
-
- orgname = team.organization.username
- log_action('org_team_member_invite_accepted', orgname, {
- 'member': user.username,
- 'team': team.name,
- 'inviter': inviter.username
- })
-
- return team
-
-def handle_addinvite_team(inviter, team, user=None, email=None):
- requires_invite = features.MAILING and features.REQUIRE_TEAM_INVITE
- invite = model.team.add_or_invite_to_team(inviter, team, user, email,
- requires_invite=requires_invite)
- if not invite:
- # User was added to the team directly.
- return
-
- orgname = team.organization.username
- if user:
- model.notification.create_notification('org_team_invite', user, metadata={
- 'code': invite.invite_token,
- 'inviter': inviter.username,
- 'org': orgname,
- 'team': team.name
- })
-
- send_org_invite_email(user.username if user else email, user.email if user else email,
- orgname, team.name, inviter.username, invite.invite_token)
- return invite
-
-def team_view(orgname, team, is_new_team=False):
- view_permission = ViewTeamPermission(orgname, team.name)
- return {
- 'name': team.name,
- 'description': team.description,
- 'can_view': view_permission.can(),
- 'role': Team.role.get_name(team.role_id),
- 'avatar': avatar.get_data_for_team(team),
- 'new_team': is_new_team,
- }
-
-def member_view(member, invited=False):
- return {
- 'name': member.username,
- 'kind': 'user',
- 'is_robot': member.robot,
- 'avatar': avatar.get_data_for_user(member),
- 'invited': invited,
- }
-
-def invite_view(invite):
- if invite.user:
- return member_view(invite.user, invited=True)
- else:
return {
- 'email': invite.email,
- 'kind': 'invite',
- 'avatar': avatar.get_data(invite.email, invite.email, 'user'),
- 'invited': True
+ "repository": {
+ "name": permission.repository.name,
+ "is_public": permission.repository.visibility.name == "public",
+ },
+ "role": permission.role.name,
}
+
+def try_accept_invite(code, user):
+ (team, inviter) = model.team.confirm_team_invite(code, user)
+
+ model.notification.delete_matching_notifications(
+ user, "org_team_invite", org=team.organization.username
+ )
+
+ orgname = team.organization.username
+ log_action(
+ "org_team_member_invite_accepted",
+ orgname,
+ {"member": user.username, "team": team.name, "inviter": inviter.username},
+ )
+
+ return team
+
+
+def handle_addinvite_team(inviter, team, user=None, email=None):
+ requires_invite = features.MAILING and features.REQUIRE_TEAM_INVITE
+ invite = model.team.add_or_invite_to_team(
+ inviter, team, user, email, requires_invite=requires_invite
+ )
+ if not invite:
+ # User was added to the team directly.
+ return
+
+ orgname = team.organization.username
+ if user:
+ model.notification.create_notification(
+ "org_team_invite",
+ user,
+ metadata={
+ "code": invite.invite_token,
+ "inviter": inviter.username,
+ "org": orgname,
+ "team": team.name,
+ },
+ )
+
+ send_org_invite_email(
+ user.username if user else email,
+ user.email if user else email,
+ orgname,
+ team.name,
+ inviter.username,
+ invite.invite_token,
+ )
+ return invite
+
+
+def team_view(orgname, team, is_new_team=False):
+ view_permission = ViewTeamPermission(orgname, team.name)
+ return {
+ "name": team.name,
+ "description": team.description,
+ "can_view": view_permission.can(),
+ "role": Team.role.get_name(team.role_id),
+ "avatar": avatar.get_data_for_team(team),
+ "new_team": is_new_team,
+ }
+
+
+def member_view(member, invited=False):
+ return {
+ "name": member.username,
+ "kind": "user",
+ "is_robot": member.robot,
+ "avatar": avatar.get_data_for_user(member),
+ "invited": invited,
+ }
+
+
+def invite_view(invite):
+ if invite.user:
+ return member_view(invite.user, invited=True)
+ else:
+ return {
+ "email": invite.email,
+ "kind": "invite",
+ "avatar": avatar.get_data(invite.email, invite.email, "user"),
+ "invited": True,
+ }
+
+
def disallow_for_synced_team(except_robots=False):
- """ Disallows the decorated operation for a team that is marked as being synced from an internal
+ """ Disallows the decorated operation for a team that is marked as being synced from an internal
auth provider such as LDAP. If except_robots is True, then the operation is allowed if the
member specified on the operation is a robot account.
"""
- def inner(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- # Team syncing can only be enabled if we have a federated service.
- if features.TEAM_SYNCING and authentication.federated_service:
- orgname = kwargs['orgname']
- teamname = kwargs['teamname']
- if model.team.get_team_sync_information(orgname, teamname):
- if not except_robots or not parse_robot_username(kwargs.get('membername', '')):
- raise InvalidRequest('Cannot call this method on an auth-synced team')
- return func(self, *args, **kwargs)
- return wrapper
- return inner
+ def inner(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ # Team syncing can only be enabled if we have a federated service.
+ if features.TEAM_SYNCING and authentication.federated_service:
+ orgname = kwargs["orgname"]
+ teamname = kwargs["teamname"]
+ if model.team.get_team_sync_information(orgname, teamname):
+ if not except_robots or not parse_robot_username(
+ kwargs.get("membername", "")
+ ):
+ raise InvalidRequest(
+ "Cannot call this method on an auth-synced team"
+ )
+
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return inner
disallow_nonrobots_for_synced_team = disallow_for_synced_team(except_robots=True)
disallow_all_for_synced_team = disallow_for_synced_team(except_robots=False)
-@resource('/v1/organization//team/')
-@path_param('orgname', 'The name of the organization')
-@path_param('teamname', 'The name of the team')
+@resource("/v1/organization//team/")
+@path_param("orgname", "The name of the organization")
+@path_param("teamname", "The name of the team")
class OrganizationTeam(ApiResource):
- """ Resource for manging an organization's teams. """
- schemas = {
- 'TeamDescription': {
- 'type': 'object',
- 'description': 'Description of a team',
- 'required': [
- 'role',
- ],
- 'properties': {
- 'role': {
- 'type': 'string',
- 'description': 'Org wide permissions that should apply to the team',
- 'enum': [
- 'member',
- 'creator',
- 'admin',
- ],
- },
- 'description': {
- 'type': 'string',
- 'description': 'Markdown description for the team',
- },
- },
- },
- }
+ """ Resource for manging an organization's teams. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('updateOrganizationTeam')
- @validate_json_request('TeamDescription')
- def put(self, orgname, teamname):
- """ Update the org-wide permission for the specified team. """
- edit_permission = AdministerOrganizationPermission(orgname)
- if edit_permission.can():
- team = None
+ schemas = {
+ "TeamDescription": {
+ "type": "object",
+ "description": "Description of a team",
+ "required": ["role"],
+ "properties": {
+ "role": {
+ "type": "string",
+ "description": "Org wide permissions that should apply to the team",
+ "enum": ["member", "creator", "admin"],
+ },
+ "description": {
+ "type": "string",
+ "description": "Markdown description for the team",
+ },
+ },
+ }
+ }
- details = request.get_json()
- is_existing = False
- try:
- team = model.team.get_organization_team(orgname, teamname)
- is_existing = True
- except model.InvalidTeamException:
- # Create the new team.
- description = details['description'] if 'description' in details else ''
- role = details['role'] if 'role' in details else 'member'
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("updateOrganizationTeam")
+ @validate_json_request("TeamDescription")
+ def put(self, orgname, teamname):
+ """ Update the org-wide permission for the specified team. """
+ edit_permission = AdministerOrganizationPermission(orgname)
+ if edit_permission.can():
+ team = None
- org = model.organization.get_organization(orgname)
- team = model.team.create_team(teamname, org, role, description)
- log_action('org_create_team', orgname, {'team': teamname})
+ details = request.get_json()
+ is_existing = False
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ is_existing = True
+ except model.InvalidTeamException:
+ # Create the new team.
+ description = details["description"] if "description" in details else ""
+ role = details["role"] if "role" in details else "member"
- if is_existing:
- if ('description' in details and
- team.description != details['description']):
- team.description = details['description']
- team.save()
- log_action('org_set_team_description', orgname,
- {'team': teamname, 'description': team.description})
+ org = model.organization.get_organization(orgname)
+ team = model.team.create_team(teamname, org, role, description)
+ log_action("org_create_team", orgname, {"team": teamname})
- if 'role' in details:
- role = Team.role.get_name(team.role_id)
- if role != details['role']:
- team = model.team.set_team_org_permission(team, details['role'],
- get_authenticated_user().username)
- log_action('org_set_team_role', orgname, {'team': teamname, 'role': details['role']})
+ if is_existing:
+ if (
+ "description" in details
+ and team.description != details["description"]
+ ):
+ team.description = details["description"]
+ team.save()
+ log_action(
+ "org_set_team_description",
+ orgname,
+ {"team": teamname, "description": team.description},
+ )
- return team_view(orgname, team, is_new_team=not is_existing), 200
+ if "role" in details:
+ role = Team.role.get_name(team.role_id)
+ if role != details["role"]:
+ team = model.team.set_team_org_permission(
+ team, details["role"], get_authenticated_user().username
+ )
+ log_action(
+ "org_set_team_role",
+ orgname,
+ {"team": teamname, "role": details["role"]},
+ )
- raise Unauthorized()
+ return team_view(orgname, team, is_new_team=not is_existing), 200
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrganizationTeam')
- def delete(self, orgname, teamname):
- """ Delete the specified team. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- model.team.remove_team(orgname, teamname, get_authenticated_user().username)
- log_action('org_delete_team', orgname, {'team': teamname})
- return '', 204
+ raise Unauthorized()
- raise Unauthorized()
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrganizationTeam")
+ def delete(self, orgname, teamname):
+ """ Delete the specified team. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ model.team.remove_team(orgname, teamname, get_authenticated_user().username)
+ log_action("org_delete_team", orgname, {"team": teamname})
+ return "", 204
+
+ raise Unauthorized()
def _syncing_setup_allowed(orgname):
- """ Returns whether syncing setup is allowed for the current user over the matching org. """
- if not features.NONSUPERUSER_TEAM_SYNCING_SETUP and not SuperUserPermission().can():
- return False
+ """ Returns whether syncing setup is allowed for the current user over the matching org. """
+ if not features.NONSUPERUSER_TEAM_SYNCING_SETUP and not SuperUserPermission().can():
+ return False
- return AdministerOrganizationPermission(orgname).can()
+ return AdministerOrganizationPermission(orgname).can()
-@resource('/v1/organization//team//syncing')
-@path_param('orgname', 'The name of the organization')
-@path_param('teamname', 'The name of the team')
+@resource("/v1/organization//team//syncing")
+@path_param("orgname", "The name of the organization")
+@path_param("teamname", "The name of the team")
@show_if(features.TEAM_SYNCING)
class OrganizationTeamSyncing(ApiResource):
- """ Resource for managing syncing of a team by a backing group. """
- @require_scope(scopes.ORG_ADMIN)
- @require_scope(scopes.SUPERUSER)
- @nickname('enableOrganizationTeamSync')
- @verify_not_prod
- @require_fresh_login
- def post(self, orgname, teamname):
- if _syncing_setup_allowed(orgname):
- try:
- team = model.team.get_organization_team(orgname, teamname)
- except model.InvalidTeamException:
- raise NotFound()
+ """ Resource for managing syncing of a team by a backing group. """
- config = request.get_json()
-
- # Ensure that the specified config points to a valid group.
- status, err = authentication.check_group_lookup_args(config)
- if not status:
- raise InvalidRequest('Could not sync to group: %s' % err)
-
- # Set the team's syncing config.
- model.team.set_team_syncing(team, authentication.federated_service, config)
-
- return team_view(orgname, team)
-
- raise Unauthorized()
-
- @require_scope(scopes.ORG_ADMIN)
- @require_scope(scopes.SUPERUSER)
- @nickname('disableOrganizationTeamSync')
- @verify_not_prod
- @require_fresh_login
- def delete(self, orgname, teamname):
- if _syncing_setup_allowed(orgname):
- try:
- team = model.team.get_organization_team(orgname, teamname)
- except model.InvalidTeamException:
- raise NotFound()
-
- model.team.remove_team_syncing(orgname, teamname)
- return team_view(orgname, team)
-
- raise Unauthorized()
-
-
-@resource('/v1/organization//team//members')
-@path_param('orgname', 'The name of the organization')
-@path_param('teamname', 'The name of the team')
-class TeamMemberList(ApiResource):
- """ Resource for managing the list of members for a team. """
- @require_scope(scopes.ORG_ADMIN)
- @parse_args()
- @query_param('includePending', 'Whether to include pending members', type=truthy_bool,
- default=False)
- @nickname('getOrganizationTeamMembers')
- def get(self, orgname, teamname, parsed_args):
- """ Retrieve the list of members for the specified team. """
- view_permission = ViewTeamPermission(orgname, teamname)
- edit_permission = AdministerOrganizationPermission(orgname)
-
- if view_permission.can():
- team = None
- try:
- team = model.team.get_organization_team(orgname, teamname)
- except model.InvalidTeamException:
- raise NotFound()
-
- members = model.organization.get_organization_team_members(team.id)
- invites = []
-
- if parsed_args['includePending'] and edit_permission.can():
- invites = model.team.get_organization_team_member_invites(team.id)
-
- data = {
- 'name': teamname,
- 'members': [member_view(m) for m in members] + [invite_view(i) for i in invites],
- 'can_edit': edit_permission.can(),
- }
-
- if features.TEAM_SYNCING and authentication.federated_service:
+ @require_scope(scopes.ORG_ADMIN)
+ @require_scope(scopes.SUPERUSER)
+ @nickname("enableOrganizationTeamSync")
+ @verify_not_prod
+ @require_fresh_login
+ def post(self, orgname, teamname):
if _syncing_setup_allowed(orgname):
- data['can_sync'] = {
- 'service': authentication.federated_service,
- }
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ except model.InvalidTeamException:
+ raise NotFound()
- data['can_sync'].update(authentication.service_metadata())
+ config = request.get_json()
- sync_info = model.team.get_team_sync_information(orgname, teamname)
- if sync_info is not None:
- data['synced'] = {
- 'service': sync_info.service.name,
- }
+ # Ensure that the specified config points to a valid group.
+ status, err = authentication.check_group_lookup_args(config)
+ if not status:
+ raise InvalidRequest("Could not sync to group: %s" % err)
- if SuperUserPermission().can():
- data['synced'].update({
- 'last_updated': format_date(sync_info.last_updated),
- 'config': json.loads(sync_info.config),
- })
+ # Set the team's syncing config.
+ model.team.set_team_syncing(team, authentication.federated_service, config)
- return data
+ return team_view(orgname, team)
- raise Unauthorized()
+ raise Unauthorized()
+
+ @require_scope(scopes.ORG_ADMIN)
+ @require_scope(scopes.SUPERUSER)
+ @nickname("disableOrganizationTeamSync")
+ @verify_not_prod
+ @require_fresh_login
+ def delete(self, orgname, teamname):
+ if _syncing_setup_allowed(orgname):
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ except model.InvalidTeamException:
+ raise NotFound()
+
+ model.team.remove_team_syncing(orgname, teamname)
+ return team_view(orgname, team)
+
+ raise Unauthorized()
-@resource('/v1/organization//team//members/')
-@path_param('orgname', 'The name of the organization')
-@path_param('teamname', 'The name of the team')
-@path_param('membername', 'The username of the team member')
+@resource("/v1/organization//team//members")
+@path_param("orgname", "The name of the organization")
+@path_param("teamname", "The name of the team")
+class TeamMemberList(ApiResource):
+ """ Resource for managing the list of members for a team. """
+
+ @require_scope(scopes.ORG_ADMIN)
+ @parse_args()
+ @query_param(
+ "includePending",
+ "Whether to include pending members",
+ type=truthy_bool,
+ default=False,
+ )
+ @nickname("getOrganizationTeamMembers")
+ def get(self, orgname, teamname, parsed_args):
+ """ Retrieve the list of members for the specified team. """
+ view_permission = ViewTeamPermission(orgname, teamname)
+ edit_permission = AdministerOrganizationPermission(orgname)
+
+ if view_permission.can():
+ team = None
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ except model.InvalidTeamException:
+ raise NotFound()
+
+ members = model.organization.get_organization_team_members(team.id)
+ invites = []
+
+ if parsed_args["includePending"] and edit_permission.can():
+ invites = model.team.get_organization_team_member_invites(team.id)
+
+ data = {
+ "name": teamname,
+ "members": [member_view(m) for m in members]
+ + [invite_view(i) for i in invites],
+ "can_edit": edit_permission.can(),
+ }
+
+ if features.TEAM_SYNCING and authentication.federated_service:
+ if _syncing_setup_allowed(orgname):
+ data["can_sync"] = {"service": authentication.federated_service}
+
+ data["can_sync"].update(authentication.service_metadata())
+
+ sync_info = model.team.get_team_sync_information(orgname, teamname)
+ if sync_info is not None:
+ data["synced"] = {"service": sync_info.service.name}
+
+ if SuperUserPermission().can():
+ data["synced"].update(
+ {
+ "last_updated": format_date(sync_info.last_updated),
+ "config": json.loads(sync_info.config),
+ }
+ )
+
+ return data
+
+ raise Unauthorized()
+
+
+@resource("/v1/organization//team//members/")
+@path_param("orgname", "The name of the organization")
+@path_param("teamname", "The name of the team")
+@path_param("membername", "The username of the team member")
class TeamMember(ApiResource):
- """ Resource for managing individual members of a team. """
+ """ Resource for managing individual members of a team. """
- @require_scope(scopes.ORG_ADMIN)
- @nickname('updateOrganizationTeamMember')
- @disallow_nonrobots_for_synced_team
- def put(self, orgname, teamname, membername):
- """ Adds or invites a member to an existing team. """
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- team = None
- user = None
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("updateOrganizationTeamMember")
+ @disallow_nonrobots_for_synced_team
+ def put(self, orgname, teamname, membername):
+ """ Adds or invites a member to an existing team. """
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ team = None
+ user = None
- # Find the team.
- try:
- team = model.team.get_organization_team(orgname, teamname)
- except model.InvalidTeamException:
- raise NotFound()
+ # Find the team.
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ except model.InvalidTeamException:
+ raise NotFound()
- # Find the user.
- user = model.user.get_user(membername)
- if not user:
- raise request_error(message='Unknown user')
+ # Find the user.
+ user = model.user.get_user(membername)
+ if not user:
+ raise request_error(message="Unknown user")
- # Add or invite the user to the team.
- inviter = get_authenticated_user()
- invite = handle_addinvite_team(inviter, team, user=user)
- if not invite:
- log_action('org_add_team_member', orgname, {'member': membername, 'team': teamname})
- return member_view(user, invited=False)
+ # Add or invite the user to the team.
+ inviter = get_authenticated_user()
+ invite = handle_addinvite_team(inviter, team, user=user)
+ if not invite:
+ log_action(
+ "org_add_team_member",
+ orgname,
+ {"member": membername, "team": teamname},
+ )
+ return member_view(user, invited=False)
- # User was invited.
- log_action('org_invite_team_member', orgname, {
- 'user': membername,
- 'member': membername,
- 'team': teamname
- })
- return member_view(user, invited=True)
+ # User was invited.
+ log_action(
+ "org_invite_team_member",
+ orgname,
+ {"user": membername, "member": membername, "team": teamname},
+ )
+ return member_view(user, invited=True)
- raise Unauthorized()
+ raise Unauthorized()
- @require_scope(scopes.ORG_ADMIN)
- @nickname('deleteOrganizationTeamMember')
- @disallow_nonrobots_for_synced_team
- def delete(self, orgname, teamname, membername):
- """ Delete a member of a team. If the user is merely invited to join
+ @require_scope(scopes.ORG_ADMIN)
+ @nickname("deleteOrganizationTeamMember")
+ @disallow_nonrobots_for_synced_team
+ def delete(self, orgname, teamname, membername):
+ """ Delete a member of a team. If the user is merely invited to join
the team, then the invite is removed instead.
"""
- permission = AdministerOrganizationPermission(orgname)
- if permission.can():
- # Remote the user from the team.
- invoking_user = get_authenticated_user().username
+ permission = AdministerOrganizationPermission(orgname)
+ if permission.can():
+ # Remote the user from the team.
+ invoking_user = get_authenticated_user().username
- # Find the team.
- try:
- team = model.team.get_organization_team(orgname, teamname)
- except model.InvalidTeamException:
- raise NotFound()
+ # Find the team.
+ try:
+ team = model.team.get_organization_team(orgname, teamname)
+ except model.InvalidTeamException:
+ raise NotFound()
- # Find the member.
- member = model.user.get_user(membername)
- if not member:
- raise NotFound()
+ # Find the member.
+ member = model.user.get_user(membername)
+ if not member:
+ raise NotFound()
- # First attempt to delete an invite for the user to this team. If none found,
- # then we try to remove the user directly.
- if model.team.delete_team_user_invite(team, member):
- log_action('org_delete_team_member_invite', orgname, {
- 'user': membername,
- 'team': teamname,
- 'member': membername
- })
- return '', 204
+ # First attempt to delete an invite for the user to this team. If none found,
+ # then we try to remove the user directly.
+ if model.team.delete_team_user_invite(team, member):
+ log_action(
+ "org_delete_team_member_invite",
+ orgname,
+ {"user": membername, "team": teamname, "member": membername},
+ )
+ return "", 204
- model.team.remove_user_from_team(orgname, teamname, membername, invoking_user)
- log_action('org_remove_team_member', orgname, {'member': membername, 'team': teamname})
- return '', 204
+ model.team.remove_user_from_team(
+ orgname, teamname, membername, invoking_user
+ )
+ log_action(
+ "org_remove_team_member",
+ orgname,
+ {"member": membername, "team": teamname},
+ )
+ return "", 204
- raise Unauthorized()
+ raise Unauthorized()
-@resource('/v1/organization//team/