initial import for Open Source 🎉

This commit is contained in:
Jimmy Zelinskie 2019-11-12 11:09:47 -05:00
parent 1898c361f3
commit 9c0dd3b722
2048 changed files with 218743 additions and 0 deletions

153
data/model/__init__.py Normal file
View file

@ -0,0 +1,153 @@
from data.database import db, db_transaction
class DataModelException(Exception):
pass
class InvalidLabelKeyException(DataModelException):
pass
class InvalidMediaTypeException(DataModelException):
pass
class BlobDoesNotExist(DataModelException):
pass
class TorrentInfoDoesNotExist(DataModelException):
pass
class InvalidBlobUpload(DataModelException):
pass
class InvalidEmailAddressException(DataModelException):
pass
class InvalidOrganizationException(DataModelException):
pass
class InvalidPasswordException(DataModelException):
pass
class InvalidRobotException(DataModelException):
pass
class InvalidUsernameException(DataModelException):
pass
class InvalidRepositoryBuildException(DataModelException):
pass
class InvalidBuildTriggerException(DataModelException):
pass
class InvalidTokenException(DataModelException):
pass
class InvalidNotificationException(DataModelException):
pass
class InvalidImageException(DataModelException):
pass
class UserAlreadyInTeam(DataModelException):
pass
class InvalidTeamException(DataModelException):
pass
class InvalidTeamMemberException(DataModelException):
pass
class InvalidManifestException(DataModelException):
pass
class ServiceKeyDoesNotExist(DataModelException):
pass
class ServiceKeyAlreadyApproved(DataModelException):
pass
class ServiceNameInvalid(DataModelException):
pass
class TagAlreadyCreatedException(DataModelException):
pass
class StaleTagException(DataModelException):
pass
class TooManyLoginAttemptsException(Exception):
def __init__(self, message, retry_after):
super(TooManyLoginAttemptsException, self).__init__(message)
self.retry_after = retry_after
class Config(object):
def __init__(self):
self.app_config = None
self.store = None
self.image_cleanup_callbacks = []
self.repo_cleanup_callbacks = []
def register_image_cleanup_callback(self, callback):
self.image_cleanup_callbacks.append(callback)
def register_repo_cleanup_callback(self, callback):
self.repo_cleanup_callbacks.append(callback)
config = Config()
# There MUST NOT be any circular dependencies between these subsections. If there are fix it by
# moving the minimal number of things to _basequery
from data.model import (
appspecifictoken,
blob,
build,
gc,
image,
label,
log,
message,
modelutil,
notification,
oauth,
organization,
permission,
repositoryactioncount,
repo_mirror,
release,
repo_mirror,
repository,
service_keys,
storage,
tag,
team,
token,
user,
)

198
data/model/_basequery.py Normal file
View file

@ -0,0 +1,198 @@
import logging
from peewee import fn, PeeweeException
from cachetools.func import lru_cache
from datetime import datetime, timedelta
from data.model import DataModelException, config
from data.readreplica import ReadOnlyModeException
from data.database import (Repository, User, Team, TeamMember, RepositoryPermission, TeamRole,
Namespace, Visibility, ImageStorage, Image, RepositoryKind,
db_for_update)
logger = logging.getLogger(__name__)
def reduce_as_tree(queries_to_reduce):
""" This method will split a list of queries into halves recursively until we reach individual
queries, at which point it will start unioning the queries, or the already unioned subqueries.
This works around a bug in peewee SQL generation where reducing linearly generates a chain
of queries that will exceed the recursion depth limit when it has around 80 queries.
"""
mid = len(queries_to_reduce)/2
left = queries_to_reduce[:mid]
right = queries_to_reduce[mid:]
to_reduce_right = right[0]
if len(right) > 1:
to_reduce_right = reduce_as_tree(right)
if len(left) > 1:
to_reduce_left = reduce_as_tree(left)
elif len(left) == 1:
to_reduce_left = left[0]
else:
return to_reduce_right
return to_reduce_left.union_all(to_reduce_right)
def get_existing_repository(namespace_name, repository_name, for_update=False, kind_filter=None):
query = (Repository
.select(Repository, Namespace)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace_name,
Repository.name == repository_name))
if kind_filter:
query = (query
.switch(Repository)
.join(RepositoryKind)
.where(RepositoryKind.name == kind_filter))
if for_update:
query = db_for_update(query)
return query.get()
@lru_cache(maxsize=1)
def get_public_repo_visibility():
return Visibility.get(name='public')
def _lookup_team_role(name):
return _lookup_team_roles()[name]
@lru_cache(maxsize=1)
def _lookup_team_roles():
return {role.name:role for role in TeamRole.select()}
def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='image',
include_public=True, start_id=None):
if not include_public and not user_id:
return Repository.select().where(Repository.id == '-1')
# Filter on the type of repository.
if repo_kind is not None:
try:
query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
except RepositoryKind.DoesNotExist:
raise DataModelException('Unknown repository kind')
# Add the start ID if necessary.
if start_id is not None:
query = query.where(Repository.id >= start_id)
# Add a namespace filter if necessary.
if namespace:
query = query.where(Namespace.username == namespace)
# Build a set of queries that, when unioned together, return the full set of visible repositories
# for the filters specified.
queries = []
if include_public:
queries.append(query.where(Repository.visibility == get_public_repo_visibility()))
if user_id is not None:
AdminTeam = Team.alias()
AdminTeamMember = TeamMember.alias()
# Add repositories in which the user has permission.
queries.append(query
.switch(RepositoryPermission)
.where(RepositoryPermission.user == user_id))
# Add repositories in which the user is a member of a team that has permission.
queries.append(query
.switch(RepositoryPermission)
.join(Team)
.join(TeamMember)
.where(TeamMember.user == user_id))
# Add repositories under namespaces in which the user is the org admin.
queries.append(query
.switch(Repository)
.join(AdminTeam, on=(Repository.namespace_user == AdminTeam.organization))
.join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
.where(AdminTeam.role == _lookup_team_role('admin'))
.where(AdminTeamMember.user == user_id))
return reduce(lambda l, r: l | r, queries)
def get_user_organizations(username):
UserAlias = User.alias()
return (User
.select()
.distinct()
.join(Team)
.join(TeamMember)
.join(UserAlias, on=(UserAlias.id == TeamMember.user))
.where(User.organization == True, UserAlias.username == username))
def calculate_image_aggregate_size(ancestors_str, image_size, parent_image):
ancestors = ancestors_str.split('/')[1:-1]
if not ancestors:
return image_size
if parent_image is None:
raise DataModelException('Could not load parent image')
ancestor_size = parent_image.aggregate_size
if ancestor_size is not None:
return ancestor_size + image_size
# Fallback to a slower path if the parent doesn't have an aggregate size saved.
# TODO: remove this code if/when we do a full backfill.
ancestor_size = (ImageStorage
.select(fn.Sum(ImageStorage.image_size))
.join(Image)
.where(Image.id << ancestors)
.scalar())
if ancestor_size is None:
return None
return ancestor_size + image_size
def update_last_accessed(token_or_user):
""" Updates the `last_accessed` field on the given token or user. If the existing field's value
is within the configured threshold, the update is skipped. """
if not config.app_config.get('FEATURE_USER_LAST_ACCESSED'):
return
threshold = timedelta(seconds=config.app_config.get('LAST_ACCESSED_UPDATE_THRESHOLD_S', 120))
if (token_or_user.last_accessed is not None and
datetime.utcnow() - token_or_user.last_accessed < threshold):
# Skip updating, as we don't want to put undue pressure on the database.
return
model_class = token_or_user.__class__
last_accessed = datetime.utcnow()
try:
(model_class
.update(last_accessed=last_accessed)
.where(model_class.id == token_or_user.id)
.execute())
token_or_user.last_accessed = last_accessed
except ReadOnlyModeException:
pass
except PeeweeException as ex:
# If there is any form of DB exception, only fail if strict logging is enabled.
strict_logging_disabled = config.app_config.get('ALLOW_PULLS_WITHOUT_STRICT_LOGGING')
if strict_logging_disabled:
data = {
'exception': ex,
'token_or_user': token_or_user.id,
'class': str(model_class),
}
logger.exception('update last_accessed for token/user failed', extra=data)
else:
raise

View file

@ -0,0 +1,172 @@
import logging
from datetime import datetime
from active_migration import ActiveDataMigration, ERTMigrationFlags
from data.database import AppSpecificAuthToken, User, random_string_generator
from data.model import config
from data.model._basequery import update_last_accessed
from data.fields import DecryptedValue
from util.timedeltastring import convert_to_timedelta
from util.unicode import remove_unicode
logger = logging.getLogger(__name__)
TOKEN_NAME_PREFIX_LENGTH = 60
MINIMUM_TOKEN_SUFFIX_LENGTH = 60
def _default_expiration_duration():
expiration_str = config.app_config.get('APP_SPECIFIC_TOKEN_EXPIRATION')
return convert_to_timedelta(expiration_str) if expiration_str else None
# Define a "unique" value so that callers can specifiy an expiration of None and *not* have it
# use the default.
_default_expiration_duration_opt = '__deo'
def create_token(user, title, expiration=_default_expiration_duration_opt):
""" Creates and returns an app specific token for the given user. If no expiration is specified
(including `None`), then the default from config is used. """
if expiration == _default_expiration_duration_opt:
duration = _default_expiration_duration()
expiration = duration + datetime.now() if duration else None
token_code = random_string_generator(TOKEN_NAME_PREFIX_LENGTH + MINIMUM_TOKEN_SUFFIX_LENGTH)()
token_name = token_code[:TOKEN_NAME_PREFIX_LENGTH]
token_secret = token_code[TOKEN_NAME_PREFIX_LENGTH:]
assert token_name
assert token_secret
# TODO(remove-unenc): Remove legacy handling.
old_token_code = (token_code
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS)
else None)
return AppSpecificAuthToken.create(user=user,
title=title,
expiration=expiration,
token_name=token_name,
token_secret=DecryptedValue(token_secret),
token_code=old_token_code)
def list_tokens(user):
""" Lists all tokens for the given user. """
return AppSpecificAuthToken.select().where(AppSpecificAuthToken.user == user)
def revoke_token(token):
""" Revokes an app specific token by deleting it. """
token.delete_instance()
def revoke_token_by_uuid(uuid, owner):
""" Revokes an app specific token by deleting it. """
try:
token = AppSpecificAuthToken.get(uuid=uuid, user=owner)
except AppSpecificAuthToken.DoesNotExist:
return None
revoke_token(token)
return token
def get_expiring_tokens(user, soon):
""" Returns all tokens owned by the given user that will be expiring "soon", where soon is defined
by the soon parameter (a timedelta from now).
"""
soon_datetime = datetime.now() + soon
return (AppSpecificAuthToken
.select()
.where(AppSpecificAuthToken.user == user,
AppSpecificAuthToken.expiration <= soon_datetime,
AppSpecificAuthToken.expiration > datetime.now()))
def gc_expired_tokens(expiration_window):
""" Deletes all expired tokens outside of the expiration window. """
(AppSpecificAuthToken
.delete()
.where(AppSpecificAuthToken.expiration < (datetime.now() - expiration_window))
.execute())
def get_token_by_uuid(uuid, owner=None):
""" Looks up an unexpired app specific token with the given uuid. Returns it if found or
None if none. If owner is specified, only tokens owned by the owner user will be
returned.
"""
try:
query = (AppSpecificAuthToken
.select()
.where(AppSpecificAuthToken.uuid == uuid,
((AppSpecificAuthToken.expiration > datetime.now()) |
(AppSpecificAuthToken.expiration >> None))))
if owner is not None:
query = query.where(AppSpecificAuthToken.user == owner)
return query.get()
except AppSpecificAuthToken.DoesNotExist:
return None
def access_valid_token(token_code):
""" Looks up an unexpired app specific token with the given token code. If found, the token's
last_accessed field is set to now and the token is returned. If not found, returns None.
"""
token_code = remove_unicode(token_code)
prefix = token_code[:TOKEN_NAME_PREFIX_LENGTH]
if len(prefix) != TOKEN_NAME_PREFIX_LENGTH:
return None
suffix = token_code[TOKEN_NAME_PREFIX_LENGTH:]
# Lookup the token by its prefix.
try:
token = (AppSpecificAuthToken
.select(AppSpecificAuthToken, User)
.join(User)
.where(AppSpecificAuthToken.token_name == prefix,
((AppSpecificAuthToken.expiration > datetime.now()) |
(AppSpecificAuthToken.expiration >> None)))
.get())
if not token.token_secret.matches(suffix):
return None
assert len(prefix) == TOKEN_NAME_PREFIX_LENGTH
assert len(suffix) >= MINIMUM_TOKEN_SUFFIX_LENGTH
update_last_accessed(token)
return token
except AppSpecificAuthToken.DoesNotExist:
pass
# TODO(remove-unenc): Remove legacy handling.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
token = (AppSpecificAuthToken
.select(AppSpecificAuthToken, User)
.join(User)
.where(AppSpecificAuthToken.token_code == token_code,
((AppSpecificAuthToken.expiration > datetime.now()) |
(AppSpecificAuthToken.expiration >> None)))
.get())
update_last_accessed(token)
return token
except AppSpecificAuthToken.DoesNotExist:
return None
return None
def get_full_token_string(token):
# TODO(remove-unenc): Remove legacy handling.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
if not token.token_name:
return token.token_code
assert token.token_name
return '%s%s' % (token.token_name, token.token_secret.decrypt())

237
data/model/blob.py Normal file
View file

@ -0,0 +1,237 @@
import logging
from datetime import datetime
from uuid import uuid4
from data.model import (tag, _basequery, BlobDoesNotExist, InvalidBlobUpload, db_transaction,
storage as storage_model, InvalidImageException)
from data.database import (Repository, Namespace, ImageStorage, Image, ImageStoragePlacement,
BlobUpload, ImageStorageLocation, db_random_func)
logger = logging.getLogger(__name__)
def get_repository_blob_by_digest(repository, blob_digest):
""" Find the content-addressable blob linked to the specified repository.
"""
assert blob_digest
try:
storage = (ImageStorage
.select(ImageStorage.uuid)
.join(Image)
.where(Image.repository == repository,
ImageStorage.content_checksum == blob_digest,
ImageStorage.uploading == False)
.get())
return storage_model.get_storage_by_uuid(storage.uuid)
except (ImageStorage.DoesNotExist, InvalidImageException):
raise BlobDoesNotExist('Blob does not exist with digest: {0}'.format(blob_digest))
def get_repo_blob_by_digest(namespace, repo_name, blob_digest):
""" Find the content-addressable blob linked to the specified repository.
"""
assert blob_digest
try:
storage = (ImageStorage
.select(ImageStorage.uuid)
.join(Image)
.join(Repository)
.join(Namespace, on=(Namespace.id == Repository.namespace_user))
.where(Repository.name == repo_name, Namespace.username == namespace,
ImageStorage.content_checksum == blob_digest,
ImageStorage.uploading == False)
.get())
return storage_model.get_storage_by_uuid(storage.uuid)
except (ImageStorage.DoesNotExist, InvalidImageException):
raise BlobDoesNotExist('Blob does not exist with digest: {0}'.format(blob_digest))
def store_blob_record_and_temp_link(namespace, repo_name, blob_digest, location_obj, byte_count,
link_expiration_s, uncompressed_byte_count=None):
repo = _basequery.get_existing_repository(namespace, repo_name)
assert repo
return store_blob_record_and_temp_link_in_repo(repo.id, blob_digest, location_obj, byte_count,
link_expiration_s, uncompressed_byte_count)
def store_blob_record_and_temp_link_in_repo(repository_id, blob_digest, location_obj, byte_count,
link_expiration_s, uncompressed_byte_count=None):
""" Store a record of the blob and temporarily link it to the specified repository.
"""
assert blob_digest
assert byte_count is not None
with db_transaction():
try:
storage = ImageStorage.get(content_checksum=blob_digest)
save_changes = False
if storage.image_size is None:
storage.image_size = byte_count
save_changes = True
if storage.uncompressed_size is None and uncompressed_byte_count is not None:
storage.uncompressed_size = uncompressed_byte_count
save_changes = True
if save_changes:
storage.save()
ImageStoragePlacement.get(storage=storage, location=location_obj)
except ImageStorage.DoesNotExist:
storage = ImageStorage.create(content_checksum=blob_digest, uploading=False,
image_size=byte_count,
uncompressed_size=uncompressed_byte_count)
ImageStoragePlacement.create(storage=storage, location=location_obj)
except ImageStoragePlacement.DoesNotExist:
ImageStoragePlacement.create(storage=storage, location=location_obj)
_temp_link_blob(repository_id, storage, link_expiration_s)
return storage
def temp_link_blob(repository_id, blob_digest, link_expiration_s):
""" Temporarily links to the blob record from the given repository. If the blob record is not
found, return None.
"""
assert blob_digest
with db_transaction():
try:
storage = ImageStorage.get(content_checksum=blob_digest)
except ImageStorage.DoesNotExist:
return None
_temp_link_blob(repository_id, storage, link_expiration_s)
return storage
def _temp_link_blob(repository_id, storage, link_expiration_s):
""" Note: Should *always* be called by a parent under a transaction. """
random_image_name = str(uuid4())
# Create a temporary link into the repository, to be replaced by the v1 metadata later
# and create a temporary tag to reference it
image = Image.create(storage=storage, docker_image_id=random_image_name, repository=repository_id)
tag.create_temporary_hidden_tag(repository_id, image, link_expiration_s)
def get_stale_blob_upload(stale_timespan):
""" Returns a random blob upload which was created before the stale timespan. """
stale_threshold = datetime.now() - stale_timespan
try:
candidates = (BlobUpload
.select()
.where(BlobUpload.created <= stale_threshold)
.limit(500)
.distinct()
.alias('candidates'))
found = (BlobUpload
.select(candidates.c.id)
.from_(candidates)
.order_by(db_random_func())
.get())
if not found:
return None
return (BlobUpload
.select(BlobUpload, ImageStorageLocation)
.join(ImageStorageLocation)
.where(BlobUpload.id == found.id)
.get())
except BlobUpload.DoesNotExist:
return None
def get_blob_upload_by_uuid(upload_uuid):
""" Loads the upload with the given UUID, if any. """
try:
return (BlobUpload
.select()
.where(BlobUpload.uuid == upload_uuid)
.get())
except BlobUpload.DoesNotExist:
return None
def get_blob_upload(namespace, repo_name, upload_uuid):
""" Load the upload which is already in progress.
"""
try:
return (BlobUpload
.select(BlobUpload, ImageStorageLocation)
.join(ImageStorageLocation)
.switch(BlobUpload)
.join(Repository)
.join(Namespace, on=(Namespace.id == Repository.namespace_user))
.where(Repository.name == repo_name, Namespace.username == namespace,
BlobUpload.uuid == upload_uuid)
.get())
except BlobUpload.DoesNotExist:
raise InvalidBlobUpload()
def initiate_upload(namespace, repo_name, uuid, location_name, storage_metadata):
""" Initiates a blob upload for the repository with the given namespace and name,
in a specific location. """
repo = _basequery.get_existing_repository(namespace, repo_name)
return initiate_upload_for_repo(repo, uuid, location_name, storage_metadata)
def initiate_upload_for_repo(repo, uuid, location_name, storage_metadata):
""" Initiates a blob upload for a specific repository object, in a specific location. """
location = storage_model.get_image_location_for_name(location_name)
return BlobUpload.create(repository=repo, location=location.id, uuid=uuid,
storage_metadata=storage_metadata)
def get_shared_blob(digest):
""" Returns the ImageStorage blob with the given digest or, if not present,
returns None. This method is *only* to be used for shared blobs that are
globally accessible, such as the special empty gzipped tar layer that Docker
no longer pushes to us.
"""
assert digest
try:
return ImageStorage.get(content_checksum=digest, uploading=False)
except ImageStorage.DoesNotExist:
return None
def get_or_create_shared_blob(digest, byte_data, storage):
""" Returns the ImageStorage blob with the given digest or, if not present,
adds a row and writes the given byte data to the storage engine.
This method is *only* to be used for shared blobs that are globally
accessible, such as the special empty gzipped tar layer that Docker
no longer pushes to us.
"""
assert digest
assert byte_data is not None
assert storage
try:
return ImageStorage.get(content_checksum=digest, uploading=False)
except ImageStorage.DoesNotExist:
record = ImageStorage.create(image_size=len(byte_data), content_checksum=digest,
cas_path=True, uploading=True)
preferred = storage.preferred_locations[0]
location_obj = ImageStorageLocation.get(name=preferred)
try:
storage.put_content([preferred], storage_model.get_layer_path(record), byte_data)
ImageStoragePlacement.create(storage=record, location=location_obj)
record.uploading = False
record.save()
except:
logger.exception('Exception when trying to write special layer %s', digest)
record.delete_instance()
raise
return record

323
data/model/build.py Normal file
View file

@ -0,0 +1,323 @@
import json
from datetime import timedelta, datetime
from peewee import JOIN
from active_migration import ActiveDataMigration, ERTMigrationFlags
from data.database import (BuildTriggerService, RepositoryBuildTrigger, Repository, Namespace, User,
RepositoryBuild, BUILD_PHASE, db_random_func, UseThenDisconnect,
TRIGGER_DISABLE_REASON)
from data.model import (InvalidBuildTriggerException, InvalidRepositoryBuildException,
db_transaction, user as user_model, config)
from data.fields import DecryptedValue
PRESUMED_DEAD_BUILD_AGE = timedelta(days=15)
PHASES_NOT_ALLOWED_TO_CANCEL_FROM = (BUILD_PHASE.PUSHING, BUILD_PHASE.COMPLETE,
BUILD_PHASE.ERROR, BUILD_PHASE.INTERNAL_ERROR)
ARCHIVABLE_BUILD_PHASES = [BUILD_PHASE.COMPLETE, BUILD_PHASE.ERROR, BUILD_PHASE.CANCELLED]
def update_build_trigger(trigger, config, auth_token=None, write_token=None):
trigger.config = json.dumps(config or {})
# TODO(remove-unenc): Remove legacy field.
if auth_token is not None:
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
trigger.auth_token = auth_token
trigger.secure_auth_token = auth_token
if write_token is not None:
trigger.write_token = write_token
trigger.save()
def create_build_trigger(repo, service_name, auth_token, user, pull_robot=None, config=None):
service = BuildTriggerService.get(name=service_name)
# TODO(remove-unenc): Remove legacy field.
old_auth_token = None
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
old_auth_token = auth_token
secure_auth_token = DecryptedValue(auth_token) if auth_token else None
trigger = RepositoryBuildTrigger.create(repository=repo, service=service,
auth_token=old_auth_token,
secure_auth_token=secure_auth_token,
connected_user=user,
pull_robot=pull_robot,
config=json.dumps(config or {}))
return trigger
def get_build_trigger(trigger_uuid):
try:
return (RepositoryBuildTrigger
.select(RepositoryBuildTrigger, BuildTriggerService, Repository, Namespace)
.join(BuildTriggerService)
.switch(RepositoryBuildTrigger)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryBuildTrigger)
.join(User, on=(RepositoryBuildTrigger.connected_user == User.id))
.where(RepositoryBuildTrigger.uuid == trigger_uuid)
.get())
except RepositoryBuildTrigger.DoesNotExist:
msg = 'No build trigger with uuid: %s' % trigger_uuid
raise InvalidBuildTriggerException(msg)
def list_build_triggers(namespace_name, repository_name):
return (RepositoryBuildTrigger
.select(RepositoryBuildTrigger, BuildTriggerService, Repository)
.join(BuildTriggerService)
.switch(RepositoryBuildTrigger)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace_name, Repository.name == repository_name))
def list_trigger_builds(namespace_name, repository_name, trigger_uuid,
limit):
return (list_repository_builds(namespace_name, repository_name, limit)
.where(RepositoryBuildTrigger.uuid == trigger_uuid))
def get_repository_for_resource(resource_key):
try:
return (Repository
.select(Repository, Namespace)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository)
.join(RepositoryBuild)
.where(RepositoryBuild.resource_key == resource_key)
.get())
except Repository.DoesNotExist:
return None
def _get_build_base_query():
return (RepositoryBuild
.select(RepositoryBuild, RepositoryBuildTrigger, BuildTriggerService, Repository,
Namespace, User)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryBuild)
.join(User, JOIN.LEFT_OUTER)
.switch(RepositoryBuild)
.join(RepositoryBuildTrigger, JOIN.LEFT_OUTER)
.join(BuildTriggerService, JOIN.LEFT_OUTER)
.order_by(RepositoryBuild.started.desc()))
def get_repository_build(build_uuid):
try:
return _get_build_base_query().where(RepositoryBuild.uuid == build_uuid).get()
except RepositoryBuild.DoesNotExist:
msg = 'Unable to locate a build by id: %s' % build_uuid
raise InvalidRepositoryBuildException(msg)
def list_repository_builds(namespace_name, repository_name, limit,
include_inactive=True, since=None):
query = (_get_build_base_query()
.where(Repository.name == repository_name, Namespace.username == namespace_name)
.limit(limit))
if since is not None:
query = query.where(RepositoryBuild.started >= since)
if not include_inactive:
query = query.where(RepositoryBuild.phase != BUILD_PHASE.ERROR,
RepositoryBuild.phase != BUILD_PHASE.COMPLETE)
return query
def get_recent_repository_build(namespace_name, repository_name):
query = list_repository_builds(namespace_name, repository_name, 1)
try:
return query.get()
except RepositoryBuild.DoesNotExist:
return None
def create_repository_build(repo, access_token, job_config_obj, dockerfile_id,
display_name, trigger=None, pull_robot_name=None):
pull_robot = None
if pull_robot_name:
pull_robot = user_model.lookup_robot(pull_robot_name)
return RepositoryBuild.create(repository=repo, access_token=access_token,
job_config=json.dumps(job_config_obj),
display_name=display_name, trigger=trigger,
resource_key=dockerfile_id,
pull_robot=pull_robot)
def get_pull_robot_name(trigger):
if not trigger.pull_robot:
return None
return trigger.pull_robot.username
def _get_build_row(build_uuid):
return RepositoryBuild.select().where(RepositoryBuild.uuid == build_uuid).get()
def update_phase_then_close(build_uuid, phase):
""" A function to change the phase of a build """
with UseThenDisconnect(config.app_config):
try:
build = _get_build_row(build_uuid)
except RepositoryBuild.DoesNotExist:
return False
# Can't update a cancelled build
if build.phase == BUILD_PHASE.CANCELLED:
return False
updated = (RepositoryBuild
.update(phase=phase)
.where(RepositoryBuild.id == build.id, RepositoryBuild.phase == build.phase)
.execute())
return updated > 0
def create_cancel_build_in_queue(build_phase, build_queue_id, build_queue):
""" A function to cancel a build before it leaves the queue """
def cancel_build():
cancelled = False
if build_queue_id is not None:
cancelled = build_queue.cancel(build_queue_id)
if build_phase != BUILD_PHASE.WAITING:
return False
return cancelled
return cancel_build
def create_cancel_build_in_manager(build_phase, build_uuid, build_canceller):
""" A function to cancel the build before it starts to push """
def cancel_build():
if build_phase in PHASES_NOT_ALLOWED_TO_CANCEL_FROM:
return False
return build_canceller.try_cancel_build(build_uuid)
return cancel_build
def cancel_repository_build(build, build_queue):
""" This tries to cancel the build returns true if request is successful false
if it can't be cancelled """
from app import build_canceller
from buildman.jobutil.buildjob import BuildJobNotifier
cancel_builds = [create_cancel_build_in_queue(build.phase, build.queue_id, build_queue),
create_cancel_build_in_manager(build.phase, build.uuid, build_canceller), ]
for cancelled in cancel_builds:
if cancelled():
updated = update_phase_then_close(build.uuid, BUILD_PHASE.CANCELLED)
if updated:
BuildJobNotifier(build.uuid).send_notification("build_cancelled")
return updated
return False
def get_archivable_build():
presumed_dead_date = datetime.utcnow() - PRESUMED_DEAD_BUILD_AGE
candidates = (RepositoryBuild
.select(RepositoryBuild.id)
.where((RepositoryBuild.phase << ARCHIVABLE_BUILD_PHASES) |
(RepositoryBuild.started < presumed_dead_date),
RepositoryBuild.logs_archived == False)
.limit(50)
.alias('candidates'))
try:
found_id = (RepositoryBuild
.select(candidates.c.id)
.from_(candidates)
.order_by(db_random_func())
.get())
return RepositoryBuild.get(id=found_id)
except RepositoryBuild.DoesNotExist:
return None
def mark_build_archived(build_uuid):
""" Mark a build as archived, and return True if we were the ones who actually
updated the row. """
return (RepositoryBuild
.update(logs_archived=True)
.where(RepositoryBuild.uuid == build_uuid,
RepositoryBuild.logs_archived == False)
.execute()) > 0
def toggle_build_trigger(trigger, enabled, reason=TRIGGER_DISABLE_REASON.USER_TOGGLED):
""" Toggles the enabled status of a build trigger. """
trigger.enabled = enabled
if not enabled:
trigger.disabled_reason = RepositoryBuildTrigger.disabled_reason.get_id(reason)
trigger.disabled_datetime = datetime.utcnow()
trigger.save()
def update_trigger_disable_status(trigger, final_phase):
""" Updates the disable status of the given build trigger. If the build trigger had a
failure, then the counter is increased and, if we've reached the limit, the trigger is
automatically disabled. Otherwise, if the trigger succeeded, it's counter is reset. This
ensures that triggers that continue to error are eventually automatically disabled.
"""
with db_transaction():
try:
trigger = RepositoryBuildTrigger.get(id=trigger.id)
except RepositoryBuildTrigger.DoesNotExist:
# Already deleted.
return
# If the build completed successfully, then reset the successive counters.
if final_phase == BUILD_PHASE.COMPLETE:
trigger.successive_failure_count = 0
trigger.successive_internal_error_count = 0
trigger.save()
return
# Otherwise, increment the counters and check for trigger disable.
if final_phase == BUILD_PHASE.ERROR:
trigger.successive_failure_count = trigger.successive_failure_count + 1
trigger.successive_internal_error_count = 0
elif final_phase == BUILD_PHASE.INTERNAL_ERROR:
trigger.successive_internal_error_count = trigger.successive_internal_error_count + 1
# Check if we need to disable the trigger.
failure_threshold = config.app_config.get('SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD')
error_threshold = config.app_config.get('SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD')
if failure_threshold and trigger.successive_failure_count >= failure_threshold:
toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.BUILD_FALURES)
elif (error_threshold and
trigger.successive_internal_error_count >= error_threshold):
toggle_build_trigger(trigger, False, TRIGGER_DISABLE_REASON.INTERNAL_ERRORS)
else:
# Save the trigger changes.
trigger.save()

554
data/model/gc.py Normal file
View file

@ -0,0 +1,554 @@
import logging
from data.model import config, db_transaction, storage, _basequery, tag as pre_oci_tag
from data.model.oci import tag as oci_tag
from data.database import Repository, db_for_update
from data.database import ApprTag
from data.database import (Tag, Manifest, ManifestBlob, ManifestChild, ManifestLegacyImage,
ManifestLabel, Label, TagManifestLabel)
from data.database import RepositoryTag, TagManifest, Image, DerivedStorageForImage
from data.database import TagManifestToManifest, TagToRepositoryTag, TagManifestLabelMap
logger = logging.getLogger(__name__)
class _GarbageCollectorContext(object):
def __init__(self, repository):
self.repository = repository
self.manifest_ids = set()
self.label_ids = set()
self.blob_ids = set()
self.legacy_image_ids = set()
def add_manifest_id(self, manifest_id):
self.manifest_ids.add(manifest_id)
def add_label_id(self, label_id):
self.label_ids.add(label_id)
def add_blob_id(self, blob_id):
self.blob_ids.add(blob_id)
def add_legacy_image_id(self, legacy_image_id):
self.legacy_image_ids.add(legacy_image_id)
def mark_label_id_removed(self, label_id):
self.label_ids.remove(label_id)
def mark_manifest_removed(self, manifest):
self.manifest_ids.remove(manifest.id)
def mark_legacy_image_removed(self, legacy_image):
self.legacy_image_ids.remove(legacy_image.id)
def mark_blob_id_removed(self, blob_id):
self.blob_ids.remove(blob_id)
def purge_repository(namespace_name, repository_name):
""" Completely delete all traces of the repository. Will return True upon
complete success, and False upon partial or total failure. Garbage
collection is incremental and repeatable, so this return value does
not need to be checked or responded to.
"""
try:
repo = _basequery.get_existing_repository(namespace_name, repository_name)
except Repository.DoesNotExist:
return False
assert repo.name == repository_name
# Delete the repository of all Appr-referenced entries.
# Note that new-model Tag's must be deleted in *two* passes, as they can reference parent tags,
# and MySQL is... particular... about such relationships when deleting.
if repo.kind.name == 'application':
ApprTag.delete().where(ApprTag.repository == repo, ~(ApprTag.linked_tag >> None)).execute()
ApprTag.delete().where(ApprTag.repository == repo).execute()
else:
# GC to remove the images and storage.
_purge_repository_contents(repo)
# Ensure there are no additional tags, manifests, images or blobs in the repository.
assert ApprTag.select().where(ApprTag.repository == repo).count() == 0
assert Tag.select().where(Tag.repository == repo).count() == 0
assert RepositoryTag.select().where(RepositoryTag.repository == repo).count() == 0
assert Manifest.select().where(Manifest.repository == repo).count() == 0
assert ManifestBlob.select().where(ManifestBlob.repository == repo).count() == 0
assert Image.select().where(Image.repository == repo).count() == 0
# Delete the rest of the repository metadata.
try:
# Make sure the repository still exists.
fetched = _basequery.get_existing_repository(namespace_name, repository_name)
except Repository.DoesNotExist:
return False
fetched.delete_instance(recursive=True, delete_nullable=False)
# Run callbacks
for callback in config.repo_cleanup_callbacks:
callback(namespace_name, repository_name)
return True
def _chunk_iterate_for_deletion(query, chunk_size=10):
""" Returns an iterator that loads the rows returned by the given query in chunks. Note that
order is not guaranteed here, so this will only work (i.e. not return duplicates) if
the rows returned are being deleted between calls.
"""
while True:
results = list(query.limit(chunk_size))
if not results:
raise StopIteration
yield results
def _purge_repository_contents(repo):
""" Purges all the contents of a repository, removing all of its tags,
manifests and images.
"""
logger.debug('Purging repository %s', repo)
# Purge via all the tags.
while True:
found = False
for tags in _chunk_iterate_for_deletion(Tag.select().where(Tag.repository == repo)):
logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
found = True
context = _GarbageCollectorContext(repo)
for tag in tags:
logger.debug('Deleting tag %s under repository %s', tag, repo)
assert tag.repository_id == repo.id
_purge_oci_tag(tag, context, allow_non_expired=True)
_run_garbage_collection(context)
if not found:
break
# TODO: remove this once we're fully on the OCI data model.
while True:
found = False
repo_tag_query = RepositoryTag.select().where(RepositoryTag.repository == repo)
for tags in _chunk_iterate_for_deletion(repo_tag_query):
logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
found = True
context = _GarbageCollectorContext(repo)
for tag in tags:
logger.debug('Deleting tag %s under repository %s', tag, repo)
assert tag.repository_id == repo.id
_purge_pre_oci_tag(tag, context, allow_non_expired=True)
_run_garbage_collection(context)
if not found:
break
# Add all remaining images to a new context. We do this here to minimize the number of images
# we need to load.
while True:
found_image = False
image_context = _GarbageCollectorContext(repo)
for image in Image.select().where(Image.repository == repo):
found_image = True
logger.debug('Deleting image %s under repository %s', image, repo)
assert image.repository_id == repo.id
image_context.add_legacy_image_id(image.id)
_run_garbage_collection(image_context)
if not found_image:
break
def garbage_collect_repo(repo):
""" Performs garbage collection over the contents of a repository. """
# Purge expired tags.
had_changes = False
for tags in _chunk_iterate_for_deletion(oci_tag.lookup_unrecoverable_tags(repo)):
logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
context = _GarbageCollectorContext(repo)
for tag in tags:
logger.debug('Deleting tag %s under repository %s', tag, repo)
assert tag.repository_id == repo.id
assert tag.lifetime_end_ms is not None
_purge_oci_tag(tag, context)
_run_garbage_collection(context)
had_changes = True
for tags in _chunk_iterate_for_deletion(pre_oci_tag.lookup_unrecoverable_tags(repo)):
logger.debug('Found %s tags to GC under repository %s', len(tags), repo)
context = _GarbageCollectorContext(repo)
for tag in tags:
logger.debug('Deleting tag %s under repository %s', tag, repo)
assert tag.repository_id == repo.id
assert tag.lifetime_end_ts is not None
_purge_pre_oci_tag(tag, context)
_run_garbage_collection(context)
had_changes = True
return had_changes
def _run_garbage_collection(context):
""" Runs the garbage collection loop, deleting manifests, images, labels and blobs
in an iterative fashion.
"""
has_changes = True
while has_changes:
has_changes = False
# GC all manifests encountered.
for manifest_id in list(context.manifest_ids):
if _garbage_collect_manifest(manifest_id, context):
has_changes = True
# GC all images encountered.
for image_id in list(context.legacy_image_ids):
if _garbage_collect_legacy_image(image_id, context):
has_changes = True
# GC all labels encountered.
for label_id in list(context.label_ids):
if _garbage_collect_label(label_id, context):
has_changes = True
# GC any blobs encountered.
if context.blob_ids:
storage_ids_removed = set(storage.garbage_collect_storage(context.blob_ids))
for blob_removed_id in storage_ids_removed:
context.mark_blob_id_removed(blob_removed_id)
has_changes = True
def _purge_oci_tag(tag, context, allow_non_expired=False):
assert tag.repository_id == context.repository.id
if not allow_non_expired:
assert tag.lifetime_end_ms is not None
assert tag.lifetime_end_ms <= oci_tag.get_epoch_timestamp_ms()
# Add the manifest to be GCed.
context.add_manifest_id(tag.manifest_id)
with db_transaction():
# Reload the tag and verify its lifetime_end_ms has not changed.
try:
reloaded_tag = db_for_update(Tag.select().where(Tag.id == tag.id)).get()
except Tag.DoesNotExist:
return False
assert reloaded_tag.id == tag.id
assert reloaded_tag.repository_id == context.repository.id
if reloaded_tag.lifetime_end_ms != tag.lifetime_end_ms:
return False
# Delete mapping rows.
TagToRepositoryTag.delete().where(TagToRepositoryTag.tag == tag).execute()
# Delete the tag.
tag.delete_instance()
def _purge_pre_oci_tag(tag, context, allow_non_expired=False):
assert tag.repository_id == context.repository.id
if not allow_non_expired:
assert tag.lifetime_end_ts is not None
assert tag.lifetime_end_ts <= pre_oci_tag.get_epoch_timestamp()
# If it exists, GC the tag manifest.
try:
tag_manifest = TagManifest.select().where(TagManifest.tag == tag).get()
_garbage_collect_legacy_manifest(tag_manifest.id, context)
except TagManifest.DoesNotExist:
pass
# Add the tag's legacy image to be GCed.
context.add_legacy_image_id(tag.image_id)
with db_transaction():
# Reload the tag and verify its lifetime_end_ts has not changed.
try:
reloaded_tag = db_for_update(RepositoryTag.select().where(RepositoryTag.id == tag.id)).get()
except RepositoryTag.DoesNotExist:
return False
assert reloaded_tag.id == tag.id
assert reloaded_tag.repository_id == context.repository.id
if reloaded_tag.lifetime_end_ts != tag.lifetime_end_ts:
return False
# Delete mapping rows.
TagToRepositoryTag.delete().where(TagToRepositoryTag.repository_tag == reloaded_tag).execute()
# Delete the tag.
reloaded_tag.delete_instance()
def _check_manifest_used(manifest_id):
assert manifest_id is not None
with db_transaction():
# Check if the manifest is referenced by any other tag.
try:
Tag.select().where(Tag.manifest == manifest_id).get()
return True
except Tag.DoesNotExist:
pass
# Check if the manifest is referenced as a child of another manifest.
try:
ManifestChild.select().where(ManifestChild.child_manifest == manifest_id).get()
return True
except ManifestChild.DoesNotExist:
pass
return False
def _garbage_collect_manifest(manifest_id, context):
assert manifest_id is not None
# Make sure the manifest isn't referenced.
if _check_manifest_used(manifest_id):
return False
# Add the manifest's blobs to the context to be GCed.
for manifest_blob in ManifestBlob.select().where(ManifestBlob.manifest == manifest_id):
context.add_blob_id(manifest_blob.blob_id)
# Retrieve the manifest's associated image, if any.
try:
legacy_image_id = ManifestLegacyImage.get(manifest=manifest_id).image_id
context.add_legacy_image_id(legacy_image_id)
except ManifestLegacyImage.DoesNotExist:
legacy_image_id = None
# Add child manifests to be GCed.
for connector in ManifestChild.select().where(ManifestChild.manifest == manifest_id):
context.add_manifest_id(connector.child_manifest_id)
# Add the labels to be GCed.
for manifest_label in ManifestLabel.select().where(ManifestLabel.manifest == manifest_id):
context.add_label_id(manifest_label.label_id)
# Delete the manifest.
with db_transaction():
try:
manifest = Manifest.select().where(Manifest.id == manifest_id).get()
except Manifest.DoesNotExist:
return False
assert manifest.id == manifest_id
assert manifest.repository_id == context.repository.id
if _check_manifest_used(manifest_id):
return False
# Delete any label mappings.
(TagManifestLabelMap
.delete()
.where(TagManifestLabelMap.manifest == manifest_id)
.execute())
# Delete any mapping rows for the manifest.
TagManifestToManifest.delete().where(TagManifestToManifest.manifest == manifest_id).execute()
# Delete any label rows.
ManifestLabel.delete().where(ManifestLabel.manifest == manifest_id,
ManifestLabel.repository == context.repository).execute()
# Delete any child manifest rows.
ManifestChild.delete().where(ManifestChild.manifest == manifest_id,
ManifestChild.repository == context.repository).execute()
# Delete the manifest blobs for the manifest.
ManifestBlob.delete().where(ManifestBlob.manifest == manifest_id,
ManifestBlob.repository == context.repository).execute()
# Delete the manifest legacy image row.
if legacy_image_id:
(ManifestLegacyImage
.delete()
.where(ManifestLegacyImage.manifest == manifest_id,
ManifestLegacyImage.repository == context.repository)
.execute())
# Delete the manifest.
manifest.delete_instance()
context.mark_manifest_removed(manifest)
return True
def _garbage_collect_legacy_manifest(legacy_manifest_id, context):
assert legacy_manifest_id is not None
# Add the labels to be GCed.
query = TagManifestLabel.select().where(TagManifestLabel.annotated == legacy_manifest_id)
for manifest_label in query:
context.add_label_id(manifest_label.label_id)
# Delete the tag manifest.
with db_transaction():
try:
tag_manifest = TagManifest.select().where(TagManifest.id == legacy_manifest_id).get()
except TagManifest.DoesNotExist:
return False
assert tag_manifest.id == legacy_manifest_id
assert tag_manifest.tag.repository_id == context.repository.id
# Delete any label mapping rows.
(TagManifestLabelMap
.delete()
.where(TagManifestLabelMap.tag_manifest == legacy_manifest_id)
.execute())
# Delete the label rows.
TagManifestLabel.delete().where(TagManifestLabel.annotated == legacy_manifest_id).execute()
# Delete the mapping row if it exists.
try:
tmt = (TagManifestToManifest
.select()
.where(TagManifestToManifest.tag_manifest == tag_manifest)
.get())
context.add_manifest_id(tmt.manifest_id)
tmt.delete_instance()
except TagManifestToManifest.DoesNotExist:
pass
# Delete the tag manifest.
tag_manifest.delete_instance()
return True
def _check_image_used(legacy_image_id):
assert legacy_image_id is not None
with db_transaction():
# Check if the image is referenced by a manifest.
try:
ManifestLegacyImage.select().where(ManifestLegacyImage.image == legacy_image_id).get()
return True
except ManifestLegacyImage.DoesNotExist:
pass
# Check if the image is referenced by a tag.
try:
RepositoryTag.select().where(RepositoryTag.image == legacy_image_id).get()
return True
except RepositoryTag.DoesNotExist:
pass
# Check if the image is referenced by another image.
try:
Image.select().where(Image.parent == legacy_image_id).get()
return True
except Image.DoesNotExist:
pass
return False
def _garbage_collect_legacy_image(legacy_image_id, context):
assert legacy_image_id is not None
# Check if the image is referenced.
if _check_image_used(legacy_image_id):
return False
# We have an unreferenced image. We can now delete it.
# Grab any derived storage for the image.
for derived in (DerivedStorageForImage
.select()
.where(DerivedStorageForImage.source_image == legacy_image_id)):
context.add_blob_id(derived.derivative_id)
try:
image = Image.select().where(Image.id == legacy_image_id).get()
except Image.DoesNotExist:
return False
assert image.repository_id == context.repository.id
# Add the image's blob to be GCed.
context.add_blob_id(image.storage_id)
# If the image has a parent ID, add the parent for GC.
if image.parent_id is not None:
context.add_legacy_image_id(image.parent_id)
# Delete the image.
with db_transaction():
if _check_image_used(legacy_image_id):
return False
try:
image = Image.select().where(Image.id == legacy_image_id).get()
except Image.DoesNotExist:
return False
assert image.id == legacy_image_id
assert image.repository_id == context.repository.id
# Delete any derived storage for the image.
(DerivedStorageForImage
.delete()
.where(DerivedStorageForImage.source_image == legacy_image_id)
.execute())
# Delete the image itself.
image.delete_instance()
context.mark_legacy_image_removed(image)
if config.image_cleanup_callbacks:
for callback in config.image_cleanup_callbacks:
callback([image])
return True
def _check_label_used(label_id):
assert label_id is not None
with db_transaction():
# Check if the label is referenced by another manifest or tag manifest.
try:
ManifestLabel.select().where(ManifestLabel.label == label_id).get()
return True
except ManifestLabel.DoesNotExist:
pass
try:
TagManifestLabel.select().where(TagManifestLabel.label == label_id).get()
return True
except TagManifestLabel.DoesNotExist:
pass
return False
def _garbage_collect_label(label_id, context):
assert label_id is not None
# We can now delete the label.
with db_transaction():
if _check_label_used(label_id):
return False
result = Label.delete().where(Label.id == label_id).execute() == 1
if result:
context.mark_label_id_removed(label_id)
return result

22
data/model/health.py Normal file
View file

@ -0,0 +1,22 @@
import logging
from data.database import TeamRole, validate_database_url
logger = logging.getLogger(__name__)
def check_health(app_config):
# Attempt to connect to the database first. If the DB is not responding,
# using the validate_database_url will timeout quickly, as opposed to
# making a normal connect which will just hang (thus breaking the health
# check).
try:
validate_database_url(app_config['DB_URI'], {}, connect_timeout=3)
except Exception as ex:
return (False, 'Could not connect to the database: %s' % ex.message)
# We will connect to the db, check that it contains some team role kinds
try:
okay = bool(list(TeamRole.select().limit(1)))
return (okay, 'Could not connect to the database' if not okay else None)
except Exception as ex:
return (False, 'Could not connect to the database: %s' % ex.message)

516
data/model/image.py Normal file
View file

@ -0,0 +1,516 @@
import logging
import hashlib
import json
from collections import defaultdict
from datetime import datetime
import dateutil.parser
from peewee import JOIN, IntegrityError, fn
from data.model import (DataModelException, db_transaction, _basequery, storage,
InvalidImageException)
from data.database import (Image, Repository, ImageStoragePlacement, Namespace, ImageStorage,
ImageStorageLocation, RepositoryPermission, DerivedStorageForImage,
ImageStorageTransformation, User)
from util.canonicaljson import canonicalize
logger = logging.getLogger(__name__)
def _namespace_id_for_username(username):
try:
return User.get(username=username).id
except User.DoesNotExist:
return None
def get_image_with_storage(docker_image_id, storage_uuid):
""" Returns the image with the given docker image ID and storage uuid or None if none.
"""
try:
return (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.where(Image.docker_image_id == docker_image_id,
ImageStorage.uuid == storage_uuid)
.get())
except Image.DoesNotExist:
return None
def get_parent_images(namespace_name, repository_name, image_obj):
""" Returns a list of parent Image objects starting with the most recent parent
and ending with the base layer. The images in this query will include the storage.
"""
parents = image_obj.ancestors
# Ancestors are in the format /<root>/<intermediate>/.../<parent>/, with each path section
# containing the database Id of the image row.
parent_db_ids = parents.strip('/').split('/')
if parent_db_ids == ['']:
return []
def filter_to_parents(query):
return query.where(Image.id << parent_db_ids)
parents = _get_repository_images_and_storages(namespace_name, repository_name,
filter_to_parents)
id_to_image = {unicode(image.id): image for image in parents}
try:
return [id_to_image[parent_id] for parent_id in reversed(parent_db_ids)]
except KeyError as ke:
logger.exception('Could not find an expected parent image for image %s', image_obj.id)
raise DataModelException('Unknown parent image')
def get_placements_for_images(images):
""" Returns the placements for the given images, as a map from image storage ID to placements. """
if not images:
return {}
query = (ImageStoragePlacement
.select(ImageStoragePlacement, ImageStorageLocation, ImageStorage)
.join(ImageStorageLocation)
.switch(ImageStoragePlacement)
.join(ImageStorage)
.where(ImageStorage.id << [image.storage_id for image in images]))
placement_map = defaultdict(list)
for placement in query:
placement_map[placement.storage.id].append(placement)
return dict(placement_map)
def get_image_and_placements(namespace_name, repo_name, docker_image_id):
""" Returns the repo image (with a storage object) and storage placements for the image
or (None, None) if non found.
"""
repo_image = get_repo_image_and_storage(namespace_name, repo_name, docker_image_id)
if repo_image is None:
return (None, None)
query = (ImageStoragePlacement
.select(ImageStoragePlacement, ImageStorageLocation)
.join(ImageStorageLocation)
.switch(ImageStoragePlacement)
.join(ImageStorage)
.where(ImageStorage.id == repo_image.storage_id))
return repo_image, list(query)
def get_repo_image(namespace_name, repository_name, docker_image_id):
""" Returns the repository image with the given Docker image ID or None if none.
Does not include the storage object.
"""
def limit_to_image_id(query):
return query.where(Image.docker_image_id == docker_image_id).limit(1)
query = _get_repository_images(namespace_name, repository_name, limit_to_image_id)
try:
return query.get()
except Image.DoesNotExist:
return None
def get_repo_image_and_storage(namespace_name, repository_name, docker_image_id):
""" Returns the repository image with the given Docker image ID or None if none.
Includes the storage object.
"""
def limit_to_image_id(query):
return query.where(Image.docker_image_id == docker_image_id)
images = _get_repository_images_and_storages(namespace_name, repository_name, limit_to_image_id)
if not images:
return None
return images[0]
def get_image_by_id(namespace_name, repository_name, docker_image_id):
""" Returns the repository image with the given Docker image ID or raises if not found.
Includes the storage object.
"""
image = get_repo_image_and_storage(namespace_name, repository_name, docker_image_id)
if not image:
raise InvalidImageException('Unable to find image \'%s\' for repo \'%s/%s\'' %
(docker_image_id, namespace_name, repository_name))
return image
def _get_repository_images_and_storages(namespace_name, repository_name, query_modifier):
query = (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.switch(Image)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Repository.name == repository_name, Namespace.username == namespace_name))
query = query_modifier(query)
return query
def _get_repository_images(namespace_name, repository_name, query_modifier):
query = (Image
.select()
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Repository.name == repository_name, Namespace.username == namespace_name))
query = query_modifier(query)
return query
def lookup_repository_images(repo, docker_image_ids):
return (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.where(Image.repository == repo, Image.docker_image_id << docker_image_ids))
def get_repository_images_without_placements(repo_obj, with_ancestor=None):
query = (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.where(Image.repository == repo_obj))
if with_ancestor:
ancestors_string = '%s%s/' % (with_ancestor.ancestors, with_ancestor.id)
query = query.where((Image.ancestors ** (ancestors_string + '%')) |
(Image.id == with_ancestor.id))
return query
def get_repository_images(namespace_name, repository_name):
""" Returns all the repository images in the repository. Does not include storage objects. """
return _get_repository_images(namespace_name, repository_name, lambda q: q)
def __translate_ancestry(old_ancestry, translations, repo_obj, username, preferred_location):
if old_ancestry == '/':
return '/'
def translate_id(old_id, docker_image_id):
logger.debug('Translating id: %s', old_id)
if old_id not in translations:
image_in_repo = find_create_or_link_image(docker_image_id, repo_obj, username, translations,
preferred_location)
translations[old_id] = image_in_repo.id
return translations[old_id]
# Select all the ancestor Docker IDs in a single query.
old_ids = [int(id_str) for id_str in old_ancestry.split('/')[1:-1]]
query = Image.select(Image.id, Image.docker_image_id).where(Image.id << old_ids)
old_images = {i.id: i.docker_image_id for i in query}
# Translate the old images into new ones.
new_ids = [str(translate_id(old_id, old_images[old_id])) for old_id in old_ids]
return '/%s/' % '/'.join(new_ids)
def _find_or_link_image(existing_image, repo_obj, username, translations, preferred_location):
with db_transaction():
# Check for an existing image, under the transaction, to make sure it doesn't already exist.
repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
existing_image.docker_image_id)
if repo_image:
return repo_image
# Make sure the existing base image still exists.
try:
to_copy = Image.select().join(ImageStorage).where(Image.id == existing_image.id).get()
msg = 'Linking image to existing storage with docker id: %s and uuid: %s'
logger.debug(msg, existing_image.docker_image_id, to_copy.storage.uuid)
new_image_ancestry = __translate_ancestry(to_copy.ancestors, translations, repo_obj,
username, preferred_location)
copied_storage = to_copy.storage
translated_parent_id = None
if new_image_ancestry != '/':
translated_parent_id = int(new_image_ancestry.split('/')[-2])
new_image = Image.create(docker_image_id=existing_image.docker_image_id,
repository=repo_obj,
storage=copied_storage,
ancestors=new_image_ancestry,
command=existing_image.command,
created=existing_image.created,
comment=existing_image.comment,
v1_json_metadata=existing_image.v1_json_metadata,
aggregate_size=existing_image.aggregate_size,
parent=translated_parent_id,
v1_checksum=existing_image.v1_checksum)
logger.debug('Storing translation %s -> %s', existing_image.id, new_image.id)
translations[existing_image.id] = new_image.id
return new_image
except Image.DoesNotExist:
return None
def find_create_or_link_image(docker_image_id, repo_obj, username, translations,
preferred_location):
# First check for the image existing in the repository. If found, we simply return it.
repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
docker_image_id)
if repo_image:
return repo_image
# We next check to see if there is an existing storage the new image can link to.
existing_image_query = (Image
.select(Image, ImageStorage)
.distinct()
.join(ImageStorage)
.switch(Image)
.join(Repository)
.join(RepositoryPermission, JOIN.LEFT_OUTER)
.switch(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(ImageStorage.uploading == False,
Image.docker_image_id == docker_image_id))
existing_image_query = _basequery.filter_to_repos_for_user(existing_image_query,
_namespace_id_for_username(username))
# If there is an existing image, we try to translate its ancestry and copy its storage.
new_image = None
try:
logger.debug('Looking up existing image for ID: %s', docker_image_id)
existing_image = existing_image_query.get()
logger.debug('Existing image %s found for ID: %s', existing_image.id, docker_image_id)
new_image = _find_or_link_image(existing_image, repo_obj, username, translations,
preferred_location)
if new_image:
return new_image
except Image.DoesNotExist:
logger.debug('No existing image found for ID: %s', docker_image_id)
# Otherwise, create a new storage directly.
with db_transaction():
# Final check for an existing image, under the transaction.
repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
docker_image_id)
if repo_image:
return repo_image
logger.debug('Creating new storage for docker id: %s', docker_image_id)
new_storage = storage.create_v1_storage(preferred_location)
return Image.create(docker_image_id=docker_image_id,
repository=repo_obj, storage=new_storage,
ancestors='/')
def set_image_metadata(docker_image_id, namespace_name, repository_name, created_date_str, comment,
command, v1_json_metadata, parent=None):
""" Sets metadata that is specific to how a binary piece of storage fits into the layer tree.
"""
with db_transaction():
try:
fetched = (Image
.select(Image, ImageStorage)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Image)
.join(ImageStorage)
.where(Repository.name == repository_name, Namespace.username == namespace_name,
Image.docker_image_id == docker_image_id)
.get())
except Image.DoesNotExist:
raise DataModelException('No image with specified id and repository')
fetched.created = datetime.now()
if created_date_str is not None:
try:
fetched.created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
except:
# parse raises different exceptions, so we cannot use a specific kind of handler here.
pass
# We cleanup any old checksum in case it's a retry after a fail
fetched.v1_checksum = None
fetched.comment = comment
fetched.command = command
fetched.v1_json_metadata = v1_json_metadata
if parent:
fetched.ancestors = '%s%s/' % (parent.ancestors, parent.id)
fetched.parent = parent
fetched.save()
return fetched
def get_image(repo, docker_image_id):
try:
return (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.where(Image.docker_image_id == docker_image_id, Image.repository == repo)
.get())
except Image.DoesNotExist:
return None
def get_image_by_db_id(id):
try:
return Image.get(id=id)
except Image.DoesNotExist:
return None
def synthesize_v1_image(repo, image_storage_id, storage_image_size, docker_image_id,
created_date_str, comment, command, v1_json_metadata, parent_image=None):
""" Find an existing image with this docker image id, and if none exists, write one with the
specified metadata.
"""
ancestors = '/'
if parent_image is not None:
ancestors = '{0}{1}/'.format(parent_image.ancestors, parent_image.id)
created = None
if created_date_str is not None:
try:
created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
except:
# parse raises different exceptions, so we cannot use a specific kind of handler here.
pass
# Get the aggregate size for the image.
aggregate_size = _basequery.calculate_image_aggregate_size(ancestors, storage_image_size,
parent_image)
try:
return Image.create(docker_image_id=docker_image_id, ancestors=ancestors, comment=comment,
command=command, v1_json_metadata=v1_json_metadata, created=created,
storage=image_storage_id, repository=repo, parent=parent_image,
aggregate_size=aggregate_size)
except IntegrityError:
return Image.get(docker_image_id=docker_image_id, repository=repo)
def ensure_image_locations(*names):
with db_transaction():
locations = ImageStorageLocation.select().where(ImageStorageLocation.name << names)
insert_names = list(names)
for location in locations:
insert_names.remove(location.name)
if not insert_names:
return
data = [{'name': name} for name in insert_names]
ImageStorageLocation.insert_many(data).execute()
def get_max_id_for_sec_scan():
""" Gets the maximum id for a clair sec scan """
return Image.select(fn.Max(Image.id)).scalar()
def get_min_id_for_sec_scan(version):
""" Gets the minimum id for a clair sec scan """
return (Image
.select(fn.Min(Image.id))
.where(Image.security_indexed_engine < version)
.scalar())
def total_image_count():
""" Returns the total number of images in DB """
return Image.select().count()
def get_image_pk_field():
""" Returns the primary key for Image DB model """
return Image.id
def get_images_eligible_for_scan(clair_version):
""" Returns a query that gives all images eligible for a clair scan """
return (get_image_with_storage_and_parent_base()
.where(Image.security_indexed_engine < clair_version)
.where(ImageStorage.uploading == False))
def get_image_with_storage_and_parent_base():
Parent = Image.alias()
ParentImageStorage = ImageStorage.alias()
return (Image
.select(Image, ImageStorage, Parent, ParentImageStorage)
.join(ImageStorage)
.switch(Image)
.join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage)))
def set_secscan_status(image, indexed, version):
return (Image
.update(security_indexed=indexed, security_indexed_engine=version)
.where(Image.id == image.id)
.where((Image.security_indexed_engine != version) | (Image.security_indexed != indexed))
.execute()) != 0
def _get_uniqueness_hash(varying_metadata):
if not varying_metadata:
return None
return hashlib.sha256(json.dumps(canonicalize(varying_metadata))).hexdigest()
def find_or_create_derived_storage(source_image, transformation_name, preferred_location,
varying_metadata=None):
existing = find_derived_storage_for_image(source_image, transformation_name, varying_metadata)
if existing is not None:
return existing
uniqueness_hash = _get_uniqueness_hash(varying_metadata)
trans = ImageStorageTransformation.get(name=transformation_name)
new_storage = storage.create_v1_storage(preferred_location)
try:
derived = DerivedStorageForImage.create(source_image=source_image, derivative=new_storage,
transformation=trans, uniqueness_hash=uniqueness_hash)
except IntegrityError:
# Storage was created while this method executed. Just return the existing.
ImageStoragePlacement.delete().where(ImageStoragePlacement.storage == new_storage).execute()
new_storage.delete_instance()
return find_derived_storage_for_image(source_image, transformation_name, varying_metadata)
return derived
def find_derived_storage_for_image(source_image, transformation_name, varying_metadata=None):
uniqueness_hash = _get_uniqueness_hash(varying_metadata)
try:
found = (DerivedStorageForImage
.select(ImageStorage, DerivedStorageForImage)
.join(ImageStorage)
.switch(DerivedStorageForImage)
.join(ImageStorageTransformation)
.where(DerivedStorageForImage.source_image == source_image,
ImageStorageTransformation.name == transformation_name,
DerivedStorageForImage.uniqueness_hash == uniqueness_hash)
.get())
return found
except DerivedStorageForImage.DoesNotExist:
return None
def delete_derived_storage(derived_storage):
derived_storage.derivative.delete_instance(recursive=True)

143
data/model/label.py Normal file
View file

@ -0,0 +1,143 @@
import logging
from cachetools.func import lru_cache
from data.database import (Label, TagManifestLabel, MediaType, LabelSourceType, db_transaction,
ManifestLabel, TagManifestLabelMap, TagManifestToManifest)
from data.model import InvalidLabelKeyException, InvalidMediaTypeException, DataModelException
from data.text import prefix_search
from util.validation import validate_label_key
from util.validation import is_json
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_label_source_types():
source_type_map = {}
for kind in LabelSourceType.select():
source_type_map[kind.id] = kind.name
source_type_map[kind.name] = kind.id
return source_type_map
@lru_cache(maxsize=1)
def get_media_types():
media_type_map = {}
for kind in MediaType.select():
media_type_map[kind.id] = kind.name
media_type_map[kind.name] = kind.id
return media_type_map
def _get_label_source_type_id(name):
kinds = get_label_source_types()
return kinds[name]
def _get_media_type_id(name):
kinds = get_media_types()
return kinds[name]
def create_manifest_label(tag_manifest, key, value, source_type_name, media_type_name=None):
""" Creates a new manifest label on a specific tag manifest. """
if not key:
raise InvalidLabelKeyException()
# Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
# does not currently prevent them from being put into said manifests.
if not validate_label_key(key) and source_type_name != 'manifest':
raise InvalidLabelKeyException()
# Find the matching media type. If none specified, we infer.
if media_type_name is None:
media_type_name = 'text/plain'
if is_json(value):
media_type_name = 'application/json'
media_type_id = _get_media_type_id(media_type_name)
if media_type_id is None:
raise InvalidMediaTypeException()
source_type_id = _get_label_source_type_id(source_type_name)
with db_transaction():
label = Label.create(key=key, value=value, source_type=source_type_id, media_type=media_type_id)
tag_manifest_label = TagManifestLabel.create(annotated=tag_manifest, label=label,
repository=tag_manifest.tag.repository)
try:
mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
if mapping_row.manifest:
manifest_label = ManifestLabel.create(manifest=mapping_row.manifest, label=label,
repository=tag_manifest.tag.repository)
TagManifestLabelMap.create(manifest_label=manifest_label,
tag_manifest_label=tag_manifest_label,
label=label,
manifest=mapping_row.manifest,
tag_manifest=tag_manifest)
except TagManifestToManifest.DoesNotExist:
pass
return label
def list_manifest_labels(tag_manifest, prefix_filter=None):
""" Lists all labels found on the given tag manifest. """
query = (Label.select(Label, MediaType)
.join(MediaType)
.switch(Label)
.join(LabelSourceType)
.switch(Label)
.join(TagManifestLabel)
.where(TagManifestLabel.annotated == tag_manifest))
if prefix_filter is not None:
query = query.where(prefix_search(Label.key, prefix_filter))
return query
def get_manifest_label(label_uuid, tag_manifest):
""" Retrieves the manifest label on the tag manifest with the given ID. """
try:
return (Label.select(Label, LabelSourceType)
.join(LabelSourceType)
.where(Label.uuid == label_uuid)
.switch(Label)
.join(TagManifestLabel)
.where(TagManifestLabel.annotated == tag_manifest)
.get())
except Label.DoesNotExist:
return None
def delete_manifest_label(label_uuid, tag_manifest):
""" Deletes the manifest label on the tag manifest with the given ID. """
# Find the label itself.
label = get_manifest_label(label_uuid, tag_manifest)
if label is None:
return None
if not label.source_type.mutable:
raise DataModelException('Cannot delete immutable label')
# Delete the mapping records and label.
(TagManifestLabelMap
.delete()
.where(TagManifestLabelMap.label == label)
.execute())
deleted_count = TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
if deleted_count != 1:
logger.warning('More than a single label deleted for matching label %s', label_uuid)
deleted_count = ManifestLabel.delete().where(ManifestLabel.label == label).execute()
if deleted_count != 1:
logger.warning('More than a single label deleted for matching label %s', label_uuid)
label.delete_instance(recursive=False)
return label

299
data/model/log.py Normal file
View file

@ -0,0 +1,299 @@
import json
import logging
from datetime import datetime, timedelta
from calendar import timegm
from cachetools.func import lru_cache
from peewee import JOIN, fn, PeeweeException
from data.database import LogEntryKind, User, RepositoryActionCount, db, LogEntry3
from data.model import config, user, DataModelException
logger = logging.getLogger(__name__)
ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING = ['pull_repo']
def _logs_query(selections, start_time=None, end_time=None, performer=None, repository=None,
namespace=None, ignore=None, model=LogEntry3, id_range=None):
""" Returns a query for selecting logs from the table, with various options and filters. """
assert (start_time is not None and end_time is not None) or (id_range is not None)
joined = (model.select(*selections).switch(model))
if id_range is not None:
joined = joined.where(model.id >= id_range[0], model.id <= id_range[1])
else:
joined = joined.where(model.datetime >= start_time, model.datetime < end_time)
if repository:
joined = joined.where(model.repository == repository)
if performer:
joined = joined.where(model.performer == performer)
if namespace and not repository:
namespace_user = user.get_user_or_org(namespace)
if namespace_user is None:
raise DataModelException('Invalid namespace requested')
joined = joined.where(model.account == namespace_user.id)
if ignore:
kind_map = get_log_entry_kinds()
ignore_ids = [kind_map[kind_name] for kind_name in ignore]
joined = joined.where(~(model.kind << ignore_ids))
return joined
def _latest_logs_query(selections, performer=None, repository=None, namespace=None, ignore=None,
model=LogEntry3, size=None):
""" Returns a query for selecting the latest logs from the table, with various options and
filters. """
query = (model.select(*selections).switch(model))
if repository:
query = query.where(model.repository == repository)
if performer:
query = query.where(model.repository == repository)
if namespace and not repository:
namespace_user = user.get_user_or_org(namespace)
if namespace_user is None:
raise DataModelException('Invalid namespace requested')
query = query.where(model.account == namespace_user.id)
if ignore:
kind_map = get_log_entry_kinds()
ignore_ids = [kind_map[kind_name] for kind_name in ignore]
query = query.where(~(model.kind << ignore_ids))
query = query.order_by(model.datetime.desc(), model.id)
if size:
query = query.limit(size)
return query
@lru_cache(maxsize=1)
def get_log_entry_kinds():
kind_map = {}
for kind in LogEntryKind.select():
kind_map[kind.id] = kind.name
kind_map[kind.name] = kind.id
return kind_map
def _get_log_entry_kind(name):
kinds = get_log_entry_kinds()
return kinds[name]
def get_aggregated_logs(start_time, end_time, performer=None, repository=None, namespace=None,
ignore=None, model=LogEntry3):
""" Returns the count of logs, by kind and day, for the logs matching the given filters. """
date = db.extract_date('day', model.datetime)
selections = [model.kind, date.alias('day'), fn.Count(model.id).alias('count')]
query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore,
model=model)
return query.group_by(date, model.kind)
def get_logs_query(start_time=None, end_time=None, performer=None, repository=None, namespace=None,
ignore=None, model=LogEntry3, id_range=None):
""" Returns the logs matching the given filters. """
Performer = User.alias()
Account = User.alias()
selections = [model, Performer]
if namespace is None and repository is None:
selections.append(Account)
query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore,
model=model, id_range=id_range)
query = (query.switch(model).join(Performer, JOIN.LEFT_OUTER,
on=(model.performer == Performer.id).alias('performer')))
if namespace is None and repository is None:
query = (query.switch(model).join(Account, JOIN.LEFT_OUTER,
on=(model.account == Account.id).alias('account')))
return query
def get_latest_logs_query(performer=None, repository=None, namespace=None, ignore=None,
model=LogEntry3, size=None):
""" Returns the latest logs matching the given filters. """
Performer = User.alias()
Account = User.alias()
selections = [model, Performer]
if namespace is None and repository is None:
selections.append(Account)
query = _latest_logs_query(selections, performer, repository, namespace, ignore, model=model,
size=size)
query = (query.switch(model).join(Performer, JOIN.LEFT_OUTER,
on=(model.performer == Performer.id).alias('performer')))
if namespace is None and repository is None:
query = (query.switch(model).join(Account, JOIN.LEFT_OUTER,
on=(model.account == Account.id).alias('account')))
return query
def _json_serialize(obj):
if isinstance(obj, datetime):
return timegm(obj.utctimetuple())
return obj
def log_action(kind_name, user_or_organization_name, performer=None, repository=None, ip=None,
metadata={}, timestamp=None):
""" Logs an entry in the LogEntry table. """
if not timestamp:
timestamp = datetime.today()
account = None
if user_or_organization_name is not None:
account = User.get(User.username == user_or_organization_name).id
else:
account = config.app_config.get('SERVICE_LOG_ACCOUNT_ID')
if account is None:
account = user.get_minimum_user_id()
if performer is not None:
performer = performer.id
if repository is not None:
repository = repository.id
kind = _get_log_entry_kind(kind_name)
metadata_json = json.dumps(metadata, default=_json_serialize)
log_data = {
'kind': kind,
'account': account,
'performer': performer,
'repository': repository,
'ip': ip,
'metadata_json': metadata_json,
'datetime': timestamp
}
try:
LogEntry3.create(**log_data)
except PeeweeException as ex:
strict_logging_disabled = config.app_config.get('ALLOW_PULLS_WITHOUT_STRICT_LOGGING')
if strict_logging_disabled and kind_name in ACTIONS_ALLOWED_WITHOUT_AUDIT_LOGGING:
logger.exception('log_action failed', extra=({'exception': ex}).update(log_data))
else:
raise
def get_stale_logs_start_id(model):
""" Gets the oldest log entry. """
try:
return (model.select(fn.Min(model.id)).tuples())[0][0]
except IndexError:
return None
def get_stale_logs(start_id, end_id, model, cutoff_date):
""" Returns all the logs with IDs between start_id and end_id inclusively. """
return model.select().where((model.id >= start_id),
(model.id <= end_id),
model.datetime <= cutoff_date)
def delete_stale_logs(start_id, end_id, model):
""" Deletes all the logs with IDs between start_id and end_id. """
model.delete().where((model.id >= start_id), (model.id <= end_id)).execute()
def get_repository_action_counts(repo, start_date):
""" Returns the daily aggregated action counts for the given repository, starting at the given
start date.
"""
return RepositoryActionCount.select().where(RepositoryActionCount.repository == repo,
RepositoryActionCount.date >= start_date)
def get_repositories_action_sums(repository_ids):
""" Returns a map from repository ID to total actions within that repository in the last week. """
if not repository_ids:
return {}
# Filter the join to recent entries only.
last_week = datetime.now() - timedelta(weeks=1)
tuples = (RepositoryActionCount.select(RepositoryActionCount.repository,
fn.Sum(RepositoryActionCount.count))
.where(RepositoryActionCount.repository << repository_ids)
.where(RepositoryActionCount.date >= last_week)
.group_by(RepositoryActionCount.repository).tuples())
action_count_map = {}
for record in tuples:
action_count_map[record[0]] = record[1]
return action_count_map
def get_minimum_id_for_logs(start_time, repository_id=None, namespace_id=None, model=LogEntry3):
""" Returns the minimum ID for logs matching the given repository or namespace in
the logs table, starting at the given start time.
"""
# First try bounded by a day. Most repositories will meet this criteria, and therefore
# can make a much faster query.
day_after = start_time + timedelta(days=1)
result = _get_bounded_id(fn.Min, model.datetime >= start_time,
repository_id, namespace_id, model.datetime < day_after, model=model)
if result is not None:
return result
return _get_bounded_id(fn.Min, model.datetime >= start_time, repository_id, namespace_id,
model=model)
def get_maximum_id_for_logs(end_time, repository_id=None, namespace_id=None, model=LogEntry3):
""" Returns the maximum ID for logs matching the given repository or namespace in
the logs table, ending at the given end time.
"""
# First try bounded by a day. Most repositories will meet this criteria, and therefore
# can make a much faster query.
day_before = end_time - timedelta(days=1)
result = _get_bounded_id(fn.Max, model.datetime <= end_time,
repository_id, namespace_id, model.datetime > day_before, model=model)
if result is not None:
return result
return _get_bounded_id(fn.Max, model.datetime <= end_time, repository_id, namespace_id,
model=model)
def _get_bounded_id(fn, filter_clause, repository_id, namespace_id, reduction_clause=None,
model=LogEntry3):
assert (namespace_id is not None) or (repository_id is not None)
query = (model
.select(fn(model.id))
.where(filter_clause))
if reduction_clause is not None:
query = query.where(reduction_clause)
if repository_id is not None:
query = query.where(model.repository == repository_id)
else:
query = query.where(model.account == namespace_id)
row = query.tuples()[0]
if not row:
return None
return row[0]

24
data/model/message.py Normal file
View file

@ -0,0 +1,24 @@
from data.database import Messages, MediaType
def get_messages():
"""Query the data base for messages and returns a container of database message objects"""
return Messages.select(Messages, MediaType).join(MediaType)
def create(messages):
"""Insert messages into the database."""
inserted = []
for message in messages:
severity = message['severity']
media_type_name = message['media_type']
media_type = MediaType.get(name=media_type_name)
inserted.append(Messages.create(content=message['content'], media_type=media_type,
severity=severity))
return inserted
def delete_message(uuids):
"""Delete message from the database"""
if not uuids:
return
Messages.delete().where(Messages.uuid << uuids).execute()

77
data/model/modelutil.py Normal file
View file

@ -0,0 +1,77 @@
import dateutil.parser
from datetime import datetime
from peewee import SQL
def paginate(query, model, descending=False, page_token=None, limit=50, sort_field_alias=None,
max_page=None, sort_field_name=None):
""" Paginates the given query using an field range, starting at the optional page_token.
Returns a *list* of matching results along with an unencrypted page_token for the
next page, if any. If descending is set to True, orders by the field descending rather
than ascending.
"""
# Note: We use the sort_field_alias for the order_by, but not the where below. The alias is
# necessary for certain queries that use unions in MySQL, as it gets confused on which field
# to order by. The where clause, on the other hand, cannot use the alias because Postgres does
# not allow aliases in where clauses.
sort_field_name = sort_field_name or 'id'
sort_field = getattr(model, sort_field_name)
if sort_field_alias is not None:
sort_field_name = sort_field_alias
sort_field = SQL(sort_field_alias)
if descending:
query = query.order_by(sort_field.desc())
else:
query = query.order_by(sort_field)
start_index = pagination_start(page_token)
if start_index is not None:
if descending:
query = query.where(sort_field <= start_index)
else:
query = query.where(sort_field >= start_index)
query = query.limit(limit + 1)
page_number = (page_token.get('page_number') or None) if page_token else None
if page_number is not None and max_page is not None and page_number > max_page:
return [], None
return paginate_query(query, limit=limit, sort_field_name=sort_field_name,
page_number=page_number)
def pagination_start(page_token=None):
""" Returns the start index for pagination for the given page token. Will return None if None. """
if page_token is not None:
start_index = page_token.get('start_index')
if page_token.get('is_datetime'):
start_index = dateutil.parser.parse(start_index)
return start_index
return None
def paginate_query(query, limit=50, sort_field_name=None, page_number=None):
""" Executes the given query and returns a page's worth of results, as well as the page token
for the next page (if any).
"""
results = list(query)
page_token = None
if len(results) > limit:
start_index = getattr(results[limit], sort_field_name or 'id')
is_datetime = False
if isinstance(start_index, datetime):
start_index = start_index.isoformat() + "Z"
is_datetime = True
page_token = {
'start_index': start_index,
'page_number': page_number + 1 if page_number else 1,
'is_datetime': is_datetime,
}
return results[0:limit], page_token

220
data/model/notification.py Normal file
View file

@ -0,0 +1,220 @@
import json
from peewee import SQL
from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole,
RepositoryNotification, ExternalNotificationEvent, Repository,
ExternalNotificationMethod, Namespace, db_for_update)
from data.model import InvalidNotificationException, db_transaction
def create_notification(kind_name, target, metadata={}, lookup_path=None):
kind_ref = NotificationKind.get(name=kind_name)
notification = Notification.create(kind=kind_ref, target=target,
metadata_json=json.dumps(metadata),
lookup_path=lookup_path)
return notification
def create_unique_notification(kind_name, target, metadata={}):
with db_transaction():
if list_notifications(target, kind_name).count() == 0:
create_notification(kind_name, target, metadata)
def lookup_notification(user, uuid):
results = list(list_notifications(user, id_filter=uuid, include_dismissed=True, limit=1))
if not results:
return None
return results[0]
def lookup_notifications_by_path_prefix(prefix):
return list((Notification
.select()
.where(Notification.lookup_path % prefix)))
def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=False,
page=None, limit=None):
base_query = (Notification
.select(Notification.id,
Notification.uuid,
Notification.kind,
Notification.metadata_json,
Notification.dismissed,
Notification.lookup_path,
Notification.created,
Notification.created.alias('cd'),
Notification.target)
.join(NotificationKind))
if kind_name is not None:
base_query = base_query.where(NotificationKind.name == kind_name)
if id_filter is not None:
base_query = base_query.where(Notification.uuid == id_filter)
if not include_dismissed:
base_query = base_query.where(Notification.dismissed == False)
# Lookup directly for the user.
user_direct = base_query.clone().where(Notification.target == user)
# Lookup via organizations admined by the user.
Org = User.alias()
AdminTeam = Team.alias()
AdminTeamMember = TeamMember.alias()
AdminUser = User.alias()
via_orgs = (base_query.clone()
.join(Org, on=(Org.id == Notification.target))
.join(AdminTeam, on=(Org.id == AdminTeam.organization))
.join(TeamRole, on=(AdminTeam.role == TeamRole.id))
.switch(AdminTeam)
.join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
.join(AdminUser, on=(AdminTeamMember.user == AdminUser.id))
.where((AdminUser.id == user) & (TeamRole.name == 'admin')))
query = user_direct | via_orgs
if page:
query = query.paginate(page, limit)
elif limit:
query = query.limit(limit)
return query.order_by(SQL('cd desc'))
def delete_all_notifications_by_path_prefix(prefix):
(Notification
.delete()
.where(Notification.lookup_path ** (prefix + '%'))
.execute())
def delete_all_notifications_by_kind(kind_name):
kind_ref = NotificationKind.get(name=kind_name)
(Notification
.delete()
.where(Notification.kind == kind_ref)
.execute())
def delete_notifications_by_kind(target, kind_name):
kind_ref = NotificationKind.get(name=kind_name)
Notification.delete().where(Notification.target == target,
Notification.kind == kind_ref).execute()
def delete_matching_notifications(target, kind_name, **kwargs):
kind_ref = NotificationKind.get(name=kind_name)
# Load all notifications for the user with the given kind.
notifications = (Notification
.select()
.where(Notification.target == target,
Notification.kind == kind_ref))
# For each, match the metadata to the specified values.
for notification in notifications:
matches = True
try:
metadata = json.loads(notification.metadata_json)
except:
continue
for (key, value) in kwargs.iteritems():
if not key in metadata or metadata[key] != value:
matches = False
break
if not matches:
continue
notification.delete_instance()
def increment_notification_failure_count(uuid):
""" This increments the number of failures by one """
(RepositoryNotification
.update(number_of_failures=RepositoryNotification.number_of_failures + 1)
.where(RepositoryNotification.uuid == uuid)
.execute())
def reset_notification_number_of_failures(namespace_name, repository_name, uuid):
""" This resets the number of failures for a repo notification to 0 """
try:
notification = RepositoryNotification.select().where(RepositoryNotification.uuid == uuid).get()
if (notification.repository.namespace_user.username != namespace_name or
notification.repository.name != repository_name):
raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
reset_number_of_failures_to_zero(notification.id)
return notification
except RepositoryNotification.DoesNotExist:
return None
def reset_number_of_failures_to_zero(notification_id):
""" This resets the number of failures for a repo notification to 0 """
RepositoryNotification.update(number_of_failures=0).where(RepositoryNotification.id == notification_id).execute()
def create_repo_notification(repo, event_name, method_name, method_config, event_config, title=None):
event = ExternalNotificationEvent.get(ExternalNotificationEvent.name == event_name)
method = ExternalNotificationMethod.get(ExternalNotificationMethod.name == method_name)
return RepositoryNotification.create(repository=repo, event=event, method=method,
config_json=json.dumps(method_config), title=title,
event_config_json=json.dumps(event_config))
def _base_get_notification(uuid):
""" This is a base query for get statements """
return (RepositoryNotification
.select(RepositoryNotification, Repository, Namespace)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(RepositoryNotification.uuid == uuid))
def get_enabled_notification(uuid):
""" This returns a notification with less than 3 failures """
try:
return _base_get_notification(uuid).where(RepositoryNotification.number_of_failures < 3).get()
except RepositoryNotification.DoesNotExist:
raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
def get_repo_notification(uuid):
try:
return _base_get_notification(uuid).get()
except RepositoryNotification.DoesNotExist:
raise InvalidNotificationException('No repository notification found with uuid: %s' % uuid)
def delete_repo_notification(namespace_name, repository_name, uuid):
found = get_repo_notification(uuid)
if found.repository.namespace_user.username != namespace_name or found.repository.name != repository_name:
raise InvalidNotificationException('No repository notifiation found with uuid: %s' % uuid)
found.delete_instance()
return found
def list_repo_notifications(namespace_name, repository_name, event_name=None):
query = (RepositoryNotification
.select(RepositoryNotification, Repository, Namespace)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace_name, Repository.name == repository_name))
if event_name:
query = (query
.switch(RepositoryNotification)
.join(ExternalNotificationEvent)
.where(ExternalNotificationEvent.name == event_name))
return query

434
data/model/oauth.py Normal file
View file

@ -0,0 +1,434 @@
import logging
import json
from flask import url_for
from datetime import datetime, timedelta
from oauth2lib.provider import AuthorizationProvider
from oauth2lib import utils
from active_migration import ActiveDataMigration, ERTMigrationFlags
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
random_string_generator)
from data.fields import DecryptedValue, Credential
from data.model import user, config
from auth import scopes
from util import get_app_url
logger = logging.getLogger(__name__)
ACCESS_TOKEN_PREFIX_LENGTH = 20
ACCESS_TOKEN_MINIMUM_CODE_LENGTH = 20
AUTHORIZATION_CODE_PREFIX_LENGTH = 20
class DatabaseAuthorizationProvider(AuthorizationProvider):
def get_authorized_user(self):
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
def _generate_data_string(self):
return json.dumps({'username': self.get_authorized_user().username})
@property
def token_expires_in(self):
"""Property method to get the token expiration time in seconds.
"""
return int(60*60*24*365.25*10) # 10 Years
def validate_client_id(self, client_id):
return self.get_application_for_client_id(client_id) is not None
def get_application_for_client_id(self, client_id):
try:
return OAuthApplication.get(client_id=client_id)
except OAuthApplication.DoesNotExist:
return None
def validate_client_secret(self, client_id, client_secret):
try:
application = OAuthApplication.get(client_id=client_id)
# TODO(remove-unenc): Remove legacy check.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
if application.secure_client_secret is None:
return application.client_secret == client_secret
assert application.secure_client_secret is not None
return application.secure_client_secret.matches(client_secret)
except OAuthApplication.DoesNotExist:
return False
def validate_redirect_uri(self, client_id, redirect_uri):
internal_redirect_url = '%s%s' % (get_app_url(config.app_config),
url_for('web.oauth_local_handler'))
if redirect_uri == internal_redirect_url:
return True
try:
oauth_app = OAuthApplication.get(client_id=client_id)
if (oauth_app.redirect_uri and redirect_uri and
redirect_uri.startswith(oauth_app.redirect_uri)):
return True
return False
except OAuthApplication.DoesNotExist:
return False
def validate_scope(self, client_id, scopes_string):
return scopes.validate_scope_string(scopes_string)
def validate_access(self):
return self.get_authorized_user() is not None
def load_authorized_scope_string(self, client_id, username):
found = (OAuthAccessToken
.select()
.join(OAuthApplication)
.switch(OAuthAccessToken)
.join(User)
.where(OAuthApplication.client_id == client_id, User.username == username,
OAuthAccessToken.expires_at > datetime.utcnow()))
found = list(found)
logger.debug('Found %s matching tokens.', len(found))
long_scope_string = ','.join([token.scope for token in found])
logger.debug('Computed long scope string: %s', long_scope_string)
return long_scope_string
def validate_has_scopes(self, client_id, username, scope):
long_scope_string = self.load_authorized_scope_string(client_id, username)
# Make sure the token contains the given scopes (at least).
return scopes.is_subset_string(long_scope_string, scope)
def from_authorization_code(self, client_id, full_code, scope):
code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
try:
found = (OAuthAuthorizationCode
.select()
.join(OAuthApplication)
.where(OAuthApplication.client_id == client_id,
OAuthAuthorizationCode.code_name == code_name,
OAuthAuthorizationCode.scope == scope)
.get())
if not found.code_credential.matches(code_credential):
return None
logger.debug('Returning data: %s', found.data)
return found.data
except OAuthAuthorizationCode.DoesNotExist:
# Fallback to the legacy lookup of the full code.
# TODO(remove-unenc): Remove legacy fallback.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
found = (OAuthAuthorizationCode
.select()
.join(OAuthApplication)
.where(OAuthApplication.client_id == client_id,
OAuthAuthorizationCode.code == full_code,
OAuthAuthorizationCode.scope == scope)
.get())
logger.debug('Returning data: %s', found.data)
return found.data
except OAuthAuthorizationCode.DoesNotExist:
return None
else:
return None
def persist_authorization_code(self, client_id, full_code, scope):
oauth_app = OAuthApplication.get(client_id=client_id)
data = self._generate_data_string()
assert len(full_code) >= (AUTHORIZATION_CODE_PREFIX_LENGTH * 2)
code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
code_credential = full_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
# TODO(remove-unenc): Remove legacy fallback.
full_code = None
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
full_code = code_name + code_credential
OAuthAuthorizationCode.create(application=oauth_app,
code=full_code,
scope=scope,
code_name=code_name,
code_credential=Credential.from_string(code_credential),
data=data)
def persist_token_information(self, client_id, scope, access_token, token_type,
expires_in, refresh_token, data):
assert not refresh_token
found = user.get_user(json.loads(data)['username'])
if not found:
raise RuntimeError('Username must be in the data field')
token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
assert token_name
assert token_code
assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
oauth_app = OAuthApplication.get(client_id=client_id)
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
OAuthAccessToken.create(application=oauth_app,
authorized_user=found,
scope=scope,
token_name=token_name,
token_code=Credential.from_string(token_code),
access_token='',
token_type=token_type,
expires_at=expires_at,
data=data)
def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params):
# Ensure proper response_type
if response_type != 'token':
err = 'unsupported_response_type'
return self._make_redirect_error_response(redirect_uri, err)
# Check redirect URI
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
if not is_valid_redirect_uri:
return self._invalid_redirect_uri_response()
return self._make_redirect_error_response(redirect_uri, 'authorization_denied')
def get_token_response(self, response_type, client_id, redirect_uri, **params):
# Ensure proper response_type
if response_type != 'token':
err = 'unsupported_response_type'
return self._make_redirect_error_response(redirect_uri, err)
# Check for a valid client ID.
is_valid_client_id = self.validate_client_id(client_id)
if not is_valid_client_id:
err = 'unauthorized_client'
return self._make_redirect_error_response(redirect_uri, err)
# Check for a valid redirect URI.
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
if not is_valid_redirect_uri:
return self._invalid_redirect_uri_response()
# Check conditions
is_valid_access = self.validate_access()
scope = params.get('scope', '')
are_valid_scopes = self.validate_scope(client_id, scope)
# Return proper error responses on invalid conditions
if not is_valid_access:
err = 'access_denied'
return self._make_redirect_error_response(redirect_uri, err)
if not are_valid_scopes:
err = 'invalid_scope'
return self._make_redirect_error_response(redirect_uri, err)
# Make sure we have enough random data in the token to have a public
# prefix and a private encrypted suffix.
access_token = str(self.generate_access_token())
assert len(access_token) - ACCESS_TOKEN_PREFIX_LENGTH >= 20
token_type = self.token_type
expires_in = self.token_expires_in
data = self._generate_data_string()
self.persist_token_information(client_id=client_id,
scope=scope,
access_token=access_token,
token_type=token_type,
expires_in=expires_in,
refresh_token=None,
data=data)
url = utils.build_url(redirect_uri, params)
url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
return self._make_response(headers={'Location': url}, status_code=302)
def from_refresh_token(self, client_id, refresh_token, scope):
raise NotImplementedError()
def discard_authorization_code(self, client_id, full_code):
code_name = full_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
try:
found = (OAuthAuthorizationCode
.select()
.join(OAuthApplication)
.where(OAuthApplication.client_id == client_id,
OAuthAuthorizationCode.code_name == code_name)
.get())
found.delete_instance()
return
except OAuthAuthorizationCode.DoesNotExist:
pass
# Legacy: full code.
# TODO(remove-unenc): Remove legacy fallback.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
found = (OAuthAuthorizationCode
.select()
.join(OAuthApplication)
.where(OAuthApplication.client_id == client_id,
OAuthAuthorizationCode.code == full_code)
.get())
found.delete_instance()
except OAuthAuthorizationCode.DoesNotExist:
pass
def discard_refresh_token(self, client_id, refresh_token):
raise NotImplementedError()
def create_application(org, name, application_uri, redirect_uri, **kwargs):
client_secret = kwargs.pop('client_secret', random_string_generator(length=40)())
# TODO(remove-unenc): Remove legacy field.
old_client_secret = None
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
old_client_secret = client_secret
return OAuthApplication.create(organization=org,
name=name,
application_uri=application_uri,
redirect_uri=redirect_uri,
client_secret=old_client_secret,
secure_client_secret=DecryptedValue(client_secret),
**kwargs)
def validate_access_token(access_token):
assert isinstance(access_token, basestring)
token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
if not token_name:
return None
token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
if not token_code:
return None
try:
found = (OAuthAccessToken
.select(OAuthAccessToken, User)
.join(User)
.where(OAuthAccessToken.token_name == token_name)
.get())
if found.token_code is None or not found.token_code.matches(token_code):
return None
return found
except OAuthAccessToken.DoesNotExist:
pass
# Legacy lookup.
# TODO(remove-unenc): Remove this once migrated.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
assert access_token
found = (OAuthAccessToken
.select(OAuthAccessToken, User)
.join(User)
.where(OAuthAccessToken.access_token == access_token)
.get())
return found
except OAuthAccessToken.DoesNotExist:
return None
return None
def get_application_for_client_id(client_id):
try:
return OAuthApplication.get(client_id=client_id)
except OAuthApplication.DoesNotExist:
return None
def reset_client_secret(application):
client_secret = random_string_generator(length=40)()
# TODO(remove-unenc): Remove legacy field.
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
application.client_secret = client_secret
application.secure_client_secret = DecryptedValue(client_secret)
application.save()
return application
def lookup_application(org, client_id):
try:
return OAuthApplication.get(organization=org, client_id=client_id)
except OAuthApplication.DoesNotExist:
return None
def delete_application(org, client_id):
application = lookup_application(org, client_id)
if not application:
return
application.delete_instance(recursive=True, delete_nullable=True)
return application
def lookup_access_token_by_uuid(token_uuid):
try:
return OAuthAccessToken.get(OAuthAccessToken.uuid == token_uuid)
except OAuthAccessToken.DoesNotExist:
return None
def lookup_access_token_for_user(user_obj, token_uuid):
try:
return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user_obj,
OAuthAccessToken.uuid == token_uuid)
except OAuthAccessToken.DoesNotExist:
return None
def list_access_tokens_for_user(user_obj):
query = (OAuthAccessToken
.select()
.join(OAuthApplication)
.switch(OAuthAccessToken)
.join(User)
.where(OAuthAccessToken.authorized_user == user_obj))
return query
def list_applications_for_org(org):
query = (OAuthApplication
.select()
.join(User)
.where(OAuthApplication.organization == org))
return query
def create_access_token_for_testing(user_obj, client_id, scope, access_token=None, expires_in=9000):
access_token = access_token or random_string_generator(length=40)()
token_name = access_token[:ACCESS_TOKEN_PREFIX_LENGTH]
token_code = access_token[ACCESS_TOKEN_PREFIX_LENGTH:]
assert len(token_name) == ACCESS_TOKEN_PREFIX_LENGTH
assert len(token_code) >= ACCESS_TOKEN_MINIMUM_CODE_LENGTH
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
application = get_application_for_client_id(client_id)
created = OAuthAccessToken.create(application=application,
authorized_user=user_obj,
scope=scope,
token_type='token',
access_token='',
token_code=Credential.from_string(token_code),
token_name=token_name,
expires_at=expires_at,
data='')
return created, access_token

View file

@ -0,0 +1,9 @@
# There MUST NOT be any circular dependencies between these subsections. If there are fix it by
# moving the minimal number of things to shared
from data.model.oci import (
blob,
label,
manifest,
shared,
tag,
)

26
data/model/oci/blob.py Normal file
View file

@ -0,0 +1,26 @@
from data.database import ImageStorage, ManifestBlob
from data.model import BlobDoesNotExist
from data.model.storage import get_storage_by_uuid, InvalidImageException
from data.model.blob import get_repository_blob_by_digest as legacy_get
def get_repository_blob_by_digest(repository, blob_digest):
""" Find the content-addressable blob linked to the specified repository and
returns it or None if none.
"""
try:
storage = (ImageStorage
.select(ImageStorage.uuid)
.join(ManifestBlob)
.where(ManifestBlob.repository == repository,
ImageStorage.content_checksum == blob_digest,
ImageStorage.uploading == False)
.get())
return get_storage_by_uuid(storage.uuid)
except (ImageStorage.DoesNotExist, InvalidImageException):
# TODO: Remove once we are no longer using the legacy tables.
# Try the legacy call.
try:
return legacy_get(repository, blob_digest)
except BlobDoesNotExist:
return None

142
data/model/oci/label.py Normal file
View file

@ -0,0 +1,142 @@
import logging
from data.model import InvalidLabelKeyException, InvalidMediaTypeException, DataModelException
from data.database import (Label, Manifest, TagManifestLabel, MediaType, LabelSourceType,
db_transaction, ManifestLabel, TagManifestLabelMap,
TagManifestToManifest, Repository, TagManifest)
from data.text import prefix_search
from util.validation import validate_label_key
from util.validation import is_json
logger = logging.getLogger(__name__)
def list_manifest_labels(manifest_id, prefix_filter=None):
""" Lists all labels found on the given manifest, with an optional filter by key prefix. """
query = (Label
.select(Label, MediaType)
.join(MediaType)
.switch(Label)
.join(LabelSourceType)
.switch(Label)
.join(ManifestLabel)
.where(ManifestLabel.manifest == manifest_id))
if prefix_filter is not None:
query = query.where(prefix_search(Label.key, prefix_filter))
return query
def get_manifest_label(label_uuid, manifest):
""" Retrieves the manifest label on the manifest with the given UUID or None if none. """
try:
return (Label
.select(Label, LabelSourceType)
.join(LabelSourceType)
.where(Label.uuid == label_uuid)
.switch(Label)
.join(ManifestLabel)
.where(ManifestLabel.manifest == manifest)
.get())
except Label.DoesNotExist:
return None
def create_manifest_label(manifest_id, key, value, source_type_name, media_type_name=None,
adjust_old_model=True):
""" Creates a new manifest label on a specific tag manifest. """
if not key:
raise InvalidLabelKeyException()
# Note that we don't prevent invalid label names coming from the manifest to be stored, as Docker
# does not currently prevent them from being put into said manifests.
if not validate_label_key(key) and source_type_name != 'manifest':
raise InvalidLabelKeyException('Key `%s` is invalid' % key)
# Find the matching media type. If none specified, we infer.
if media_type_name is None:
media_type_name = 'text/plain'
if is_json(value):
media_type_name = 'application/json'
try:
media_type_id = Label.media_type.get_id(media_type_name)
except MediaType.DoesNotExist:
raise InvalidMediaTypeException()
source_type_id = Label.source_type.get_id(source_type_name)
# Ensure the manifest exists.
try:
manifest = (Manifest
.select(Manifest, Repository)
.join(Repository)
.where(Manifest.id == manifest_id)
.get())
except Manifest.DoesNotExist:
return None
repository = manifest.repository
# TODO: Remove this code once the TagManifest table is gone.
tag_manifest = None
if adjust_old_model:
try:
mapping_row = (TagManifestToManifest
.select(TagManifestToManifest, TagManifest)
.join(TagManifest)
.where(TagManifestToManifest.manifest == manifest)
.get())
tag_manifest = mapping_row.tag_manifest
except TagManifestToManifest.DoesNotExist:
tag_manifest = None
with db_transaction():
label = Label.create(key=key, value=value, source_type=source_type_id, media_type=media_type_id)
manifest_label = ManifestLabel.create(manifest=manifest_id, label=label, repository=repository)
# If there exists a mapping to a TagManifest, add the old-style label.
# TODO: Remove this code once the TagManifest table is gone.
if tag_manifest:
tag_manifest_label = TagManifestLabel.create(annotated=tag_manifest, label=label,
repository=repository)
TagManifestLabelMap.create(manifest_label=manifest_label,
tag_manifest_label=tag_manifest_label,
label=label,
manifest=manifest,
tag_manifest=tag_manifest)
return label
def delete_manifest_label(label_uuid, manifest):
""" Deletes the manifest label on the tag manifest with the given ID. Returns the label deleted
or None if none.
"""
# Find the label itself.
label = get_manifest_label(label_uuid, manifest)
if label is None:
return None
if not label.source_type.mutable:
raise DataModelException('Cannot delete immutable label')
# Delete the mapping records and label.
# TODO: Remove this code once the TagManifest table is gone.
with db_transaction():
(TagManifestLabelMap
.delete()
.where(TagManifestLabelMap.label == label)
.execute())
deleted_count = TagManifestLabel.delete().where(TagManifestLabel.label == label).execute()
if deleted_count != 1:
logger.warning('More than a single label deleted for matching label %s', label_uuid)
deleted_count = ManifestLabel.delete().where(ManifestLabel.label == label).execute()
if deleted_count != 1:
logger.warning('More than a single label deleted for matching label %s', label_uuid)
label.delete_instance(recursive=False)
return label

321
data/model/oci/manifest.py Normal file
View file

@ -0,0 +1,321 @@
import logging
from collections import namedtuple
from peewee import IntegrityError
from data.database import (Tag, Manifest, ManifestBlob, ManifestLegacyImage, ManifestChild,
db_transaction)
from data.model import BlobDoesNotExist
from data.model.blob import get_or_create_shared_blob, get_shared_blob
from data.model.oci.tag import filter_to_alive_tags, create_temporary_tag_if_necessary
from data.model.oci.label import create_manifest_label
from data.model.oci.retriever import RepositoryContentRetriever
from data.model.storage import lookup_repo_storages_by_content_checksum
from data.model.image import lookup_repository_images, get_image, synthesize_v1_image
from image.docker.schema2 import EMPTY_LAYER_BLOB_DIGEST, EMPTY_LAYER_BYTES
from image.docker.schema1 import ManifestException
from image.docker.schema2.list import MalformedSchema2ManifestList
from util.validation import is_json
TEMP_TAG_EXPIRATION_SEC = 300 # 5 minutes
logger = logging.getLogger(__name__)
CreatedManifest = namedtuple('CreatedManifest', ['manifest', 'newly_created', 'labels_to_apply'])
class CreateManifestException(Exception):
""" Exception raised when creating a manifest fails and explicit exception
raising is requested. """
def lookup_manifest(repository_id, manifest_digest, allow_dead=False, require_available=False,
temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC):
""" Returns the manifest with the specified digest under the specified repository
or None if none. If allow_dead is True, then manifests referenced by only
dead tags will also be returned. If require_available is True, the manifest
will be marked with a temporary tag to ensure it remains available.
"""
if not require_available:
return _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
with db_transaction():
found = _lookup_manifest(repository_id, manifest_digest, allow_dead=allow_dead)
if found is None:
return None
create_temporary_tag_if_necessary(found, temp_tag_expiration_sec)
return found
def _lookup_manifest(repository_id, manifest_digest, allow_dead=False):
query = (Manifest
.select()
.where(Manifest.repository == repository_id)
.where(Manifest.digest == manifest_digest))
if allow_dead:
try:
return query.get()
except Manifest.DoesNotExist:
return None
# Try first to filter to those manifests referenced by an alive tag,
try:
return filter_to_alive_tags(query.join(Tag)).get()
except Manifest.DoesNotExist:
pass
# Try referenced as the child of a manifest that has an alive tag.
query = (query
.join(ManifestChild, on=(ManifestChild.child_manifest == Manifest.id))
.join(Tag, on=(Tag.manifest == ManifestChild.manifest)))
query = filter_to_alive_tags(query)
try:
return query.get()
except Manifest.DoesNotExist:
return None
def get_or_create_manifest(repository_id, manifest_interface_instance, storage,
temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
for_tagging=False, raise_on_error=False):
""" Returns a CreatedManifest for the manifest in the specified repository with the matching
digest (if it already exists) or, if not yet created, creates and returns the manifest.
Returns None if there was an error creating the manifest, unless raise_on_error is specified,
in which case a CreateManifestException exception will be raised instead to provide more
context to the error.
Note that *all* blobs referenced by the manifest must exist already in the repository or this
method will fail with a None.
"""
existing = lookup_manifest(repository_id, manifest_interface_instance.digest, allow_dead=True,
require_available=True,
temp_tag_expiration_sec=temp_tag_expiration_sec)
if existing is not None:
return CreatedManifest(manifest=existing, newly_created=False, labels_to_apply=None)
return _create_manifest(repository_id, manifest_interface_instance, storage,
temp_tag_expiration_sec, for_tagging=for_tagging,
raise_on_error=raise_on_error)
def _create_manifest(repository_id, manifest_interface_instance, storage,
temp_tag_expiration_sec=TEMP_TAG_EXPIRATION_SEC,
for_tagging=False, raise_on_error=False):
# Validate the manifest.
retriever = RepositoryContentRetriever.for_repository(repository_id, storage)
try:
manifest_interface_instance.validate(retriever)
except (ManifestException, MalformedSchema2ManifestList, BlobDoesNotExist, IOError) as ex:
logger.exception('Could not validate manifest `%s`', manifest_interface_instance.digest)
if raise_on_error:
raise CreateManifestException(ex)
return None
# Load, parse and get/create the child manifests, if any.
child_manifest_refs = manifest_interface_instance.child_manifests(retriever)
child_manifest_rows = {}
child_manifest_label_dicts = []
if child_manifest_refs is not None:
for child_manifest_ref in child_manifest_refs:
# Load and parse the child manifest.
try:
child_manifest = child_manifest_ref.manifest_obj
except (ManifestException, MalformedSchema2ManifestList, BlobDoesNotExist, IOError) as ex:
logger.exception('Could not load manifest list for manifest `%s`',
manifest_interface_instance.digest)
if raise_on_error:
raise CreateManifestException(ex)
return None
# Retrieve its labels.
labels = child_manifest.get_manifest_labels(retriever)
if labels is None:
logger.exception('Could not load manifest labels for child manifest')
return None
# Get/create the child manifest in the database.
child_manifest_info = get_or_create_manifest(repository_id, child_manifest, storage,
raise_on_error=raise_on_error)
if child_manifest_info is None:
logger.error('Could not get/create child manifest')
return None
child_manifest_rows[child_manifest_info.manifest.digest] = child_manifest_info.manifest
child_manifest_label_dicts.append(labels)
# Ensure all the blobs in the manifest exist.
digests = set(manifest_interface_instance.local_blob_digests)
blob_map = {}
# If the special empty layer is required, simply load it directly. This is much faster
# than trying to load it on a per repository basis, and that is unnecessary anyway since
# this layer is predefined.
if EMPTY_LAYER_BLOB_DIGEST in digests:
digests.remove(EMPTY_LAYER_BLOB_DIGEST)
blob_map[EMPTY_LAYER_BLOB_DIGEST] = get_shared_blob(EMPTY_LAYER_BLOB_DIGEST)
if not blob_map[EMPTY_LAYER_BLOB_DIGEST]:
logger.warning('Could not find the special empty blob in storage')
return None
if digests:
query = lookup_repo_storages_by_content_checksum(repository_id, digests)
blob_map.update({s.content_checksum: s for s in query})
for digest_str in digests:
if digest_str not in blob_map:
logger.warning('Unknown blob `%s` under manifest `%s` for repository `%s`', digest_str,
manifest_interface_instance.digest, repository_id)
if raise_on_error:
raise CreateManifestException('Unknown blob `%s`' % digest_str)
return None
# Special check: If the empty layer blob is needed for this manifest, add it to the
# blob map. This is necessary because Docker decided to elide sending of this special
# empty layer in schema version 2, but we need to have it referenced for GC and schema version 1.
if EMPTY_LAYER_BLOB_DIGEST not in blob_map:
if manifest_interface_instance.get_requires_empty_layer_blob(retriever):
shared_blob = get_or_create_shared_blob(EMPTY_LAYER_BLOB_DIGEST, EMPTY_LAYER_BYTES, storage)
assert not shared_blob.uploading
assert shared_blob.content_checksum == EMPTY_LAYER_BLOB_DIGEST
blob_map[EMPTY_LAYER_BLOB_DIGEST] = shared_blob
# Determine and populate the legacy image if necessary. Manifest lists will not have a legacy
# image.
legacy_image = None
if manifest_interface_instance.has_legacy_image:
legacy_image_id = _populate_legacy_image(repository_id, manifest_interface_instance, blob_map,
retriever)
if legacy_image_id is None:
return None
legacy_image = get_image(repository_id, legacy_image_id)
if legacy_image is None:
return None
# Create the manifest and its blobs.
media_type = Manifest.media_type.get_id(manifest_interface_instance.media_type)
storage_ids = {storage.id for storage in blob_map.values()}
with db_transaction():
# Check for the manifest. This is necessary because Postgres doesn't handle IntegrityErrors
# well under transactions.
try:
manifest = Manifest.get(repository=repository_id, digest=manifest_interface_instance.digest)
return CreatedManifest(manifest=manifest, newly_created=False, labels_to_apply=None)
except Manifest.DoesNotExist:
pass
# Create the manifest.
try:
manifest = Manifest.create(repository=repository_id,
digest=manifest_interface_instance.digest,
media_type=media_type,
manifest_bytes=manifest_interface_instance.bytes.as_encoded_str())
except IntegrityError:
manifest = Manifest.get(repository=repository_id, digest=manifest_interface_instance.digest)
return CreatedManifest(manifest=manifest, newly_created=False, labels_to_apply=None)
# Insert the blobs.
blobs_to_insert = [dict(manifest=manifest, repository=repository_id,
blob=storage_id) for storage_id in storage_ids]
if blobs_to_insert:
ManifestBlob.insert_many(blobs_to_insert).execute()
# Set the legacy image (if applicable).
if legacy_image is not None:
ManifestLegacyImage.create(repository=repository_id, image=legacy_image, manifest=manifest)
# Insert the manifest child rows (if applicable).
if child_manifest_rows:
children_to_insert = [dict(manifest=manifest, child_manifest=child_manifest,
repository=repository_id)
for child_manifest in child_manifest_rows.values()]
ManifestChild.insert_many(children_to_insert).execute()
# If this manifest is being created not for immediate tagging, add a temporary tag to the
# manifest to ensure it isn't being GCed. If the manifest *is* for tagging, then since we're
# creating a new one here, it cannot be GCed (since it isn't referenced by anything yet), so
# its safe to elide the temp tag operation. If we ever change GC code to collect *all* manifests
# in a repository for GC, then we will have to reevaluate this optimization at that time.
if not for_tagging:
create_temporary_tag_if_necessary(manifest, temp_tag_expiration_sec)
# Define the labels for the manifest (if any).
labels = manifest_interface_instance.get_manifest_labels(retriever)
if labels:
for key, value in labels.iteritems():
media_type = 'application/json' if is_json(value) else 'text/plain'
create_manifest_label(manifest, key, value, 'manifest', media_type)
# Return the dictionary of labels to apply (i.e. those labels that cause an action to be taken
# on the manifest or its resulting tags). We only return those labels either defined on
# the manifest or shared amongst all the child manifests. We intersect amongst all child manifests
# to ensure that any action performed is defined in all manifests.
labels_to_apply = labels or {}
if child_manifest_label_dicts:
labels_to_apply = child_manifest_label_dicts[0].viewitems()
for child_manifest_label_dict in child_manifest_label_dicts[1:]:
# Intersect the key+values of the labels to ensure we get the exact same result
# for all the child manifests.
labels_to_apply = labels_to_apply & child_manifest_label_dict.viewitems()
labels_to_apply = dict(labels_to_apply)
return CreatedManifest(manifest=manifest, newly_created=True, labels_to_apply=labels_to_apply)
def _populate_legacy_image(repository_id, manifest_interface_instance, blob_map, retriever):
# Lookup all the images and their parent images (if any) inside the manifest.
# This will let us know which v1 images we need to synthesize and which ones are invalid.
docker_image_ids = list(manifest_interface_instance.get_legacy_image_ids(retriever))
images_query = lookup_repository_images(repository_id, docker_image_ids)
image_storage_map = {i.docker_image_id: i.storage for i in images_query}
# Rewrite any v1 image IDs that do not match the checksum in the database.
try:
rewritten_images = manifest_interface_instance.generate_legacy_layers(image_storage_map,
retriever)
rewritten_images = list(rewritten_images)
parent_image_map = {}
for rewritten_image in rewritten_images:
if not rewritten_image.image_id in image_storage_map:
parent_image = None
if rewritten_image.parent_image_id:
parent_image = parent_image_map.get(rewritten_image.parent_image_id)
if parent_image is None:
parent_image = get_image(repository_id, rewritten_image.parent_image_id)
if parent_image is None:
return None
storage_reference = blob_map[rewritten_image.content_checksum]
synthesized = synthesize_v1_image(
repository_id,
storage_reference.id,
storage_reference.image_size,
rewritten_image.image_id,
rewritten_image.created,
rewritten_image.comment,
rewritten_image.command,
rewritten_image.compat_json,
parent_image,
)
parent_image_map[rewritten_image.image_id] = synthesized
except ManifestException:
logger.exception("exception when rewriting v1 metadata")
return None
return rewritten_images[-1].image_id

View file

@ -0,0 +1,37 @@
from image.docker.interfaces import ContentRetriever
from data.database import Manifest
from data.model.oci.blob import get_repository_blob_by_digest
from data.model.storage import get_layer_path
class RepositoryContentRetriever(ContentRetriever):
""" Implementation of the ContentRetriever interface for manifests that retrieves
config blobs and child manifests for the specified repository.
"""
def __init__(self, repository_id, storage):
self.repository_id = repository_id
self.storage = storage
@classmethod
def for_repository(cls, repository_id, storage):
return RepositoryContentRetriever(repository_id, storage)
def get_manifest_bytes_with_digest(self, digest):
""" Returns the bytes of the manifest with the given digest or None if none found. """
query = (Manifest
.select()
.where(Manifest.repository == self.repository_id)
.where(Manifest.digest == digest))
try:
return query.get().manifest_bytes
except Manifest.DoesNotExist:
return None
def get_blob_bytes_with_digest(self, digest):
""" Returns the bytes of the blob with the given digest or None if none found. """
blob = get_repository_blob_by_digest(self.repository_id, digest)
if blob is None:
return None
assert blob.locations is not None
return self.storage.get_content(blob.locations, get_layer_path(blob))

24
data/model/oci/shared.py Normal file
View file

@ -0,0 +1,24 @@
from data.database import Manifest, ManifestLegacyImage, Image
def get_legacy_image_for_manifest(manifest_id):
""" Returns the legacy image associated with the given manifest, if any, or None if none. """
try:
query = (ManifestLegacyImage
.select(ManifestLegacyImage, Image)
.join(Image)
.where(ManifestLegacyImage.manifest == manifest_id))
return query.get().image
except ManifestLegacyImage.DoesNotExist:
return None
def get_manifest_for_legacy_image(image_id):
""" Returns a manifest that is associated with the given image, if any, or None if none. """
try:
query = (ManifestLegacyImage
.select(ManifestLegacyImage, Manifest)
.join(Manifest)
.where(ManifestLegacyImage.image == image_id))
return query.get().manifest
except ManifestLegacyImage.DoesNotExist:
return None

505
data/model/oci/tag.py Normal file
View file

@ -0,0 +1,505 @@
import uuid
import logging
from calendar import timegm
from peewee import fn
from data.database import (Tag, Manifest, ManifestLegacyImage, Image, ImageStorage,
MediaType, RepositoryTag, TagManifest, TagManifestToManifest,
get_epoch_timestamp_ms, db_transaction, Repository,
TagToRepositoryTag, Namespace, RepositoryNotification,
ExternalNotificationEvent)
from data.model.oci.shared import get_legacy_image_for_manifest
from data.model import config
from image.docker.schema1 import (DOCKER_SCHEMA1_CONTENT_TYPES, DockerSchema1Manifest,
MalformedSchema1Manifest)
from util.bytes import Bytes
from util.timedeltastring import convert_to_timedelta
logger = logging.getLogger(__name__)
def get_tag_by_id(tag_id):
""" Returns the tag with the given ID, joined with its manifest or None if none. """
try:
return Tag.select(Tag, Manifest).join(Manifest).where(Tag.id == tag_id).get()
except Tag.DoesNotExist:
return None
def get_tag(repository_id, tag_name):
""" Returns the alive, non-hidden tag with the given name under the specified repository or
None if none. The tag is returned joined with its manifest.
"""
query = (Tag
.select(Tag, Manifest)
.join(Manifest)
.where(Tag.repository == repository_id)
.where(Tag.name == tag_name))
query = filter_to_alive_tags(query)
try:
found = query.get()
assert not found.hidden
return found
except Tag.DoesNotExist:
return None
def lookup_alive_tags_shallow(repository_id, start_pagination_id=None, limit=None):
""" Returns a list of the tags alive in the specified repository. Note that the tags returned
*only* contain their ID and name. Also note that the Tags are returned ordered by ID.
"""
query = (Tag
.select(Tag.id, Tag.name)
.where(Tag.repository == repository_id)
.order_by(Tag.id))
if start_pagination_id is not None:
query = query.where(Tag.id >= start_pagination_id)
if limit is not None:
query = query.limit(limit)
return filter_to_alive_tags(query)
def list_alive_tags(repository_id):
""" Returns a list of all the tags alive in the specified repository.
Tag's returned are joined with their manifest.
"""
query = (Tag
.select(Tag, Manifest)
.join(Manifest)
.where(Tag.repository == repository_id))
return filter_to_alive_tags(query)
def list_repository_tag_history(repository_id, page, page_size, specific_tag_name=None,
active_tags_only=False, since_time_ms=None):
""" Returns a tuple of the full set of tags found in the specified repository, including those
that are no longer alive (unless active_tags_only is True), and whether additional tags exist.
If specific_tag_name is given, the tags are further filtered by name. If since is given, tags
are further filtered to newer than that date.
Note that the returned Manifest will not contain the manifest contents.
"""
query = (Tag
.select(Tag, Manifest.id, Manifest.digest, Manifest.media_type)
.join(Manifest)
.where(Tag.repository == repository_id)
.order_by(Tag.lifetime_start_ms.desc(), Tag.name)
.limit(page_size + 1)
.offset(page_size * (page - 1)))
if specific_tag_name is not None:
query = query.where(Tag.name == specific_tag_name)
if since_time_ms is not None:
query = query.where((Tag.lifetime_start_ms > since_time_ms) | (Tag.lifetime_end_ms > since_time_ms))
if active_tags_only:
query = filter_to_alive_tags(query)
query = filter_to_visible_tags(query)
results = list(query)
return results[0:page_size], len(results) > page_size
def get_legacy_images_for_tags(tags):
""" Returns a map from tag ID to the legacy image for the tag. """
if not tags:
return {}
query = (ManifestLegacyImage
.select(ManifestLegacyImage, Image, ImageStorage)
.join(Image)
.join(ImageStorage)
.where(ManifestLegacyImage.manifest << [tag.manifest_id for tag in tags]))
by_manifest = {mli.manifest_id: mli.image for mli in query}
return {tag.id: by_manifest[tag.manifest_id] for tag in tags if tag.manifest_id in by_manifest}
def find_matching_tag(repository_id, tag_names, tag_kinds=None):
""" Finds an alive tag in the specified repository with one of the specified tag names and
returns it or None if none. Tag's returned are joined with their manifest.
"""
assert repository_id
assert tag_names
query = (Tag
.select(Tag, Manifest)
.join(Manifest)
.where(Tag.repository == repository_id)
.where(Tag.name << tag_names))
if tag_kinds:
query = query.where(Tag.tag_kind << tag_kinds)
try:
found = filter_to_alive_tags(query).get()
assert not found.hidden
return found
except Tag.DoesNotExist:
return None
def get_most_recent_tag_lifetime_start(repository_ids):
""" Returns a map from repo ID to the timestamp of the most recently pushed alive tag
for each specified repository or None if none.
"""
assert len(repository_ids) > 0 and None not in repository_ids
query = (Tag.select(Tag.repository, fn.Max(Tag.lifetime_start_ms))
.where(Tag.repository << [repo_id for repo_id in repository_ids])
.group_by(Tag.repository))
tuples = filter_to_alive_tags(query).tuples()
return {repo_id: timestamp for repo_id, timestamp in tuples}
def get_most_recent_tag(repository_id):
""" Returns the most recently pushed alive tag in the specified repository or None if none.
The Tag returned is joined with its manifest.
"""
assert repository_id
query = (Tag
.select(Tag, Manifest)
.join(Manifest)
.where(Tag.repository == repository_id)
.order_by(Tag.lifetime_start_ms.desc()))
try:
found = filter_to_alive_tags(query).get()
assert not found.hidden
return found
except Tag.DoesNotExist:
return None
def get_expired_tag(repository_id, tag_name):
""" Returns a tag with the given name that is expired in the repository or None if none.
"""
try:
return (Tag
.select()
.where(Tag.name == tag_name, Tag.repository == repository_id)
.where(~(Tag.lifetime_end_ms >> None))
.where(Tag.lifetime_end_ms <= get_epoch_timestamp_ms())
.get())
except Tag.DoesNotExist:
return None
def create_temporary_tag_if_necessary(manifest, expiration_sec):
""" Creates a temporary tag pointing to the given manifest, with the given expiration in seconds,
unless there is an existing tag that will keep the manifest around.
"""
tag_name = '$temp-%s' % str(uuid.uuid4())
now_ms = get_epoch_timestamp_ms()
end_ms = now_ms + (expiration_sec * 1000)
# Check if there is an existing tag on the manifest that won't expire within the
# timeframe. If so, no need for a temporary tag.
with db_transaction():
try:
(Tag
.select()
.where(Tag.manifest == manifest,
(Tag.lifetime_end_ms >> None) | (Tag.lifetime_end_ms >= end_ms))
.get())
return None
except Tag.DoesNotExist:
pass
return Tag.create(name=tag_name,
repository=manifest.repository_id,
lifetime_start_ms=now_ms,
lifetime_end_ms=end_ms,
reversion=False,
hidden=True,
manifest=manifest,
tag_kind=Tag.tag_kind.get_id('tag'))
def retarget_tag(tag_name, manifest_id, is_reversion=False, now_ms=None, adjust_old_model=True):
""" Creates or updates a tag with the specified name to point to the given manifest under
its repository. If this action is a reversion to a previous manifest, is_reversion
should be set to True. Returns the newly created tag row or None on error.
"""
try:
manifest = (Manifest
.select(Manifest, MediaType)
.join(MediaType)
.where(Manifest.id == manifest_id)
.get())
except Manifest.DoesNotExist:
return None
# CHECK: Make sure that we are not mistargeting a schema 1 manifest to a tag with a different
# name.
if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES:
try:
parsed = DockerSchema1Manifest(Bytes.for_string_or_unicode(manifest.manifest_bytes),
validate=False)
if parsed.tag != tag_name:
logger.error('Tried to re-target schema1 manifest with tag `%s` to tag `%s', parsed.tag,
tag_name)
return None
except MalformedSchema1Manifest:
logger.exception('Could not parse schema1 manifest')
return None
legacy_image = get_legacy_image_for_manifest(manifest)
now_ms = now_ms or get_epoch_timestamp_ms()
now_ts = int(now_ms / 1000)
with db_transaction():
# Lookup an existing tag in the repository with the same name and, if present, mark it
# as expired.
existing_tag = get_tag(manifest.repository_id, tag_name)
if existing_tag is not None:
_, okay = set_tag_end_ms(existing_tag, now_ms)
# TODO: should we retry here and/or use a for-update?
if not okay:
return None
# Create a new tag pointing to the manifest with a lifetime start of now.
created = Tag.create(name=tag_name, repository=manifest.repository_id, lifetime_start_ms=now_ms,
reversion=is_reversion, manifest=manifest,
tag_kind=Tag.tag_kind.get_id('tag'))
# TODO: Remove the linkage code once RepositoryTag is gone.
# If this is a schema 1 manifest, then add a TagManifest linkage to it. Otherwise, it will only
# be pullable via the new OCI model.
if adjust_old_model:
if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES and legacy_image is not None:
old_style_tag = RepositoryTag.create(repository=manifest.repository_id, image=legacy_image,
name=tag_name, lifetime_start_ts=now_ts,
reversion=is_reversion)
TagToRepositoryTag.create(tag=created, repository_tag=old_style_tag,
repository=manifest.repository_id)
tag_manifest = TagManifest.create(tag=old_style_tag, digest=manifest.digest,
json_data=manifest.manifest_bytes)
TagManifestToManifest.create(tag_manifest=tag_manifest, manifest=manifest,
repository=manifest.repository_id)
return created
def delete_tag(repository_id, tag_name):
""" Deletes the alive tag with the given name in the specified repository and returns the deleted
tag. If the tag did not exist, returns None.
"""
tag = get_tag(repository_id, tag_name)
if tag is None:
return None
return _delete_tag(tag, get_epoch_timestamp_ms())
def _delete_tag(tag, now_ms):
""" Deletes the given tag by marking it as expired. """
now_ts = int(now_ms / 1000)
with db_transaction():
updated = (Tag
.update(lifetime_end_ms=now_ms)
.where(Tag.id == tag.id, Tag.lifetime_end_ms == tag.lifetime_end_ms)
.execute())
if updated != 1:
return None
# TODO: Remove the linkage code once RepositoryTag is gone.
try:
old_style_tag = (TagToRepositoryTag
.select(TagToRepositoryTag, RepositoryTag)
.join(RepositoryTag)
.where(TagToRepositoryTag.tag == tag)
.get()).repository_tag
old_style_tag.lifetime_end_ts = now_ts
old_style_tag.save()
except TagToRepositoryTag.DoesNotExist:
pass
return tag
def delete_tags_for_manifest(manifest):
""" Deletes all tags pointing to the given manifest. Returns the list of tags
deleted.
"""
query = Tag.select().where(Tag.manifest == manifest)
query = filter_to_alive_tags(query)
query = filter_to_visible_tags(query)
tags = list(query)
now_ms = get_epoch_timestamp_ms()
with db_transaction():
for tag in tags:
_delete_tag(tag, now_ms)
return tags
def filter_to_visible_tags(query):
""" Adjusts the specified Tag query to only return those tags that are visible.
"""
return query.where(Tag.hidden == False)
def filter_to_alive_tags(query, now_ms=None, model=Tag):
""" Adjusts the specified Tag query to only return those tags alive. If now_ms is specified,
the given timestamp (in MS) is used in place of the current timestamp for determining wherther
a tag is alive.
"""
if now_ms is None:
now_ms = get_epoch_timestamp_ms()
return (query.where((model.lifetime_end_ms >> None) | (model.lifetime_end_ms > now_ms))
.where(model.hidden == False))
def set_tag_expiration_sec_for_manifest(manifest_id, expiration_seconds):
""" Sets the tag expiration for any tags that point to the given manifest ID. """
query = Tag.select().where(Tag.manifest == manifest_id)
query = filter_to_alive_tags(query)
tags = list(query)
for tag in tags:
assert not tag.hidden
set_tag_end_ms(tag, tag.lifetime_start_ms + (expiration_seconds * 1000))
return tags
def set_tag_expiration_for_manifest(manifest_id, expiration_datetime):
""" Sets the tag expiration for any tags that point to the given manifest ID. """
query = Tag.select().where(Tag.manifest == manifest_id)
query = filter_to_alive_tags(query)
tags = list(query)
for tag in tags:
assert not tag.hidden
change_tag_expiration(tag, expiration_datetime)
return tags
def change_tag_expiration(tag_id, expiration_datetime):
""" Changes the expiration of the specified tag to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring. Returns
a tuple of the previous expiration timestamp in seconds (if any), and whether the
operation succeeded.
"""
try:
tag = Tag.get(id=tag_id)
except Tag.DoesNotExist:
return (None, False)
new_end_ms = None
min_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MINIMUM', '1h'))
max_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MAXIMUM', '104w'))
if expiration_datetime is not None:
lifetime_start_ts = int(tag.lifetime_start_ms / 1000)
offset = timegm(expiration_datetime.utctimetuple()) - lifetime_start_ts
offset = min(max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds())
new_end_ms = tag.lifetime_start_ms + (offset * 1000)
if new_end_ms == tag.lifetime_end_ms:
return (None, True)
return set_tag_end_ms(tag, new_end_ms)
def lookup_unrecoverable_tags(repo):
""" Returns the tags in a repository that are expired and past their time machine recovery
period. """
expired_clause = get_epoch_timestamp_ms() - (Namespace.removed_tag_expiration_s * 1000)
return (Tag
.select()
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Tag.repository == repo)
.where(~(Tag.lifetime_end_ms >> None), Tag.lifetime_end_ms <= expired_clause))
def set_tag_end_ms(tag, end_ms):
""" Sets the end timestamp for a tag. Should only be called by change_tag_expiration
or tests.
"""
with db_transaction():
updated = (Tag
.update(lifetime_end_ms=end_ms)
.where(Tag.id == tag)
.where(Tag.lifetime_end_ms == tag.lifetime_end_ms)
.execute())
if updated != 1:
return (None, False)
# TODO: Remove the linkage code once RepositoryTag is gone.
try:
old_style_tag = (TagToRepositoryTag
.select(TagToRepositoryTag, RepositoryTag)
.join(RepositoryTag)
.where(TagToRepositoryTag.tag == tag)
.get()).repository_tag
old_style_tag.lifetime_end_ts = end_ms / 1000 if end_ms is not None else None
old_style_tag.save()
except TagToRepositoryTag.DoesNotExist:
pass
return (tag.lifetime_end_ms, True)
def tags_containing_legacy_image(image):
""" Yields all alive Tags containing the given image as a legacy image, somewhere in its
legacy image hierarchy.
"""
ancestors_str = '%s%s/%%' % (image.ancestors, image.id)
tags = (Tag
.select()
.join(Repository)
.switch(Tag)
.join(Manifest)
.join(ManifestLegacyImage)
.join(Image)
.where(Tag.repository == image.repository_id)
.where(Image.repository == image.repository_id)
.where((Image.id == image.id) |
(Image.ancestors ** ancestors_str)))
return filter_to_alive_tags(tags)
def lookup_notifiable_tags_for_legacy_image(docker_image_id, storage_uuid, event_name):
""" Yields any alive Tags found in repositories with an event with the given name registered
and whose legacy Image has the given docker image ID and storage UUID.
"""
event = ExternalNotificationEvent.get(name=event_name)
images = (Image
.select()
.join(ImageStorage)
.where(Image.docker_image_id == docker_image_id,
ImageStorage.uuid == storage_uuid))
for image in list(images):
# Ensure the image is under a repository that supports the event.
try:
RepositoryNotification.get(repository=image.repository_id, event=event)
except RepositoryNotification.DoesNotExist:
continue
# If found in a repository with the valid event, yield the tag(s) that contains the image.
for tag in tags_containing_legacy_image(image):
yield tag

View file

View file

@ -0,0 +1,87 @@
import pytest
from playhouse.test_utils import assert_query_count
from data.database import Manifest, ManifestLabel
from data.model.oci.label import (create_manifest_label, list_manifest_labels, get_manifest_label,
delete_manifest_label, DataModelException)
from test.fixtures import *
@pytest.mark.parametrize('key, value, source_type, expected_error', [
('foo', 'bar', 'manifest', None),
pytest.param('..foo', 'bar', 'manifest', None, id='invalid key on manifest'),
pytest.param('..foo', 'bar', 'api', 'is invalid', id='invalid key on api'),
])
def test_create_manifest_label(key, value, source_type, expected_error, initialized_db):
manifest = Manifest.get()
if expected_error:
with pytest.raises(DataModelException) as ex:
create_manifest_label(manifest, key, value, source_type)
assert ex.match(expected_error)
return
label = create_manifest_label(manifest, key, value, source_type)
labels = [ml.label_id for ml in ManifestLabel.select().where(ManifestLabel.manifest == manifest)]
assert label.id in labels
with assert_query_count(1):
assert label in list_manifest_labels(manifest)
assert label not in list_manifest_labels(manifest, 'someprefix')
assert label in list_manifest_labels(manifest, key[0:2])
with assert_query_count(1):
assert get_manifest_label(label.uuid, manifest) == label
def test_list_manifest_labels(initialized_db):
manifest = Manifest.get()
label1 = create_manifest_label(manifest, 'foo', '1', 'manifest')
label2 = create_manifest_label(manifest, 'bar', '2', 'api')
label3 = create_manifest_label(manifest, 'baz', '3', 'internal')
assert label1 in list_manifest_labels(manifest)
assert label2 in list_manifest_labels(manifest)
assert label3 in list_manifest_labels(manifest)
other_manifest = Manifest.select().where(Manifest.id != manifest.id).get()
assert label1 not in list_manifest_labels(other_manifest)
assert label2 not in list_manifest_labels(other_manifest)
assert label3 not in list_manifest_labels(other_manifest)
def test_get_manifest_label(initialized_db):
found = False
for manifest_label in ManifestLabel.select():
assert (get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) ==
manifest_label.label)
assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
found = True
assert found
def test_delete_manifest_label(initialized_db):
found = False
for manifest_label in list(ManifestLabel.select()):
assert (get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) ==
manifest_label.label)
assert manifest_label.label in list_manifest_labels(manifest_label.manifest)
if manifest_label.label.source_type.mutable:
assert delete_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
assert manifest_label.label not in list_manifest_labels(manifest_label.manifest)
assert get_manifest_label(manifest_label.label.uuid, manifest_label.manifest) is None
else:
with pytest.raises(DataModelException):
delete_manifest_label(manifest_label.label.uuid, manifest_label.manifest)
found = True
assert found

View file

@ -0,0 +1,560 @@
import json
from playhouse.test_utils import assert_query_count
from app import docker_v2_signing_key, storage
from digest.digest_tools import sha256_digest
from data.database import (Tag, ManifestBlob, ImageStorageLocation, ManifestChild,
ImageStorage, Image, RepositoryTag, get_epoch_timestamp_ms)
from data.model.oci.manifest import lookup_manifest, get_or_create_manifest
from data.model.oci.tag import filter_to_alive_tags, get_tag
from data.model.oci.shared import get_legacy_image_for_manifest
from data.model.oci.label import list_manifest_labels
from data.model.oci.retriever import RepositoryContentRetriever
from data.model.repository import get_repository, create_repository
from data.model.image import find_create_or_link_image
from data.model.blob import store_blob_record_and_temp_link
from data.model.storage import get_layer_path
from image.docker.schema1 import DockerSchema1ManifestBuilder, DockerSchema1Manifest
from image.docker.schema2.manifest import DockerSchema2ManifestBuilder
from image.docker.schema2.list import DockerSchema2ManifestListBuilder
from util.bytes import Bytes
from test.fixtures import *
def test_lookup_manifest(initialized_db):
found = False
for tag in filter_to_alive_tags(Tag.select()):
found = True
repo = tag.repository
digest = tag.manifest.digest
with assert_query_count(1):
assert lookup_manifest(repo, digest) == tag.manifest
assert found
for tag in Tag.select():
repo = tag.repository
digest = tag.manifest.digest
with assert_query_count(1):
assert lookup_manifest(repo, digest, allow_dead=True) == tag.manifest
def test_lookup_manifest_dead_tag(initialized_db):
dead_tag = Tag.select().where(Tag.lifetime_end_ms <= get_epoch_timestamp_ms()).get()
assert dead_tag.lifetime_end_ms <= get_epoch_timestamp_ms()
assert lookup_manifest(dead_tag.repository, dead_tag.manifest.digest) is None
assert (lookup_manifest(dead_tag.repository, dead_tag.manifest.digest, allow_dead=True) ==
dead_tag.manifest)
def create_manifest_for_testing(repository, differentiation_field='1'):
# Populate a manifest.
layer_json = json.dumps({
'config': {},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [],
})
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
remote_digest = sha256_digest('something')
builder = DockerSchema2ManifestBuilder()
builder.set_config_digest(config_digest, len(layer_json))
builder.add_layer(remote_digest, 1234, urls=['http://hello/world' + differentiation_field])
manifest = builder.build()
created = get_or_create_manifest(repository, manifest, storage)
assert created
return created.manifest, manifest
def test_lookup_manifest_child_tag(initialized_db):
repository = create_repository('devtable', 'newrepo', None)
manifest, manifest_impl = create_manifest_for_testing(repository)
# Mark the hidden tag as dead.
hidden_tag = Tag.get(manifest=manifest, hidden=True)
hidden_tag.lifetime_end_ms = hidden_tag.lifetime_start_ms
hidden_tag.save()
# Ensure the manifest cannot currently be looked up, as it is not pointed to by an alive tag.
assert lookup_manifest(repository, manifest.digest) is None
assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
# Populate a manifest list.
list_builder = DockerSchema2ManifestListBuilder()
list_builder.add_manifest(manifest_impl, 'amd64', 'linux')
manifest_list = list_builder.build()
# Write the manifest list, which should also write the manifests themselves.
created_tuple = get_or_create_manifest(repository, manifest_list, storage)
assert created_tuple is not None
# Since the manifests are not yet referenced by a tag, they cannot be found.
assert lookup_manifest(repository, manifest.digest) is None
assert lookup_manifest(repository, manifest_list.digest) is None
# Unless we ask for "dead" manifests.
assert lookup_manifest(repository, manifest.digest, allow_dead=True) is not None
assert lookup_manifest(repository, manifest_list.digest, allow_dead=True) is not None
def _populate_blob(content):
digest = str(sha256_digest(content))
location = ImageStorageLocation.get(name='local_us')
blob = store_blob_record_and_temp_link('devtable', 'newrepo', digest, location,
len(content), 120)
storage.put_content(['local_us'], get_layer_path(blob), content)
return blob, digest
@pytest.mark.parametrize('schema_version', [
1,
2,
])
def test_get_or_create_manifest(schema_version, initialized_db):
repository = create_repository('devtable', 'newrepo', None)
expected_labels = {
'Foo': 'Bar',
'Baz': 'Meh',
}
layer_json = json.dumps({
'id': 'somelegacyid',
'config': {
'Labels': expected_labels,
},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
],
})
# Create a legacy image.
find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
# Add a blob of random data.
random_data = 'hello world'
_, random_digest = _populate_blob(random_data)
# Build the manifest.
if schema_version == 1:
builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
builder.add_layer(random_digest, layer_json)
sample_manifest_instance = builder.build(docker_v2_signing_key)
elif schema_version == 2:
builder = DockerSchema2ManifestBuilder()
builder.set_config_digest(config_digest, len(layer_json))
builder.add_layer(random_digest, len(random_data))
sample_manifest_instance = builder.build()
# Create a new manifest.
created_manifest = get_or_create_manifest(repository, sample_manifest_instance, storage)
created = created_manifest.manifest
newly_created = created_manifest.newly_created
assert newly_created
assert created is not None
assert created.media_type.name == sample_manifest_instance.media_type
assert created.digest == sample_manifest_instance.digest
assert created.manifest_bytes == sample_manifest_instance.bytes.as_encoded_str()
assert created_manifest.labels_to_apply == expected_labels
# Verify it has a temporary tag pointing to it.
assert Tag.get(manifest=created, hidden=True).lifetime_end_ms
# Verify the legacy image.
legacy_image = get_legacy_image_for_manifest(created)
assert legacy_image is not None
assert legacy_image.storage.content_checksum == random_digest
# Verify the linked blobs.
blob_digests = [mb.blob.content_checksum for mb
in ManifestBlob.select().where(ManifestBlob.manifest == created)]
assert random_digest in blob_digests
if schema_version == 2:
assert config_digest in blob_digests
# Retrieve it again and ensure it is the same manifest.
created_manifest2 = get_or_create_manifest(repository, sample_manifest_instance, storage)
created2 = created_manifest2.manifest
newly_created2 = created_manifest2.newly_created
assert not newly_created2
assert created2 == created
# Ensure it again has a temporary tag.
assert Tag.get(manifest=created2, hidden=True).lifetime_end_ms
# Ensure the labels were added.
labels = list(list_manifest_labels(created))
assert len(labels) == 2
labels_dict = {label.key: label.value for label in labels}
assert labels_dict == expected_labels
def test_get_or_create_manifest_invalid_image(initialized_db):
repository = get_repository('devtable', 'simple')
latest_tag = get_tag(repository, 'latest')
parsed = DockerSchema1Manifest(Bytes.for_string_or_unicode(latest_tag.manifest.manifest_bytes),
validate=False)
builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
builder.add_layer(parsed.blob_digests[0], '{"id": "foo", "parent": "someinvalidimageid"}')
sample_manifest_instance = builder.build(docker_v2_signing_key)
created_manifest = get_or_create_manifest(repository, sample_manifest_instance, storage)
assert created_manifest is None
def test_get_or_create_manifest_list(initialized_db):
repository = create_repository('devtable', 'newrepo', None)
expected_labels = {
'Foo': 'Bar',
'Baz': 'Meh',
}
layer_json = json.dumps({
'id': 'somelegacyid',
'config': {
'Labels': expected_labels,
},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
],
})
# Create a legacy image.
find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
# Add a blob of random data.
random_data = 'hello world'
_, random_digest = _populate_blob(random_data)
# Build the manifests.
v1_builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'anothertag')
v1_builder.add_layer(random_digest, layer_json)
v1_manifest = v1_builder.build(docker_v2_signing_key).unsigned()
v2_builder = DockerSchema2ManifestBuilder()
v2_builder.set_config_digest(config_digest, len(layer_json))
v2_builder.add_layer(random_digest, len(random_data))
v2_manifest = v2_builder.build()
# Write the manifests.
v1_created = get_or_create_manifest(repository, v1_manifest, storage)
assert v1_created
assert v1_created.manifest.digest == v1_manifest.digest
v2_created = get_or_create_manifest(repository, v2_manifest, storage)
assert v2_created
assert v2_created.manifest.digest == v2_manifest.digest
# Build the manifest list.
list_builder = DockerSchema2ManifestListBuilder()
list_builder.add_manifest(v1_manifest, 'amd64', 'linux')
list_builder.add_manifest(v2_manifest, 'amd32', 'linux')
manifest_list = list_builder.build()
# Write the manifest list, which should also write the manifests themselves.
created_tuple = get_or_create_manifest(repository, manifest_list, storage)
assert created_tuple is not None
created_list = created_tuple.manifest
assert created_list
assert created_list.media_type.name == manifest_list.media_type
assert created_list.digest == manifest_list.digest
# Ensure the child manifest links exist.
child_manifests = {cm.child_manifest.digest: cm.child_manifest
for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)}
assert len(child_manifests) == 2
assert v1_manifest.digest in child_manifests
assert v2_manifest.digest in child_manifests
assert child_manifests[v1_manifest.digest].media_type.name == v1_manifest.media_type
assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
def test_get_or_create_manifest_list_duplicate_child_manifest(initialized_db):
repository = create_repository('devtable', 'newrepo', None)
expected_labels = {
'Foo': 'Bar',
'Baz': 'Meh',
}
layer_json = json.dumps({
'id': 'somelegacyid',
'config': {
'Labels': expected_labels,
},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
],
})
# Create a legacy image.
find_create_or_link_image('somelegacyid', repository, 'devtable', {}, 'local_us')
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
# Add a blob of random data.
random_data = 'hello world'
_, random_digest = _populate_blob(random_data)
# Build the manifest.
v2_builder = DockerSchema2ManifestBuilder()
v2_builder.set_config_digest(config_digest, len(layer_json))
v2_builder.add_layer(random_digest, len(random_data))
v2_manifest = v2_builder.build()
# Write the manifest.
v2_created = get_or_create_manifest(repository, v2_manifest, storage)
assert v2_created
assert v2_created.manifest.digest == v2_manifest.digest
# Build the manifest list, with the child manifest repeated.
list_builder = DockerSchema2ManifestListBuilder()
list_builder.add_manifest(v2_manifest, 'amd64', 'linux')
list_builder.add_manifest(v2_manifest, 'amd32', 'linux')
manifest_list = list_builder.build()
# Write the manifest list, which should also write the manifests themselves.
created_tuple = get_or_create_manifest(repository, manifest_list, storage)
assert created_tuple is not None
created_list = created_tuple.manifest
assert created_list
assert created_list.media_type.name == manifest_list.media_type
assert created_list.digest == manifest_list.digest
# Ensure the child manifest links exist.
child_manifests = {cm.child_manifest.digest: cm.child_manifest
for cm in ManifestChild.select().where(ManifestChild.manifest == created_list)}
assert len(child_manifests) == 1
assert v2_manifest.digest in child_manifests
assert child_manifests[v2_manifest.digest].media_type.name == v2_manifest.media_type
# Try to create again and ensure we get back the same manifest list.
created2_tuple = get_or_create_manifest(repository, manifest_list, storage)
assert created2_tuple is not None
assert created2_tuple.manifest == created_list
def test_get_or_create_manifest_with_remote_layers(initialized_db):
repository = create_repository('devtable', 'newrepo', None)
layer_json = json.dumps({
'config': {},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
],
})
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
# Add a blob of random data.
random_data = 'hello world'
_, random_digest = _populate_blob(random_data)
remote_digest = sha256_digest('something')
builder = DockerSchema2ManifestBuilder()
builder.set_config_digest(config_digest, len(layer_json))
builder.add_layer(remote_digest, 1234, urls=['http://hello/world'])
builder.add_layer(random_digest, len(random_data))
manifest = builder.build()
assert remote_digest in manifest.blob_digests
assert remote_digest not in manifest.local_blob_digests
assert manifest.has_remote_layer
assert not manifest.has_legacy_image
assert manifest.get_schema1_manifest('foo', 'bar', 'baz', None) is None
# Write the manifest.
created_tuple = get_or_create_manifest(repository, manifest, storage)
assert created_tuple is not None
created_manifest = created_tuple.manifest
assert created_manifest
assert created_manifest.media_type.name == manifest.media_type
assert created_manifest.digest == manifest.digest
# Verify the legacy image.
legacy_image = get_legacy_image_for_manifest(created_manifest)
assert legacy_image is None
# Verify the linked blobs.
blob_digests = {mb.blob.content_checksum for mb
in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)}
assert random_digest in blob_digests
assert config_digest in blob_digests
assert remote_digest not in blob_digests
def create_manifest_for_testing(repository, differentiation_field='1', include_shared_blob=False):
# Populate a manifest.
layer_json = json.dumps({
'config': {},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [],
})
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
remote_digest = sha256_digest('something')
builder = DockerSchema2ManifestBuilder()
builder.set_config_digest(config_digest, len(layer_json))
builder.add_layer(remote_digest, 1234, urls=['http://hello/world' + differentiation_field])
if include_shared_blob:
_, blob_digest = _populate_blob('some data here')
builder.add_layer(blob_digest, 4567)
manifest = builder.build()
created = get_or_create_manifest(repository, manifest, storage)
assert created
return created.manifest, manifest
def test_retriever(initialized_db):
repository = create_repository('devtable', 'newrepo', None)
layer_json = json.dumps({
'config': {},
"rootfs": {
"type": "layers",
"diff_ids": []
},
"history": [
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
{
"created": "2018-04-03T18:37:09.284840891Z",
"created_by": "do something",
},
],
})
# Add a blob containing the config.
_, config_digest = _populate_blob(layer_json)
# Add a blob of random data.
random_data = 'hello world'
_, random_digest = _populate_blob(random_data)
# Add another blob of random data.
other_random_data = 'hi place'
_, other_random_digest = _populate_blob(other_random_data)
remote_digest = sha256_digest('something')
builder = DockerSchema2ManifestBuilder()
builder.set_config_digest(config_digest, len(layer_json))
builder.add_layer(other_random_digest, len(other_random_data))
builder.add_layer(random_digest, len(random_data))
manifest = builder.build()
assert config_digest in manifest.blob_digests
assert random_digest in manifest.blob_digests
assert other_random_digest in manifest.blob_digests
assert config_digest in manifest.local_blob_digests
assert random_digest in manifest.local_blob_digests
assert other_random_digest in manifest.local_blob_digests
# Write the manifest.
created_tuple = get_or_create_manifest(repository, manifest, storage)
assert created_tuple is not None
created_manifest = created_tuple.manifest
assert created_manifest
assert created_manifest.media_type.name == manifest.media_type
assert created_manifest.digest == manifest.digest
# Verify the linked blobs.
blob_digests = {mb.blob.content_checksum for mb
in ManifestBlob.select().where(ManifestBlob.manifest == created_manifest)}
assert random_digest in blob_digests
assert other_random_digest in blob_digests
assert config_digest in blob_digests
# Delete any Image rows linking to the blobs from temp tags.
for blob_digest in blob_digests:
storage_row = ImageStorage.get(content_checksum=blob_digest)
for image in list(Image.select().where(Image.storage == storage_row)):
all_temp = all([rt.hidden for rt
in RepositoryTag.select().where(RepositoryTag.image == image)])
if all_temp:
RepositoryTag.delete().where(RepositoryTag.image == image).execute()
image.delete_instance(recursive=True)
# Verify the blobs in the retriever.
retriever = RepositoryContentRetriever(repository, storage)
assert (retriever.get_manifest_bytes_with_digest(created_manifest.digest) ==
manifest.bytes.as_encoded_str())
for blob_digest in blob_digests:
assert retriever.get_blob_bytes_with_digest(blob_digest) is not None

View file

@ -0,0 +1,378 @@
from calendar import timegm
from datetime import timedelta, datetime
from playhouse.test_utils import assert_query_count
from data.database import (Tag, ManifestLegacyImage, TagToRepositoryTag, TagManifestToManifest,
TagManifest, Manifest, Repository)
from data.model.oci.test.test_oci_manifest import create_manifest_for_testing
from data.model.oci.tag import (find_matching_tag, get_most_recent_tag,
get_most_recent_tag_lifetime_start, list_alive_tags,
get_legacy_images_for_tags, filter_to_alive_tags,
filter_to_visible_tags, list_repository_tag_history,
get_expired_tag, get_tag, delete_tag,
delete_tags_for_manifest, change_tag_expiration,
set_tag_expiration_for_manifest, retarget_tag,
create_temporary_tag_if_necessary,
lookup_alive_tags_shallow,
lookup_unrecoverable_tags,
get_epoch_timestamp_ms)
from data.model.repository import get_repository, create_repository
from test.fixtures import *
@pytest.mark.parametrize('namespace_name, repo_name, tag_names, expected', [
('devtable', 'simple', ['latest'], 'latest'),
('devtable', 'simple', ['unknown', 'latest'], 'latest'),
('devtable', 'simple', ['unknown'], None),
])
def test_find_matching_tag(namespace_name, repo_name, tag_names, expected, initialized_db):
repo = get_repository(namespace_name, repo_name)
if expected is not None:
with assert_query_count(1):
found = find_matching_tag(repo, tag_names)
assert found is not None
assert found.name == expected
assert not found.lifetime_end_ms
else:
with assert_query_count(1):
assert find_matching_tag(repo, tag_names) is None
def test_get_most_recent_tag_lifetime_start(initialized_db):
repo = get_repository('devtable', 'simple')
tag = get_most_recent_tag(repo)
with assert_query_count(1):
tags = get_most_recent_tag_lifetime_start([repo])
assert tags[repo.id] == tag.lifetime_start_ms
def test_get_most_recent_tag(initialized_db):
repo = get_repository('outsideorg', 'coolrepo')
with assert_query_count(1):
assert get_most_recent_tag(repo).name == 'latest'
def test_get_most_recent_tag_empty_repo(initialized_db):
empty_repo = create_repository('devtable', 'empty', None)
with assert_query_count(1):
assert get_most_recent_tag(empty_repo) is None
def test_list_alive_tags(initialized_db):
found = False
for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
tags = list_alive_tags(tag.repository)
assert tag in tags
with assert_query_count(1):
legacy_images = get_legacy_images_for_tags(tags)
for tag in tags:
assert ManifestLegacyImage.get(manifest=tag.manifest).image == legacy_images[tag.id]
found = True
assert found
# Ensure hidden tags cannot be listed.
tag = Tag.get()
tag.hidden = True
tag.save()
tags = list_alive_tags(tag.repository)
assert tag not in tags
def test_lookup_alive_tags_shallow(initialized_db):
found = False
for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
tags = lookup_alive_tags_shallow(tag.repository)
found = True
assert tag in tags
assert found
# Ensure hidden tags cannot be listed.
tag = Tag.get()
tag.hidden = True
tag.save()
tags = lookup_alive_tags_shallow(tag.repository)
assert tag not in tags
def test_get_tag(initialized_db):
found = False
for tag in filter_to_visible_tags(filter_to_alive_tags(Tag.select())):
repo = tag.repository
with assert_query_count(1):
assert get_tag(repo, tag.name) == tag
found = True
assert found
@pytest.mark.parametrize('namespace_name, repo_name', [
('devtable', 'simple'),
('devtable', 'complex'),
])
def test_list_repository_tag_history(namespace_name, repo_name, initialized_db):
repo = get_repository(namespace_name, repo_name)
with assert_query_count(1):
results, has_more = list_repository_tag_history(repo, 1, 100)
assert results
assert not has_more
def test_list_repository_tag_history_with_history(initialized_db):
repo = get_repository('devtable', 'history')
with assert_query_count(1):
results, _ = list_repository_tag_history(repo, 1, 100)
assert len(results) == 2
assert results[0].lifetime_end_ms is None
assert results[1].lifetime_end_ms is not None
with assert_query_count(1):
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 2
assert results[0].lifetime_end_ms is None
assert results[1].lifetime_end_ms is not None
with assert_query_count(1):
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='foobar')
assert len(results) == 0
def test_list_repository_tag_history_all_tags(initialized_db):
for tag in Tag.select():
repo = tag.repository
with assert_query_count(1):
results, _ = list_repository_tag_history(repo, 1, 1000)
assert (tag in results) == (not tag.hidden)
@pytest.mark.parametrize('namespace_name, repo_name, tag_name, expected', [
('devtable', 'simple', 'latest', False),
('devtable', 'simple', 'unknown', False),
('devtable', 'complex', 'latest', False),
('devtable', 'history', 'latest', True),
])
def test_get_expired_tag(namespace_name, repo_name, tag_name, expected, initialized_db):
repo = get_repository(namespace_name, repo_name)
with assert_query_count(1):
assert bool(get_expired_tag(repo, tag_name)) == expected
def test_delete_tag(initialized_db):
found = False
for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
repo = tag.repository
assert get_tag(repo, tag.name) == tag
assert tag.lifetime_end_ms is None
with assert_query_count(4):
assert delete_tag(repo, tag.name) == tag
assert get_tag(repo, tag.name) is None
found = True
assert found
def test_delete_tags_for_manifest(initialized_db):
for tag in list(filter_to_visible_tags(filter_to_alive_tags(Tag.select()))):
repo = tag.repository
assert get_tag(repo, tag.name) == tag
with assert_query_count(5):
assert delete_tags_for_manifest(tag.manifest) == [tag]
assert get_tag(repo, tag.name) is None
def test_delete_tags_for_manifest_same_manifest(initialized_db):
new_repo = model.repository.create_repository('devtable', 'newrepo', None)
manifest_1, _ = create_manifest_for_testing(new_repo, '1')
manifest_2, _ = create_manifest_for_testing(new_repo, '2')
assert manifest_1.digest != manifest_2.digest
# Add some tag history, moving a tag back and forth between two manifests.
retarget_tag('latest', manifest_1)
retarget_tag('latest', manifest_2)
retarget_tag('latest', manifest_1)
retarget_tag('latest', manifest_2)
retarget_tag('another1', manifest_1)
retarget_tag('another2', manifest_2)
# Delete all tags pointing to the first manifest.
delete_tags_for_manifest(manifest_1)
assert get_tag(new_repo, 'latest').manifest == manifest_2
assert get_tag(new_repo, 'another1') is None
assert get_tag(new_repo, 'another2').manifest == manifest_2
# Delete all tags pointing to the second manifest, which should actually delete the `latest`
# tag now.
delete_tags_for_manifest(manifest_2)
assert get_tag(new_repo, 'latest') is None
assert get_tag(new_repo, 'another1') is None
assert get_tag(new_repo, 'another2') is None
@pytest.mark.parametrize('timedelta, expected_timedelta', [
pytest.param(timedelta(seconds=1), timedelta(hours=1), id='less than minimum'),
pytest.param(timedelta(weeks=300), timedelta(weeks=104), id='more than maxium'),
pytest.param(timedelta(weeks=1), timedelta(weeks=1), id='within range'),
])
def test_change_tag_expiration(timedelta, expected_timedelta, initialized_db):
now = datetime.utcnow()
now_ms = timegm(now.utctimetuple()) * 1000
tag = Tag.get()
tag.lifetime_start_ms = now_ms
tag.save()
original_end_ms, okay = change_tag_expiration(tag, now + timedelta)
assert okay
assert original_end_ms == tag.lifetime_end_ms
updated_tag = Tag.get(id=tag.id)
offset = expected_timedelta.total_seconds() * 1000
expected_ms = (updated_tag.lifetime_start_ms + offset)
assert updated_tag.lifetime_end_ms == expected_ms
original_end_ms, okay = change_tag_expiration(tag, None)
assert okay
assert original_end_ms == expected_ms
updated_tag = Tag.get(id=tag.id)
assert updated_tag.lifetime_end_ms is None
def test_set_tag_expiration_for_manifest(initialized_db):
tag = Tag.get()
manifest = tag.manifest
assert manifest is not None
set_tag_expiration_for_manifest(manifest, datetime.utcnow() + timedelta(weeks=1))
updated_tag = Tag.get(id=tag.id)
assert updated_tag.lifetime_end_ms is not None
def test_create_temporary_tag_if_necessary(initialized_db):
tag = Tag.get()
manifest = tag.manifest
assert manifest is not None
# Ensure no tag is created, since an existing one is present.
created = create_temporary_tag_if_necessary(manifest, 60)
assert created is None
# Mark the tag as deleted.
tag.lifetime_end_ms = 1
tag.save()
# Now create a temp tag.
created = create_temporary_tag_if_necessary(manifest, 60)
assert created is not None
assert created.hidden
assert created.name.startswith('$temp-')
assert created.manifest == manifest
assert created.lifetime_end_ms is not None
assert created.lifetime_end_ms == (created.lifetime_start_ms + 60000)
# Try again and ensure it is not created.
created = create_temporary_tag_if_necessary(manifest, 30)
assert created is None
def test_retarget_tag(initialized_db):
repo = get_repository('devtable', 'history')
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 2
assert results[0].lifetime_end_ms is None
assert results[1].lifetime_end_ms is not None
# Revert back to the original manifest.
created = retarget_tag('latest', results[0].manifest, is_reversion=True,
now_ms=results[1].lifetime_end_ms + 10000)
assert created.lifetime_end_ms is None
assert created.reversion
assert created.name == 'latest'
assert created.manifest == results[0].manifest
# Verify in the history.
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 3
assert results[0].lifetime_end_ms is None
assert results[1].lifetime_end_ms is not None
assert results[2].lifetime_end_ms is not None
assert results[0] == created
# Verify old-style tables.
repository_tag = TagToRepositoryTag.get(tag=created).repository_tag
assert repository_tag.lifetime_start_ts == int(created.lifetime_start_ms / 1000)
tag_manifest = TagManifest.get(tag=repository_tag)
assert TagManifestToManifest.get(tag_manifest=tag_manifest).manifest == created.manifest
def test_retarget_tag_wrong_name(initialized_db):
repo = get_repository('devtable', 'history')
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 2
created = retarget_tag('someothername', results[1].manifest, is_reversion=True)
assert created is None
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 2
def test_lookup_unrecoverable_tags(initialized_db):
# Ensure no existing tags are found.
for repo in Repository.select():
assert not list(lookup_unrecoverable_tags(repo))
# Mark a tag as outside the expiration window and ensure it is found.
repo = get_repository('devtable', 'history')
results, _ = list_repository_tag_history(repo, 1, 100, specific_tag_name='latest')
assert len(results) == 2
results[1].lifetime_end_ms = 1
results[1].save()
# Ensure the tag is now found.
found = list(lookup_unrecoverable_tags(repo))
assert found
assert len(found) == 1
assert found[0] == results[1]
# Mark the tag as expiring in the future and ensure it is no longer found.
results[1].lifetime_end_ms = get_epoch_timestamp_ms() + 1000000
results[1].save()
found = list(lookup_unrecoverable_tags(repo))
assert not found

167
data/model/organization.py Normal file
View file

@ -0,0 +1,167 @@
from data.database import (User, FederatedLogin, TeamMember, Team, TeamRole, RepositoryPermission,
Repository, Namespace, DeletedNamespace)
from data.model import (user, team, DataModelException, InvalidOrganizationException,
InvalidUsernameException, db_transaction, _basequery)
def create_organization(name, email, creating_user, email_required=True, is_possible_abuser=False):
with db_transaction():
try:
# Create the org
new_org = user.create_user_noverify(name, email, email_required=email_required,
is_possible_abuser=is_possible_abuser)
new_org.organization = True
new_org.save()
# Create a team for the owners
owners_team = team.create_team('owners', new_org, 'admin')
# Add the user who created the org to the owners team
team.add_user_to_team(creating_user, owners_team)
return new_org
except InvalidUsernameException as iue:
raise InvalidOrganizationException(iue.message)
def get_organization(name):
try:
return User.get(username=name, organization=True)
except User.DoesNotExist:
raise InvalidOrganizationException('Organization does not exist: %s' %
name)
def convert_user_to_organization(user_obj, admin_user):
if user_obj.robot:
raise DataModelException('Cannot convert a robot into an organization')
with db_transaction():
# Change the user to an organization and disable this account for login.
user_obj.organization = True
user_obj.password_hash = None
user_obj.save()
# Clear any federated auth pointing to this user.
FederatedLogin.delete().where(FederatedLogin.user == user_obj).execute()
# Delete any user-specific permissions on repositories.
(RepositoryPermission.delete()
.where(RepositoryPermission.user == user_obj)
.execute())
# Create a team for the owners
owners_team = team.create_team('owners', user_obj, 'admin')
# Add the user who will admin the org to the owners team
team.add_user_to_team(admin_user, owners_team)
return user_obj
def get_user_organizations(username):
return _basequery.get_user_organizations(username)
def get_organization_team_members(teamid):
joined = User.select().join(TeamMember).join(Team)
query = joined.where(Team.id == teamid)
return query
def __get_org_admin_users(org):
return (User
.select()
.join(TeamMember)
.join(Team)
.join(TeamRole)
.where(Team.organization == org, TeamRole.name == 'admin', User.robot == False)
.distinct())
def get_admin_users(org):
""" Returns the owner users for the organization. """
return __get_org_admin_users(org)
def remove_organization_member(org, user_obj):
org_admins = [u.username for u in __get_org_admin_users(org)]
if len(org_admins) == 1 and user_obj.username in org_admins:
raise DataModelException('Cannot remove user as they are the only organization admin')
with db_transaction():
# Find and remove the user from any repositories under the org.
permissions = list(RepositoryPermission
.select(RepositoryPermission.id)
.join(Repository)
.where(Repository.namespace_user == org,
RepositoryPermission.user == user_obj))
if permissions:
RepositoryPermission.delete().where(RepositoryPermission.id << permissions).execute()
# Find and remove the user from any teams under the org.
members = list(TeamMember
.select(TeamMember.id)
.join(Team)
.where(Team.organization == org, TeamMember.user == user_obj))
if members:
TeamMember.delete().where(TeamMember.id << members).execute()
def get_organization_member_set(org, include_robots=False, users_filter=None):
""" Returns the set of all member usernames under the given organization, with optional
filtering by robots and/or by a specific set of User objects.
"""
Org = User.alias()
org_users = (User
.select(User.username)
.join(TeamMember)
.join(Team)
.where(Team.organization == org)
.distinct())
if not include_robots:
org_users = org_users.where(User.robot == False)
if users_filter is not None:
ids_list = [u.id for u in users_filter if u is not None]
if not ids_list:
return set()
org_users = org_users.where(User.id << ids_list)
return {user.username for user in org_users}
def get_all_repo_users_transitive_via_teams(namespace_name, repository_name):
return (User
.select()
.distinct()
.join(TeamMember)
.join(Team)
.join(RepositoryPermission)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace_name, Repository.name == repository_name))
def get_organizations(deleted=False):
query = User.select().where(User.organization == True, User.robot == False)
if not deleted:
query = query.where(User.id.not_in(DeletedNamespace.select(DeletedNamespace.namespace)))
return query
def get_active_org_count():
return get_organizations().count()
def add_user_as_admin(user_obj, org_obj):
try:
admin_role = TeamRole.get(name='admin')
admin_team = Team.select().where(Team.role == admin_role, Team.organization == org_obj).get()
team.add_user_to_team(user_obj, admin_team)
except team.UserAlreadyInTeam:
pass

322
data/model/permission.py Normal file
View file

@ -0,0 +1,322 @@
from peewee import JOIN
from data.database import (RepositoryPermission, User, Repository, Visibility, Role, TeamMember,
PermissionPrototype, Team, TeamRole, Namespace)
from data.model import DataModelException, _basequery
from util.names import parse_robot_username
def list_team_permissions(team):
return (RepositoryPermission
.select(RepositoryPermission)
.join(Repository)
.join(Visibility)
.switch(RepositoryPermission)
.join(Role)
.switch(RepositoryPermission)
.where(RepositoryPermission.team == team))
def list_robot_permissions(robot_name):
return (RepositoryPermission
.select(RepositoryPermission, User, Repository)
.join(Repository)
.join(Visibility)
.switch(RepositoryPermission)
.join(Role)
.switch(RepositoryPermission)
.join(User)
.where(User.username == robot_name, User.robot == True))
def list_organization_member_permissions(organization, limit_to_user=None):
query = (RepositoryPermission
.select(RepositoryPermission, Repository, User)
.join(Repository)
.switch(RepositoryPermission)
.join(User)
.where(Repository.namespace_user == organization))
if limit_to_user is not None:
query = query.where(RepositoryPermission.user == limit_to_user)
else:
query = query.where(User.robot == False)
return query
def get_all_user_repository_permissions(user):
return _get_user_repo_permissions(user)
def get_user_repo_permissions(user, repo):
return _get_user_repo_permissions(user, limit_to_repository_obj=repo)
def get_user_repository_permissions(user, namespace, repo_name):
return _get_user_repo_permissions(user, limit_namespace=namespace, limit_repo_name=repo_name)
def _get_user_repo_permissions(user, limit_to_repository_obj=None, limit_namespace=None,
limit_repo_name=None):
UserThroughTeam = User.alias()
base_query = (RepositoryPermission
.select(RepositoryPermission, Role, Repository, Namespace)
.join(Role)
.switch(RepositoryPermission)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryPermission))
if limit_to_repository_obj is not None:
base_query = base_query.where(RepositoryPermission.repository == limit_to_repository_obj)
elif limit_namespace and limit_repo_name:
base_query = base_query.where(Repository.name == limit_repo_name,
Namespace.username == limit_namespace)
direct = (base_query
.clone()
.join(User)
.where(User.id == user))
team = (base_query
.clone()
.join(Team)
.join(TeamMember)
.join(UserThroughTeam, on=(UserThroughTeam.id == TeamMember.user))
.where(UserThroughTeam.id == user))
return direct | team
def delete_prototype_permission(org, uid):
found = get_prototype_permission(org, uid)
if not found:
return None
found.delete_instance()
return found
def get_prototype_permission(org, uid):
try:
return PermissionPrototype.get(PermissionPrototype.org == org,
PermissionPrototype.uuid == uid)
except PermissionPrototype.DoesNotExist:
return None
def get_prototype_permissions(org):
ActivatingUser = User.alias()
DelegateUser = User.alias()
query = (PermissionPrototype
.select()
.where(PermissionPrototype.org == org)
.join(ActivatingUser, JOIN.LEFT_OUTER,
on=(ActivatingUser.id == PermissionPrototype.activating_user))
.join(DelegateUser, JOIN.LEFT_OUTER,
on=(DelegateUser.id == PermissionPrototype.delegate_user))
.join(Team, JOIN.LEFT_OUTER,
on=(Team.id == PermissionPrototype.delegate_team))
.join(Role, JOIN.LEFT_OUTER, on=(Role.id == PermissionPrototype.role)))
return query
def update_prototype_permission(org, uid, role_name):
found = get_prototype_permission(org, uid)
if not found:
return None
new_role = Role.get(Role.name == role_name)
found.role = new_role
found.save()
return found
def add_prototype_permission(org, role_name, activating_user,
delegate_user=None, delegate_team=None):
new_role = Role.get(Role.name == role_name)
return PermissionPrototype.create(org=org, role=new_role, activating_user=activating_user,
delegate_user=delegate_user, delegate_team=delegate_team)
def get_org_wide_permissions(user, org_filter=None):
Org = User.alias()
team_with_role = Team.select(Team, Org, TeamRole).join(TeamRole)
with_org = team_with_role.switch(Team).join(Org, on=(Team.organization ==
Org.id))
with_user = with_org.switch(Team).join(TeamMember).join(User)
if org_filter:
with_user.where(Org.username == org_filter)
return with_user.where(User.id == user, Org.organization == True)
def get_all_repo_teams(namespace_name, repository_name):
return (RepositoryPermission
.select(Team.name, Role.name, RepositoryPermission)
.join(Team)
.switch(RepositoryPermission)
.join(Role)
.switch(RepositoryPermission)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace_name, Repository.name == repository_name))
def apply_default_permissions(repo_obj, creating_user_obj):
org = repo_obj.namespace_user
user_clause = ((PermissionPrototype.activating_user == creating_user_obj) |
(PermissionPrototype.activating_user >> None))
team_protos = (PermissionPrototype
.select()
.where(PermissionPrototype.org == org, user_clause,
PermissionPrototype.delegate_user >> None))
def create_team_permission(team, repo, role):
RepositoryPermission.create(team=team, repository=repo, role=role)
__apply_permission_list(repo_obj, team_protos, 'name', create_team_permission)
user_protos = (PermissionPrototype
.select()
.where(PermissionPrototype.org == org, user_clause,
PermissionPrototype.delegate_team >> None))
def create_user_permission(user, repo, role):
# The creating user always gets admin anyway
if user.username == creating_user_obj.username:
return
RepositoryPermission.create(user=user, repository=repo, role=role)
__apply_permission_list(repo_obj, user_protos, 'username', create_user_permission)
def __apply_permission_list(repo, proto_query, name_property, create_permission_func):
final_protos = {}
for proto in proto_query:
applies_to = proto.delegate_team or proto.delegate_user
name = getattr(applies_to, name_property)
# We will skip the proto if it is pre-empted by a more important proto
if name in final_protos and proto.activating_user is None:
continue
# By this point, it is either a user specific proto, or there is no
# proto yet, so we can safely assume it applies
final_protos[name] = (applies_to, proto.role)
for delegate, role in final_protos.values():
create_permission_func(delegate, repo, role)
def __entity_permission_repo_query(entity_id, entity_table, entity_id_property, namespace_name,
repository_name):
""" This method works for both users and teams. """
return (RepositoryPermission
.select(entity_table, Repository, Namespace, Role, RepositoryPermission)
.join(entity_table)
.switch(RepositoryPermission)
.join(Role)
.switch(RepositoryPermission)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Repository.name == repository_name, Namespace.username == namespace_name,
entity_id_property == entity_id))
def get_user_reponame_permission(username, namespace_name, repository_name):
fetched = list(__entity_permission_repo_query(username, User, User.username, namespace_name,
repository_name))
if not fetched:
raise DataModelException('User does not have permission for repo.')
return fetched[0]
def get_team_reponame_permission(team_name, namespace_name, repository_name):
fetched = list(__entity_permission_repo_query(team_name, Team, Team.name, namespace_name,
repository_name))
if not fetched:
raise DataModelException('Team does not have permission for repo.')
return fetched[0]
def delete_user_permission(username, namespace_name, repository_name):
if username == namespace_name:
raise DataModelException('Namespace owner must always be admin.')
fetched = list(__entity_permission_repo_query(username, User, User.username, namespace_name,
repository_name))
if not fetched:
raise DataModelException('User does not have permission for repo.')
fetched[0].delete_instance()
def delete_team_permission(team_name, namespace_name, repository_name):
fetched = list(__entity_permission_repo_query(team_name, Team, Team.name, namespace_name,
repository_name))
if not fetched:
raise DataModelException('Team does not have permission for repo.')
fetched[0].delete_instance()
def __set_entity_repo_permission(entity, permission_entity_property,
namespace_name, repository_name, role_name):
repo = _basequery.get_existing_repository(namespace_name, repository_name)
new_role = Role.get(Role.name == role_name)
# Fetch any existing permission for this entity on the repo
try:
entity_attr = getattr(RepositoryPermission, permission_entity_property)
perm = RepositoryPermission.get(entity_attr == entity, RepositoryPermission.repository == repo)
perm.role = new_role
perm.save()
return perm
except RepositoryPermission.DoesNotExist:
set_entity_kwargs = {permission_entity_property: entity}
new_perm = RepositoryPermission.create(repository=repo, role=new_role, **set_entity_kwargs)
return new_perm
def set_user_repo_permission(username, namespace_name, repository_name, role_name):
if username == namespace_name:
raise DataModelException('Namespace owner must always be admin.')
try:
user = User.get(User.username == username)
except User.DoesNotExist:
raise DataModelException('Invalid username: %s' % username)
if user.robot:
parts = parse_robot_username(user.username)
if not parts:
raise DataModelException('Invalid robot: %s' % username)
robot_namespace, _ = parts
if robot_namespace != namespace_name:
raise DataModelException('Cannot add robot %s under namespace %s' %
(username, namespace_name))
return __set_entity_repo_permission(user, 'user', namespace_name, repository_name, role_name)
def set_team_repo_permission(team_name, namespace_name, repository_name, role_name):
try:
team = (Team
.select()
.join(User)
.where(Team.name == team_name, User.username == namespace_name)
.get())
except Team.DoesNotExist:
raise DataModelException('No team %s in organization %s' % (team_name, namespace_name))
return __set_entity_repo_permission(team, 'team', namespace_name, repository_name, role_name)

21
data/model/release.py Normal file
View file

@ -0,0 +1,21 @@
from data.database import QuayRelease, QuayRegion, QuayService
def set_region_release(service_name, region_name, version):
service, _ = QuayService.get_or_create(name=service_name)
region, _ = QuayRegion.get_or_create(name=region_name)
return QuayRelease.get_or_create(service=service, version=version, region=region)
def get_recent_releases(service_name, region_name):
return (QuayRelease
.select(QuayRelease)
.join(QuayService)
.switch(QuayRelease)
.join(QuayRegion)
.where(QuayService.name == service_name,
QuayRegion.name == region_name,
QuayRelease.reverted == False,
)
.order_by(QuayRelease.created.desc()))

519
data/model/repo_mirror.py Normal file
View file

@ -0,0 +1,519 @@
import re
from datetime import datetime, timedelta
from peewee import IntegrityError, fn
from jsonschema import ValidationError
from data.database import (RepoMirrorConfig, RepoMirrorRule, RepoMirrorRuleType, RepoMirrorStatus,
RepositoryState, Repository, uuid_generator, db_transaction)
from data.fields import DecryptedValue
from data.model import DataModelException
from util.names import parse_robot_username
# TODO: Move these to the configuration
MAX_SYNC_RETRIES = 3
MAX_SYNC_DURATION = 60*60*2 # 2 Hours
def get_eligible_mirrors():
"""
Returns the RepoMirrorConfig that are ready to run now. This includes those that are:
1. Not currently syncing but whose start time is in the past
2. Status of "sync now"
3. Currently marked as syncing but whose expiration time is in the past
"""
now = datetime.utcnow()
immediate_candidates_filter = ((RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNC_NOW) &
(RepoMirrorConfig.sync_expiration_date >> None))
ready_candidates_filter = ((RepoMirrorConfig.sync_start_date <= now) &
(RepoMirrorConfig.sync_retries_remaining > 0) &
(RepoMirrorConfig.sync_status != RepoMirrorStatus.SYNCING) &
(RepoMirrorConfig.sync_expiration_date >> None) &
(RepoMirrorConfig.is_enabled == True))
expired_candidates_filter = ((RepoMirrorConfig.sync_start_date <= now) &
(RepoMirrorConfig.sync_retries_remaining > 0) &
(RepoMirrorConfig.sync_status == RepoMirrorStatus.SYNCING) &
(RepoMirrorConfig.sync_expiration_date <= now) &
(RepoMirrorConfig.is_enabled == True))
return (RepoMirrorConfig
.select()
.join(Repository)
.where(Repository.state == RepositoryState.MIRROR)
.where(immediate_candidates_filter | ready_candidates_filter | expired_candidates_filter)
.order_by(RepoMirrorConfig.sync_start_date.asc()))
def get_max_id_for_repo_mirror_config():
""" Gets the maximum id for repository mirroring """
return RepoMirrorConfig.select(fn.Max(RepoMirrorConfig.id)).scalar()
def get_min_id_for_repo_mirror_config():
""" Gets the minimum id for a repository mirroring """
return RepoMirrorConfig.select(fn.Min(RepoMirrorConfig.id)).scalar()
def claim_mirror(mirror):
"""
Attempt to create an exclusive lock on the RepoMirrorConfig and return it.
If unable to create the lock, `None` will be returned.
"""
# Attempt to update the RepoMirrorConfig to mark it as "claimed"
now = datetime.utcnow()
expiration_date = now + timedelta(seconds=MAX_SYNC_DURATION)
query = (RepoMirrorConfig
.update(sync_status=RepoMirrorStatus.SYNCING,
sync_expiration_date=expiration_date,
sync_transaction_id=uuid_generator())
.where(RepoMirrorConfig.id == mirror.id,
RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
# If the update was successful, then it was claimed. Return the updated instance.
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
return None # Another process must have claimed the mirror faster.
def release_mirror(mirror, sync_status):
"""
Return a mirror to the queue and update its status.
Upon success, move next sync to be at the next interval in the future. Failures remain with
current date to ensure they are picked up for repeat attempt. After MAX_SYNC_RETRIES,
the next sync will be moved ahead as if it were a success. This is to allow a daily sync,
for example, to retry the next day. Without this, users would need to manually run syncs
to clear failure state.
"""
if sync_status == RepoMirrorStatus.FAIL:
retries = max(0, mirror.sync_retries_remaining - 1)
if sync_status == RepoMirrorStatus.SUCCESS or retries < 1:
now = datetime.utcnow()
delta = now - mirror.sync_start_date
delta_seconds = (delta.days * 24 * 60 * 60) + delta.seconds
next_start_date = now + timedelta(seconds=mirror.sync_interval - (delta_seconds % mirror.sync_interval))
retries = MAX_SYNC_RETRIES
else:
next_start_date = mirror.sync_start_date
query = (RepoMirrorConfig
.update(sync_transaction_id=uuid_generator(),
sync_status=sync_status,
sync_start_date=next_start_date,
sync_expiration_date=None,
sync_retries_remaining=retries)
.where(RepoMirrorConfig.id == mirror.id,
RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
# Unable to release Mirror. Has it been claimed by another process?
return None
def expire_mirror(mirror):
"""
Set the mirror to synchronize ASAP and reset its failure count.
"""
# Set the next-sync date to now
# TODO: Verify the `where` conditions would not expire a currently syncing mirror.
query = (RepoMirrorConfig
.update(sync_transaction_id=uuid_generator(),
sync_expiration_date=datetime.utcnow(),
sync_retries_remaining=MAX_SYNC_RETRIES)
.where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
RepoMirrorConfig.id == mirror.id,
RepoMirrorConfig.state != RepoMirrorStatus.SYNCING))
# Fetch and return the latest updates
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
# Unable to update expiration date. Perhaps another process has claimed it?
return None # TODO: Raise some Exception?
def create_mirroring_rule(repository, rule_value, rule_type=RepoMirrorRuleType.TAG_GLOB_CSV):
"""
Create a RepoMirrorRule for a given Repository.
"""
if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
raise ValidationError('validation failed: rule_type must be TAG_GLOB_CSV')
if not isinstance(rule_value, list) or len(rule_value) < 1:
raise ValidationError('validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule')
rule = RepoMirrorRule.create(repository=repository, rule_type=rule_type, rule_value=rule_value)
return rule
def enable_mirroring_for_repository(repository,
root_rule,
internal_robot,
external_reference,
sync_interval,
external_registry_username=None,
external_registry_password=None,
external_registry_config=None,
is_enabled=True,
sync_start_date=None):
"""
Create a RepoMirrorConfig and set the Repository to the MIRROR state.
"""
assert internal_robot.robot
namespace, _ = parse_robot_username(internal_robot.username)
if namespace != repository.namespace_user.username:
raise DataModelException('Cannot use robot for mirroring')
with db_transaction():
# Create the RepoMirrorConfig
try:
username = DecryptedValue(external_registry_username) if external_registry_username else None
password = DecryptedValue(external_registry_password) if external_registry_password else None
mirror = RepoMirrorConfig.create(repository=repository,
root_rule=root_rule,
is_enabled=is_enabled,
internal_robot=internal_robot,
external_reference=external_reference,
external_registry_username=username,
external_registry_password=password,
external_registry_config=external_registry_config or {},
sync_interval=sync_interval,
sync_start_date=sync_start_date or datetime.utcnow())
except IntegrityError:
return RepoMirrorConfig.get(repository=repository)
# Change Repository state to mirroring mode as needed
if repository.state != RepositoryState.MIRROR:
query = (Repository
.update(state=RepositoryState.MIRROR)
.where(Repository.id == repository.id))
if not query.execute():
raise DataModelException('Could not change the state of the repository')
return mirror
def update_sync_status(mirror, sync_status):
"""
Update the sync status
"""
query = (RepoMirrorConfig
.update(sync_transaction_id=uuid_generator(),
sync_status=sync_status)
.where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
RepoMirrorConfig.id == mirror.id))
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
return None
def update_sync_status_to_sync_now(mirror):
"""
This will change the sync status to SYNC_NOW and set the retries remaining to one, if it is
less than one. None will be returned in cases where this is not possible, such as if the
mirror is in the SYNCING state.
"""
if mirror.sync_status == RepoMirrorStatus.SYNCING:
return None
retries = max(mirror.sync_retries_remaining, 1)
query = (RepoMirrorConfig
.update(sync_transaction_id=uuid_generator(),
sync_status=RepoMirrorStatus.SYNC_NOW,
sync_expiration_date=None,
sync_retries_remaining=retries)
.where(RepoMirrorConfig.id == mirror.id,
RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id))
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
return None
def update_sync_status_to_cancel(mirror):
"""
If the mirror is SYNCING, it will be force-claimed (ignoring existing transaction id), and the
state will set to NEVER_RUN. None will be returned in cases where this is not possible, such
as if the mirror is not in the SYNCING state.
"""
if mirror.sync_status != RepoMirrorStatus.SYNCING and mirror.sync_status != RepoMirrorStatus.SYNC_NOW:
return None
query = (RepoMirrorConfig
.update(sync_transaction_id=uuid_generator(),
sync_status=RepoMirrorStatus.NEVER_RUN,
sync_expiration_date=None)
.where(RepoMirrorConfig.id == mirror.id))
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
return None
def update_with_transaction(mirror, **kwargs):
"""
Helper function which updates a Repository's RepoMirrorConfig while also rolling its
sync_transaction_id for locking purposes.
"""
# RepoMirrorConfig attributes which can be modified
mutable_attributes = (
'is_enabled',
'mirror_type',
'external_reference',
'external_registry_username',
'external_registry_password',
'external_registry_config',
'sync_interval',
'sync_start_date',
'sync_expiration_date',
'sync_retries_remaining',
'sync_status',
'sync_transaction_id'
)
# Key-Value map of changes to make
filtered_kwargs = {key:kwargs.pop(key) for key in mutable_attributes if key in kwargs}
# Roll the sync_transaction_id to a new value
filtered_kwargs['sync_transaction_id'] = uuid_generator()
# Generate the query to perform the updates
query = (RepoMirrorConfig
.update(filtered_kwargs)
.where(RepoMirrorConfig.sync_transaction_id == mirror.sync_transaction_id,
RepoMirrorConfig.id == mirror.id))
# Apply the change(s) and return the object if successful
if query.execute():
return RepoMirrorConfig.get_by_id(mirror.id)
else:
return None
def get_mirror(repository):
"""
Return the RepoMirrorConfig associated with the given Repository, or None if it doesn't exist.
"""
try:
return RepoMirrorConfig.get(repository=repository)
except RepoMirrorConfig.DoesNotExist:
return None
def enable_mirror(repository):
"""
Enables a RepoMirrorConfig.
"""
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, is_enabled=True))
def disable_mirror(repository):
"""
Disables a RepoMirrorConfig.
"""
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, is_enabled=False))
def delete_mirror(repository):
"""
Delete a Repository Mirroring configuration.
"""
raise NotImplementedError("TODO: Not Implemented")
def change_remote(repository, remote_repository):
"""
Update the external repository for Repository Mirroring.
"""
mirror = get_mirror(repository)
updates = {
'external_reference': remote_repository
}
return bool(update_with_transaction(mirror, **updates))
def change_credentials(repository, username, password):
"""
Update the credentials used to access the remote repository.
"""
mirror = get_mirror(repository)
updates = {
'external_registry_username': username,
'external_registry_password': password,
}
return bool(update_with_transaction(mirror, **updates))
def change_username(repository, username):
"""
Update the Username used to access the external repository.
"""
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, external_registry_username=username))
def change_sync_interval(repository, interval):
"""
Update the interval at which a repository will be synchronized.
"""
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, sync_interval=interval))
def change_sync_start_date(repository, dt):
"""
Specify when the repository should be synchronized next.
"""
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, sync_start_date=dt))
def change_root_rule(repository, rule):
"""
Specify which rule should be used for repository mirroring.
"""
assert rule.repository == repository
mirror = get_mirror(repository)
return bool(update_with_transaction(mirror, root_rule=rule))
def change_sync_status(repository, sync_status):
"""
Change Repository's mirroring status.
"""
mirror = get_mirror(repository)
return update_with_transaction(mirror, sync_status=sync_status)
def change_retries_remaining(repository, retries_remaining):
"""
Change the number of retries remaining for mirroring a repository.
"""
mirror = get_mirror(repository)
return update_with_transaction(mirror, sync_retries_remaining=retries_remaining)
def change_external_registry_config(repository, config_updates):
"""
Update the 'external_registry_config' with the passed in fields. Config has:
verify_tls: True|False
proxy: JSON fields 'http_proxy', 'https_proxy', andn 'no_proxy'
"""
mirror = get_mirror(repository)
external_registry_config = mirror.external_registry_config
if 'verify_tls' in config_updates:
external_registry_config['verify_tls'] = config_updates['verify_tls']
if 'proxy' in config_updates:
proxy_updates = config_updates['proxy']
for key in ('http_proxy', 'https_proxy', 'no_proxy'):
if key in config_updates['proxy']:
if 'proxy' not in external_registry_config:
external_registry_config['proxy'] = {}
else:
external_registry_config['proxy'][key] = proxy_updates[key]
return update_with_transaction(mirror, external_registry_config=external_registry_config)
def get_mirroring_robot(repository):
"""
Return the robot used for mirroring. Returns None if the repository does not have an associated
RepoMirrorConfig or the robot does not exist.
"""
mirror = get_mirror(repository)
if mirror:
return mirror.internal_robot
return None
def set_mirroring_robot(repository, robot):
"""
Sets the mirroring robot for the repository.
"""
assert robot.robot
namespace, _ = parse_robot_username(robot.username)
if namespace != repository.namespace_user.username:
raise DataModelException('Cannot use robot for mirroring')
mirror = get_mirror(repository)
mirror.internal_robot = robot
mirror.save()
# -------------------- Mirroring Rules --------------------------#
def create_rule(repository, rule_value, rule_type=RepoMirrorRuleType.TAG_GLOB_CSV, left_child=None, right_child=None):
"""
Create a new Rule for mirroring a Repository
"""
if rule_type != RepoMirrorRuleType.TAG_GLOB_CSV:
raise ValidationError('validation failed: rule_type must be TAG_GLOB_CSV')
if not isinstance(rule_value, list) or len(rule_value) < 1:
raise ValidationError('validation failed: rule_value for TAG_GLOB_CSV must be a list with at least one rule')
rule_kwargs = {
'repository': repository,
'rule_value': rule_value,
'rule_type': rule_type,
'left_child': left_child,
'right_child': right_child,
}
rule = RepoMirrorRule.create(**rule_kwargs)
return rule
def list_rules(repository):
"""
Returns all RepoMirrorRules associated with a Repository.
"""
rules = RepoMirrorRule.select().where(RepoMirrorRule.repository == repository).all()
return rules
def get_root_rule(repository):
"""
Return the primary mirroring Rule
"""
mirror = get_mirror(repository)
try:
rule = RepoMirrorRule.get(repository=repository)
return rule
except RepoMirrorRule.DoesNotExist:
return None
def change_rule_value(rule, value):
"""
Update the value of an existing rule.
"""
query = (RepoMirrorRule
.update(rule_value=value)
.where(RepoMirrorRule.id == rule.id))
return query.execute()

457
data/model/repository.py Normal file
View file

@ -0,0 +1,457 @@
import logging
import random
from enum import Enum
from datetime import timedelta, datetime
from peewee import Case, JOIN, fn, SQL, IntegrityError
from cachetools.func import ttl_cache
from data.model import (
config, DataModelException, tag, db_transaction, storage, permission, _basequery)
from data.database import (
Repository, Namespace, RepositoryTag, Star, Image, ImageStorage, User, Visibility,
RepositoryPermission, RepositoryActionCount, Role, RepositoryAuthorizedEmail,
DerivedStorageForImage, Label, db_for_update, get_epoch_timestamp,
db_random_func, db_concat_func, RepositorySearchScore, RepositoryKind, ApprTag,
ManifestLegacyImage, Manifest, ManifestChild)
from data.text import prefix_search
from util.itertoolrecipes import take
logger = logging.getLogger(__name__)
SEARCH_FIELDS = Enum("SearchFields", ["name", "description"])
class RepoStateConfigException(Exception):
""" Repository.state value requires further configuration to operate. """
pass
def get_repo_kind_name(repo):
return Repository.kind.get_name(repo.kind_id)
def get_repository_count():
return Repository.select().count()
def get_public_repo_visibility():
return _basequery.get_public_repo_visibility()
def create_repository(namespace, name, creating_user, visibility='private', repo_kind='image',
description=None):
namespace_user = User.get(username=namespace)
yesterday = datetime.now() - timedelta(days=1)
with db_transaction():
repo = Repository.create(name=name, visibility=Repository.visibility.get_id(visibility),
namespace_user=namespace_user,
kind=Repository.kind.get_id(repo_kind),
description=description)
RepositoryActionCount.create(repository=repo, count=0, date=yesterday)
RepositorySearchScore.create(repository=repo, score=0)
# Note: We put the admin create permission under the transaction to ensure it is created.
if creating_user and not creating_user.organization:
admin = Role.get(name='admin')
RepositoryPermission.create(user=creating_user, repository=repo, role=admin)
# Apply default permissions (only occurs for repositories under organizations)
if creating_user and not creating_user.organization and creating_user.username != namespace:
permission.apply_default_permissions(repo, creating_user)
return repo
def get_repository(namespace_name, repository_name, kind_filter=None):
try:
return _basequery.get_existing_repository(namespace_name, repository_name,
kind_filter=kind_filter)
except Repository.DoesNotExist:
return None
def get_or_create_repository(namespace, name, creating_user, visibility='private',
repo_kind='image'):
repo = get_repository(namespace, name, repo_kind)
if repo is None:
repo = create_repository(namespace, name, creating_user, visibility, repo_kind)
return repo
@ttl_cache(maxsize=1, ttl=600)
def _get_gc_expiration_policies():
policy_tuples_query = (
Namespace.select(Namespace.removed_tag_expiration_s).distinct()
.limit(100) # This sucks but it's the only way to limit memory
.tuples())
return [policy[0] for policy in policy_tuples_query]
def get_random_gc_policy():
""" Return a single random policy from the database to use when garbage collecting.
"""
return random.choice(_get_gc_expiration_policies())
def find_repository_with_garbage(limit_to_gc_policy_s):
expiration_timestamp = get_epoch_timestamp() - limit_to_gc_policy_s
try:
candidates = (RepositoryTag.select(RepositoryTag.repository).join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(~(RepositoryTag.lifetime_end_ts >> None),
(RepositoryTag.lifetime_end_ts <= expiration_timestamp),
(Namespace.removed_tag_expiration_s == limit_to_gc_policy_s)).limit(500)
.distinct().alias('candidates'))
found = (RepositoryTag.select(candidates.c.repository_id).from_(candidates)
.order_by(db_random_func()).get())
if found is None:
return
return Repository.get(Repository.id == found.repository_id)
except RepositoryTag.DoesNotExist:
return None
except Repository.DoesNotExist:
return None
def star_repository(user, repository):
""" Stars a repository. """
star = Star.create(user=user.id, repository=repository.id)
star.save()
def unstar_repository(user, repository):
""" Unstars a repository. """
try:
(Star.delete().where(Star.repository == repository.id, Star.user == user.id).execute())
except Star.DoesNotExist:
raise DataModelException('Star not found.')
def set_trust(repo, trust_enabled):
repo.trust_enabled = trust_enabled
repo.save()
def set_description(repo, description):
repo.description = description
repo.save()
def get_user_starred_repositories(user, kind_filter='image'):
""" Retrieves all of the repositories a user has starred. """
try:
repo_kind = Repository.kind.get_id(kind_filter)
except RepositoryKind.DoesNotExist:
raise DataModelException('Unknown kind of repository')
query = (Repository.select(Repository, User, Visibility, Repository.id.alias('rid')).join(Star)
.switch(Repository).join(User).switch(Repository).join(Visibility)
.where(Star.user == user, Repository.kind == repo_kind))
return query
def repository_is_starred(user, repository):
""" Determines whether a user has starred a repository or not. """
try:
(Star.select().where(Star.repository == repository.id, Star.user == user.id).get())
return True
except Star.DoesNotExist:
return False
def get_stars(repository_ids):
""" Returns a map from repository ID to the number of stars for each repository in the
given repository IDs list.
"""
if not repository_ids:
return {}
tuples = (Star.select(Star.repository, fn.Count(Star.id))
.where(Star.repository << repository_ids).group_by(Star.repository).tuples())
star_map = {}
for record in tuples:
star_map[record[0]] = record[1]
return star_map
def get_visible_repositories(username, namespace=None, kind_filter='image', include_public=False,
start_id=None, limit=None):
""" Returns the repositories visible to the given user (if any).
"""
if not include_public and not username:
# Short circuit by returning a query that will find no repositories. We need to return a query
# here, as it will be modified by other queries later on.
return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
query = (Repository.select(Repository.name,
Repository.id.alias('rid'), Repository.description,
Namespace.username, Repository.visibility, Repository.kind)
.switch(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id)))
user_id = None
if username:
# Note: We only need the permissions table if we will filter based on a user's permissions.
query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN.LEFT_OUTER)
found_namespace = _get_namespace_user(username)
if not found_namespace:
return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
user_id = found_namespace.id
query = _basequery.filter_to_repos_for_user(query, user_id, namespace, kind_filter,
include_public, start_id=start_id)
if limit is not None:
query = query.limit(limit).order_by(SQL('rid'))
return query
def get_app_repository(namespace_name, repository_name):
""" Find an application repository. """
try:
return _basequery.get_existing_repository(namespace_name, repository_name,
kind_filter='application')
except Repository.DoesNotExist:
return None
def get_app_search(lookup, search_fields=None, username=None, limit=50):
if search_fields is None:
search_fields = set([SEARCH_FIELDS.name.name])
return get_filtered_matching_repositories(lookup, filter_username=username,
search_fields=search_fields, repo_kind='application',
offset=0, limit=limit)
def _get_namespace_user(username):
try:
return User.get(username=username)
except User.DoesNotExist:
return None
def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_kind='image',
offset=0, limit=25, search_fields=None):
""" Returns an iterator of all repositories matching the given lookup value, with optional
filtering to a specific user. If the user is unspecified, only public repositories will
be returned.
"""
if search_fields is None:
search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])
# Build the unfiltered search query.
unfiltered_query = _get_sorted_matching_repositories(lookup_value, repo_kind=repo_kind,
search_fields=search_fields,
include_private=filter_username is not None,
ids_only=filter_username is not None)
# Add a filter to the iterator, if necessary.
if filter_username is not None:
filter_user = _get_namespace_user(filter_username)
if filter_user is None:
return []
iterator = _filter_repositories_visible_to_user(unfiltered_query, filter_user.id, limit,
repo_kind)
if offset > 0:
take(offset, iterator)
# Return the results.
return list(take(limit, iterator))
return list(unfiltered_query.offset(offset).limit(limit))
def _filter_repositories_visible_to_user(unfiltered_query, filter_user_id, limit, repo_kind):
encountered = set()
chunk_count = limit * 2
unfiltered_page = 0
iteration_count = 0
while iteration_count < 10: # Just to be safe
# Find the next chunk's worth of repository IDs, paginated by the chunk size.
unfiltered_page = unfiltered_page + 1
found_ids = [r.id for r in unfiltered_query.paginate(unfiltered_page, chunk_count)]
# Make sure we haven't encountered these results before. This code is used to handle
# the case where we've previously seen a result, as pagination is not necessary
# stable in SQL databases.
unfiltered_repository_ids = set(found_ids)
new_unfiltered_ids = unfiltered_repository_ids - encountered
if not new_unfiltered_ids:
break
encountered.update(new_unfiltered_ids)
# Filter the repositories found to only those visible to the current user.
query = (Repository
.select(Repository, Namespace)
.distinct()
.join(Namespace, on=(Namespace.id == Repository.namespace_user)).switch(Repository)
.join(RepositoryPermission).where(Repository.id << list(new_unfiltered_ids)))
filtered = _basequery.filter_to_repos_for_user(query, filter_user_id, repo_kind=repo_kind)
# Sort the filtered repositories by their initial order.
all_filtered_repos = list(filtered)
all_filtered_repos.sort(key=lambda repo: found_ids.index(repo.id))
# Yield the repositories in sorted order.
for filtered_repo in all_filtered_repos:
yield filtered_repo
# If the number of found IDs is less than the chunk count, then we're done.
if len(found_ids) < chunk_count:
break
iteration_count = iteration_count + 1
def _get_sorted_matching_repositories(lookup_value, repo_kind='image', include_private=False,
search_fields=None, ids_only=False):
""" Returns a query of repositories matching the given lookup string, with optional inclusion of
private repositories. Note that this method does *not* filter results based on visibility
to users.
"""
select_fields = [Repository.id] if ids_only else [Repository, Namespace]
if not lookup_value:
# This is a generic listing of repositories. Simply return the sorted repositories based
# on RepositorySearchScore.
query = (Repository
.select(*select_fields)
.join(RepositorySearchScore)
.order_by(RepositorySearchScore.score.desc()))
else:
if search_fields is None:
search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])
# Always search at least on name (init clause)
clause = Repository.name.match(lookup_value)
computed_score = RepositorySearchScore.score.alias('score')
# If the description field is in the search fields, then we need to compute a synthetic score
# to discount the weight of the description more than the name.
if SEARCH_FIELDS.description.name in search_fields:
clause = Repository.description.match(lookup_value) | clause
cases = [(Repository.name.match(lookup_value), 100 * RepositorySearchScore.score),]
computed_score = Case(None, cases, RepositorySearchScore.score).alias('score')
select_fields.append(computed_score)
query = (Repository.select(*select_fields)
.join(RepositorySearchScore)
.where(clause)
.order_by(SQL('score').desc()))
if repo_kind is not None:
query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))
if not include_private:
query = query.where(Repository.visibility == _basequery.get_public_repo_visibility())
if not ids_only:
query = (query
.switch(Repository)
.join(Namespace, on=(Namespace.id == Repository.namespace_user)))
return query
def lookup_repository(repo_id):
try:
return Repository.get(Repository.id == repo_id)
except Repository.DoesNotExist:
return None
def is_repository_public(repository):
return repository.visibility_id == _basequery.get_public_repo_visibility().id
def repository_is_public(namespace_name, repository_name):
try:
(Repository.select().join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository).join(Visibility).where(Namespace.username == namespace_name,
Repository.name == repository_name,
Visibility.name == 'public').get())
return True
except Repository.DoesNotExist:
return False
def set_repository_visibility(repo, visibility):
visibility_obj = Visibility.get(name=visibility)
if not visibility_obj:
return
repo.visibility = visibility_obj
repo.save()
def get_email_authorized_for_repo(namespace, repository, email):
try:
return (RepositoryAuthorizedEmail.select(RepositoryAuthorizedEmail, Repository, Namespace)
.join(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Namespace.username == namespace, Repository.name == repository,
RepositoryAuthorizedEmail.email == email).get())
except RepositoryAuthorizedEmail.DoesNotExist:
return None
def create_email_authorization_for_repo(namespace_name, repository_name, email):
try:
repo = _basequery.get_existing_repository(namespace_name, repository_name)
except Repository.DoesNotExist:
raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
return RepositoryAuthorizedEmail.create(repository=repo, email=email, confirmed=False)
def confirm_email_authorization_for_repo(code):
try:
found = (RepositoryAuthorizedEmail.select(RepositoryAuthorizedEmail, Repository, Namespace)
.join(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(RepositoryAuthorizedEmail.code == code).get())
except RepositoryAuthorizedEmail.DoesNotExist:
raise DataModelException('Invalid confirmation code.')
found.confirmed = True
found.save()
return found
def is_empty(namespace_name, repository_name):
""" Returns if the repository referenced by the given namespace and name is empty. If the repo
doesn't exist, returns True.
"""
try:
tag.list_repository_tags(namespace_name, repository_name).limit(1).get()
return False
except RepositoryTag.DoesNotExist:
return True
def get_repository_state(namespace_name, repository_name):
""" Return the Repository State if the Repository exists. Otherwise, returns None. """
repo = get_repository(namespace_name, repository_name)
if repo:
return repo.state
return None
def set_repository_state(repo, state):
repo.state = state
repo.save()

View file

@ -0,0 +1,129 @@
import logging
from collections import namedtuple
from peewee import IntegrityError
from datetime import date, timedelta, datetime
from data.database import (Repository, LogEntry, LogEntry2, LogEntry3, RepositoryActionCount,
RepositorySearchScore, db_random_func, fn)
logger = logging.getLogger(__name__)
search_bucket = namedtuple('SearchBucket', ['delta', 'days', 'weight'])
# Defines the various buckets for search scoring. Each bucket is computed using the given time
# delta from today *minus the previous bucket's time period*. Once all the actions over the
# bucket's time period have been collected, they are multiplied by the given modifier. The modifiers
# for this bucket were determined via the integral of (2/((x/183)+1)^2)/183 over the period of days
# in the bucket; this integral over 0..183 has a sum of 1, so we get a good normalize score result.
SEARCH_BUCKETS = [
search_bucket(timedelta(days=1), 1, 0.010870),
search_bucket(timedelta(days=7), 6, 0.062815),
search_bucket(timedelta(days=31), 24, 0.21604),
search_bucket(timedelta(days=183), 152, 0.71028),
]
def find_uncounted_repository():
""" Returns a repository that has not yet had an entry added into the RepositoryActionCount
table for yesterday.
"""
try:
# Get a random repository to count.
today = date.today()
yesterday = today - timedelta(days=1)
has_yesterday_actions = (RepositoryActionCount
.select(RepositoryActionCount.repository)
.where(RepositoryActionCount.date == yesterday))
to_count = (Repository
.select()
.where(~(Repository.id << (has_yesterday_actions)))
.order_by(db_random_func()).get())
return to_count
except Repository.DoesNotExist:
return None
def count_repository_actions(to_count, day):
""" Aggregates repository actions from the LogEntry table for the specified day. Returns the
count or None on error.
"""
# TODO: Clean this up a bit.
def lookup_action_count(model):
return (model
.select()
.where(model.repository == to_count,
model.datetime >= day,
model.datetime < (day + timedelta(days=1)))
.count())
actions = (lookup_action_count(LogEntry3) + lookup_action_count(LogEntry2) +
lookup_action_count(LogEntry))
return actions
def store_repository_action_count(repository, day, action_count):
""" Stores the action count for a repository for a specific day. Returns False if the
repository already has an entry for the specified day.
"""
try:
RepositoryActionCount.create(repository=repository, date=day, count=action_count)
return True
except IntegrityError:
logger.debug('Count already written for repository %s', repository.id)
return False
def update_repository_score(repo):
""" Updates the repository score entry for the given table by retrieving information from
the RepositoryActionCount table. Note that count_repository_actions for the repo should
be called first. Returns True if the row was updated and False otherwise.
"""
today = date.today()
# Retrieve the counts for each bucket and calculate the final score.
final_score = 0.0
last_end_timedelta = timedelta(days=0)
for bucket in SEARCH_BUCKETS:
start_date = today - bucket.delta
end_date = today - last_end_timedelta
last_end_timedelta = bucket.delta
query = (RepositoryActionCount
.select(fn.Sum(RepositoryActionCount.count), fn.Count(RepositoryActionCount.id))
.where(RepositoryActionCount.date >= start_date,
RepositoryActionCount.date < end_date,
RepositoryActionCount.repository == repo))
bucket_tuple = query.tuples()[0]
logger.debug('Got bucket tuple %s for bucket %s for repository %s', bucket_tuple, bucket,
repo.id)
if bucket_tuple[0] is None:
continue
bucket_sum = float(bucket_tuple[0])
bucket_count = int(bucket_tuple[1])
if not bucket_count:
continue
bucket_score = bucket_sum / (bucket_count * 1.0)
final_score += bucket_score * bucket.weight
# Update the existing repo search score row or create a new one.
normalized_score = int(final_score * 100.0)
try:
try:
search_score_row = RepositorySearchScore.get(repository=repo)
search_score_row.last_updated = datetime.now()
search_score_row.score = normalized_score
search_score_row.save()
return True
except RepositorySearchScore.DoesNotExist:
RepositorySearchScore.create(repository=repo, score=normalized_score, last_updated=today)
return True
except IntegrityError:
logger.debug('RepositorySearchScore row already existed; skipping')
return False

205
data/model/service_keys.py Normal file
View file

@ -0,0 +1,205 @@
import re
from calendar import timegm
from datetime import datetime, timedelta
from peewee import JOIN
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from data.database import db_for_update, User, ServiceKey, ServiceKeyApproval
from data.model import (ServiceKeyDoesNotExist, ServiceKeyAlreadyApproved, ServiceNameInvalid,
db_transaction, config)
from data.model.notification import create_notification, delete_all_notifications_by_path_prefix
from util.security.fingerprint import canonical_kid
_SERVICE_NAME_REGEX = re.compile(r'^[a-z0-9_]+$')
def _expired_keys_clause(service):
return ((ServiceKey.service == service) &
(ServiceKey.expiration_date <= datetime.utcnow()))
def _stale_expired_keys_service_clause(service):
return ((ServiceKey.service == service) & _stale_expired_keys_clause())
def _stale_expired_keys_clause():
expired_ttl = timedelta(seconds=config.app_config['EXPIRED_SERVICE_KEY_TTL_SEC'])
return (ServiceKey.expiration_date <= (datetime.utcnow() - expired_ttl))
def _stale_unapproved_keys_clause(service):
unapproved_ttl = timedelta(seconds=config.app_config['UNAPPROVED_SERVICE_KEY_TTL_SEC'])
return ((ServiceKey.service == service) &
(ServiceKey.approval >> None) &
(ServiceKey.created_date <= (datetime.utcnow() - unapproved_ttl)))
def _gc_expired(service):
ServiceKey.delete().where(_stale_expired_keys_service_clause(service) |
_stale_unapproved_keys_clause(service)).execute()
def _verify_service_name(service_name):
if not _SERVICE_NAME_REGEX.match(service_name):
raise ServiceNameInvalid
def _notify_superusers(key):
notification_metadata = {
'name': key.name,
'kid': key.kid,
'service': key.service,
'jwk': key.jwk,
'metadata': key.metadata,
'created_date': timegm(key.created_date.utctimetuple()),
}
if key.expiration_date is not None:
notification_metadata['expiration_date'] = timegm(key.expiration_date.utctimetuple())
if len(config.app_config['SUPER_USERS']) > 0:
superusers = User.select().where(User.username << config.app_config['SUPER_USERS'])
for superuser in superusers:
create_notification('service_key_submitted', superuser, metadata=notification_metadata,
lookup_path='/service_key_approval/{0}/{1}'.format(key.kid, superuser.id))
def create_service_key(name, kid, service, jwk, metadata, expiration_date, rotation_duration=None):
_verify_service_name(service)
_gc_expired(service)
key = ServiceKey.create(name=name, kid=kid, service=service, jwk=jwk, metadata=metadata,
expiration_date=expiration_date, rotation_duration=rotation_duration)
_notify_superusers(key)
return key
def generate_service_key(service, expiration_date, kid=None, name='', metadata=None,
rotation_duration=None):
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
if kid is None:
kid = canonical_kid(jwk)
key = create_service_key(name, kid, service, jwk, metadata or {}, expiration_date,
rotation_duration=rotation_duration)
return (private_key, key)
def replace_service_key(old_kid, kid, jwk, metadata, expiration_date):
try:
with db_transaction():
key = db_for_update(ServiceKey.select().where(ServiceKey.kid == old_kid)).get()
key.metadata.update(metadata)
ServiceKey.create(name=key.name, kid=kid, service=key.service, jwk=jwk,
metadata=key.metadata, expiration_date=expiration_date,
rotation_duration=key.rotation_duration, approval=key.approval)
key.delete_instance()
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist
_notify_superusers(key)
delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(old_kid))
_gc_expired(key.service)
def update_service_key(kid, name=None, metadata=None):
try:
with db_transaction():
key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
if name is not None:
key.name = name
if metadata is not None:
key.metadata.update(metadata)
key.save()
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist
def delete_service_key(kid):
try:
key = ServiceKey.get(kid=kid)
ServiceKey.delete().where(ServiceKey.kid == kid).execute()
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist
delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(kid))
_gc_expired(key.service)
return key
def set_key_expiration(kid, expiration_date):
try:
service_key = get_service_key(kid, alive_only=False, approved_only=False)
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist
service_key.expiration_date = expiration_date
service_key.save()
def approve_service_key(kid, approval_type, approver=None, notes=''):
try:
with db_transaction():
key = db_for_update(ServiceKey.select().where(ServiceKey.kid == kid)).get()
if key.approval is not None:
raise ServiceKeyAlreadyApproved
approval = ServiceKeyApproval.create(approver=approver, approval_type=approval_type,
notes=notes)
key.approval = approval
key.save()
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist
delete_all_notifications_by_path_prefix('/service_key_approval/{0}'.format(kid))
return key
def _list_service_keys_query(kid=None, service=None, approved_only=True, alive_only=True,
approval_type=None):
query = ServiceKey.select().join(ServiceKeyApproval, JOIN.LEFT_OUTER)
if approved_only:
query = query.where(~(ServiceKey.approval >> None))
if alive_only:
query = query.where((ServiceKey.expiration_date > datetime.utcnow()) |
(ServiceKey.expiration_date >> None))
if approval_type is not None:
query = query.where(ServiceKeyApproval.approval_type == approval_type)
if service is not None:
query = query.where(ServiceKey.service == service)
query = query.where(~(_expired_keys_clause(service)) |
~(_stale_unapproved_keys_clause(service)))
if kid is not None:
query = query.where(ServiceKey.kid == kid)
query = query.where(~(_stale_expired_keys_clause()) | (ServiceKey.expiration_date >> None))
return query
def list_all_keys():
return list(_list_service_keys_query(approved_only=False, alive_only=False))
def list_service_keys(service):
return list(_list_service_keys_query(service=service))
def get_service_key(kid, service=None, alive_only=True, approved_only=True):
try:
return _list_service_keys_query(kid=kid, service=service, approved_only=approved_only,
alive_only=alive_only).get()
except ServiceKey.DoesNotExist:
raise ServiceKeyDoesNotExist

View file

@ -0,0 +1,94 @@
from sqlalchemy import (Table, MetaData, Column, ForeignKey, Integer, String, Boolean, Text,
DateTime, Date, BigInteger, Index, text)
from peewee import (PrimaryKeyField, CharField, BooleanField, DateTimeField, TextField,
ForeignKeyField, BigIntegerField, IntegerField, DateField)
OPTIONS_TO_COPY = [
'null',
'default',
'primary_key',
]
OPTION_TRANSLATIONS = {
'null': 'nullable',
}
def gen_sqlalchemy_metadata(peewee_model_list):
metadata = MetaData(naming_convention={
"ix": 'ix_%(column_0_label)s',
"uq": "uq_%(table_name)s_%(column_0_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
})
for model in peewee_model_list:
meta = model._meta
all_indexes = set(meta.indexes)
fulltext_indexes = []
columns = []
for field in meta.sorted_fields:
alchemy_type = None
col_args = []
col_kwargs = {}
if isinstance(field, PrimaryKeyField):
alchemy_type = Integer
elif isinstance(field, CharField):
alchemy_type = String(field.max_length)
elif isinstance(field, BooleanField):
alchemy_type = Boolean
elif isinstance(field, DateTimeField):
alchemy_type = DateTime
elif isinstance(field, DateField):
alchemy_type = Date
elif isinstance(field, TextField):
alchemy_type = Text
elif isinstance(field, ForeignKeyField):
alchemy_type = Integer
all_indexes.add(((field.name, ), field.unique))
if not field.deferred:
target_name = '%s.%s' % (field.rel_model._meta.table_name, field.rel_field.column_name)
col_args.append(ForeignKey(target_name))
elif isinstance(field, BigIntegerField):
alchemy_type = BigInteger
elif isinstance(field, IntegerField):
alchemy_type = Integer
else:
raise RuntimeError('Unknown column type: %s' % field)
if hasattr(field, '__fulltext__'):
# Add the fulltext index for the field, based on whether we are under MySQL or Postgres.
fulltext_indexes.append(field.name)
for option_name in OPTIONS_TO_COPY:
alchemy_option_name = (OPTION_TRANSLATIONS[option_name]
if option_name in OPTION_TRANSLATIONS else option_name)
if alchemy_option_name not in col_kwargs:
option_val = getattr(field, option_name)
col_kwargs[alchemy_option_name] = option_val
if field.unique or field.index:
all_indexes.add(((field.name, ), field.unique))
new_col = Column(field.column_name, alchemy_type, *col_args, **col_kwargs)
columns.append(new_col)
new_table = Table(meta.table_name, metadata, *columns)
for col_prop_names, unique in all_indexes:
col_names = [meta.fields[prop_name].column_name for prop_name in col_prop_names]
index_name = '%s_%s' % (meta.table_name, '_'.join(col_names))
col_refs = [getattr(new_table.c, col_name) for col_name in col_names]
Index(index_name, *col_refs, unique=unique)
for col_field_name in fulltext_indexes:
index_name = '%s_%s__fulltext' % (meta.table_name, col_field_name)
col_ref = getattr(new_table.c, col_field_name)
Index(index_name, col_ref, postgresql_ops={col_field_name: 'gin_trgm_ops'},
postgresql_using='gin',
mysql_prefix='FULLTEXT')
return metadata

373
data/model/storage.py Normal file
View file

@ -0,0 +1,373 @@
import logging
from peewee import SQL, IntegrityError
from cachetools.func import lru_cache
from collections import namedtuple
from data.model import (config, db_transaction, InvalidImageException, TorrentInfoDoesNotExist,
DataModelException, _basequery)
from data.database import (ImageStorage, Image, ImageStoragePlacement, ImageStorageLocation,
ImageStorageTransformation, ImageStorageSignature,
ImageStorageSignatureKind, Repository, Namespace, TorrentInfo, ApprBlob,
ensure_under_transaction, ManifestBlob)
logger = logging.getLogger(__name__)
_Location = namedtuple('location', ['id', 'name'])
@lru_cache(maxsize=1)
def get_image_locations():
location_map = {}
for location in ImageStorageLocation.select():
location_tuple = _Location(location.id, location.name)
location_map[location.id] = location_tuple
location_map[location.name] = location_tuple
return location_map
def get_image_location_for_name(location_name):
locations = get_image_locations()
return locations[location_name]
def get_image_location_for_id(location_id):
locations = get_image_locations()
return locations[location_id]
def add_storage_placement(storage, location_name):
""" Adds a storage placement for the given storage at the given location. """
location = get_image_location_for_name(location_name)
try:
ImageStoragePlacement.create(location=location.id, storage=storage)
except IntegrityError:
# Placement already exists. Nothing to do.
pass
def _orphaned_storage_query(candidate_ids):
""" Returns the subset of the candidate ImageStorage IDs representing storages that are no
longer referenced by images.
"""
# Issue a union query to find all storages that are still referenced by a candidate storage. This
# is much faster than the group_by and having call we used to use here.
nonorphaned_queries = []
for counter, candidate_id in enumerate(candidate_ids):
query_alias = 'q{0}'.format(counter)
# TODO: remove the join with Image once fully on the OCI data model.
storage_subq = (ImageStorage
.select(ImageStorage.id)
.join(Image)
.where(ImageStorage.id == candidate_id)
.limit(1)
.alias(query_alias))
nonorphaned_queries.append(ImageStorage
.select(SQL('*'))
.from_(storage_subq))
manifest_storage_subq = (ImageStorage
.select(ImageStorage.id)
.join(ManifestBlob)
.where(ImageStorage.id == candidate_id)
.limit(1)
.alias(query_alias))
nonorphaned_queries.append(ImageStorage
.select(SQL('*'))
.from_(manifest_storage_subq))
# Build the set of storages that are missing. These storages are orphaned.
nonorphaned_storage_ids = {storage.id for storage
in _basequery.reduce_as_tree(nonorphaned_queries)}
return list(candidate_ids - nonorphaned_storage_ids)
def garbage_collect_storage(storage_id_whitelist):
""" Performs GC on a possible subset of the storage's with the IDs found in the
whitelist. The storages in the whitelist will be checked, and any orphaned will
be removed, with those IDs being returned.
"""
if len(storage_id_whitelist) == 0:
return []
def placements_to_filtered_paths_set(placements_list):
""" Returns the list of paths to remove from storage, filtered from the given placements
query by removing any CAS paths that are still referenced by storage(s) in the database.
"""
with ensure_under_transaction():
if not placements_list:
return set()
# Find the content checksums not referenced by other storages. Any that are, we cannot
# remove.
content_checksums = set([placement.storage.content_checksum for placement in placements_list
if placement.storage.cas_path])
unreferenced_checksums = set()
if content_checksums:
# Check the current image storage.
query = (ImageStorage
.select(ImageStorage.content_checksum)
.where(ImageStorage.content_checksum << list(content_checksums)))
is_referenced_checksums = set([image_storage.content_checksum for image_storage in query])
if is_referenced_checksums:
logger.warning('GC attempted to remove CAS checksums %s, which are still IS referenced',
is_referenced_checksums)
# Check the ApprBlob table as well.
query = ApprBlob.select(ApprBlob.digest).where(ApprBlob.digest << list(content_checksums))
appr_blob_referenced_checksums = set([blob.digest for blob in query])
if appr_blob_referenced_checksums:
logger.warning('GC attempted to remove CAS checksums %s, which are ApprBlob referenced',
appr_blob_referenced_checksums)
unreferenced_checksums = (content_checksums - appr_blob_referenced_checksums -
is_referenced_checksums)
# Return all placements for all image storages found not at a CAS path or with a content
# checksum that is referenced.
return {(get_image_location_for_id(placement.location_id).name,
get_layer_path(placement.storage))
for placement in placements_list
if not placement.storage.cas_path or
placement.storage.content_checksum in unreferenced_checksums}
# Note: Both of these deletes must occur in the same transaction (unfortunately) because a
# storage without any placement is invalid, and a placement cannot exist without a storage.
# TODO: We might want to allow for null storages on placements, which would allow us to
# delete the storages, then delete the placements in a non-transaction.
logger.debug('Garbage collecting storages from candidates: %s', storage_id_whitelist)
with db_transaction():
orphaned_storage_ids = _orphaned_storage_query(storage_id_whitelist)
if len(orphaned_storage_ids) == 0:
# Nothing to GC.
return []
placements_to_remove = list(ImageStoragePlacement
.select(ImageStoragePlacement, ImageStorage)
.join(ImageStorage)
.where(ImageStorage.id << orphaned_storage_ids))
# Remove the placements for orphaned storages
if len(placements_to_remove) > 0:
placement_ids_to_remove = [placement.id for placement in placements_to_remove]
placements_removed = (ImageStoragePlacement
.delete()
.where(ImageStoragePlacement.id << placement_ids_to_remove)
.execute())
logger.debug('Removed %s image storage placements', placements_removed)
# Remove all orphaned storages
torrents_removed = (TorrentInfo
.delete()
.where(TorrentInfo.storage << orphaned_storage_ids)
.execute())
logger.debug('Removed %s torrent info records', torrents_removed)
signatures_removed = (ImageStorageSignature
.delete()
.where(ImageStorageSignature.storage << orphaned_storage_ids)
.execute())
logger.debug('Removed %s image storage signatures', signatures_removed)
storages_removed = (ImageStorage
.delete()
.where(ImageStorage.id << orphaned_storage_ids)
.execute())
logger.debug('Removed %s image storage records', storages_removed)
# Determine the paths to remove. We cannot simply remove all paths matching storages, as CAS
# can share the same path. We further filter these paths by checking for any storages still in
# the database with the same content checksum.
paths_to_remove = placements_to_filtered_paths_set(placements_to_remove)
# We are going to make the conscious decision to not delete image storage blobs inside
# transactions.
# This may end up producing garbage in s3, trading off for higher availability in the database.
for location_name, image_path in paths_to_remove:
logger.debug('Removing %s from %s', image_path, location_name)
config.store.remove({location_name}, image_path)
return orphaned_storage_ids
def create_v1_storage(location_name):
storage = ImageStorage.create(cas_path=False, uploading=True)
location = get_image_location_for_name(location_name)
ImageStoragePlacement.create(location=location.id, storage=storage)
storage.locations = {location_name}
return storage
def find_or_create_storage_signature(storage, signature_kind_name):
found = lookup_storage_signature(storage, signature_kind_name)
if found is None:
kind = ImageStorageSignatureKind.get(name=signature_kind_name)
found = ImageStorageSignature.create(storage=storage, kind=kind)
return found
def lookup_storage_signature(storage, signature_kind_name):
kind = ImageStorageSignatureKind.get(name=signature_kind_name)
try:
return (ImageStorageSignature
.select()
.where(ImageStorageSignature.storage == storage, ImageStorageSignature.kind == kind)
.get())
except ImageStorageSignature.DoesNotExist:
return None
def _get_storage(query_modifier):
query = (ImageStoragePlacement
.select(ImageStoragePlacement, ImageStorage)
.switch(ImageStoragePlacement)
.join(ImageStorage))
placements = list(query_modifier(query))
if not placements:
raise InvalidImageException()
found = placements[0].storage
found.locations = {get_image_location_for_id(placement.location_id).name
for placement in placements}
return found
def get_storage_by_uuid(storage_uuid):
def filter_to_uuid(query):
return query.where(ImageStorage.uuid == storage_uuid)
try:
return _get_storage(filter_to_uuid)
except InvalidImageException:
raise InvalidImageException('No storage found with uuid: %s', storage_uuid)
def get_layer_path(storage_record):
""" Returns the path in the storage engine to the layer data referenced by the storage row. """
assert storage_record.cas_path is not None
return get_layer_path_for_storage(storage_record.uuid, storage_record.cas_path,
storage_record.content_checksum)
def get_layer_path_for_storage(storage_uuid, cas_path, content_checksum):
""" Returns the path in the storage engine to the layer data referenced by the storage
information. """
store = config.store
if not cas_path:
logger.debug('Serving layer from legacy v1 path for storage %s', storage_uuid)
return store.v1_image_layer_path(storage_uuid)
return store.blob_path(content_checksum)
def lookup_repo_storages_by_content_checksum(repo, checksums, by_manifest=False):
""" Looks up repository storages (without placements) matching the given repository
and checksum. """
if not checksums:
return []
# There may be many duplicates of the checksums, so for performance reasons we are going
# to use a union to select just one storage with each checksum
queries = []
for counter, checksum in enumerate(set(checksums)):
query_alias = 'q{0}'.format(counter)
# TODO: Remove once we have a new-style model for tracking temp uploaded blobs and
# all legacy tables have been removed.
if by_manifest:
candidate_subq = (ImageStorage
.select(ImageStorage.id, ImageStorage.content_checksum,
ImageStorage.image_size, ImageStorage.uuid, ImageStorage.cas_path,
ImageStorage.uncompressed_size, ImageStorage.uploading)
.join(ManifestBlob)
.where(ManifestBlob.repository == repo,
ImageStorage.content_checksum == checksum)
.limit(1)
.alias(query_alias))
else:
candidate_subq = (ImageStorage
.select(ImageStorage.id, ImageStorage.content_checksum,
ImageStorage.image_size, ImageStorage.uuid, ImageStorage.cas_path,
ImageStorage.uncompressed_size, ImageStorage.uploading)
.join(Image)
.where(Image.repository == repo, ImageStorage.content_checksum == checksum)
.limit(1)
.alias(query_alias))
queries.append(ImageStorage
.select(SQL('*'))
.from_(candidate_subq))
return _basequery.reduce_as_tree(queries)
def set_image_storage_metadata(docker_image_id, namespace_name, repository_name, image_size,
uncompressed_size):
""" Sets metadata that is specific to the binary storage of the data, irrespective of how it
is used in the layer tree.
"""
if image_size is None:
raise DataModelException('Empty image size field')
try:
image = (Image
.select(Image, ImageStorage)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Image)
.join(ImageStorage)
.where(Repository.name == repository_name, Namespace.username == namespace_name,
Image.docker_image_id == docker_image_id)
.get())
except ImageStorage.DoesNotExist:
raise InvalidImageException('No image with specified id and repository')
# We MUST do this here, it can't be done in the corresponding image call because the storage
# has not yet been pushed
image.aggregate_size = _basequery.calculate_image_aggregate_size(image.ancestors, image_size,
image.parent)
image.save()
image.storage.image_size = image_size
image.storage.uncompressed_size = uncompressed_size
image.storage.save()
return image.storage
def get_storage_locations(uuid):
query = (ImageStoragePlacement
.select()
.join(ImageStorage)
.where(ImageStorage.uuid == uuid))
return [get_image_location_for_id(placement.location_id).name for placement in query]
def save_torrent_info(storage_object, piece_length, pieces):
try:
return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
except TorrentInfo.DoesNotExist:
try:
return TorrentInfo.create(storage=storage_object, piece_length=piece_length, pieces=pieces)
except IntegrityError:
# TorrentInfo already exists for this storage.
return TorrentInfo.get(storage=storage_object, piece_length=piece_length)
def get_torrent_info(blob):
try:
return (TorrentInfo
.select()
.where(TorrentInfo.storage == blob)
.get())
except TorrentInfo.DoesNotExist:
raise TorrentInfoDoesNotExist

816
data/model/tag.py Normal file
View file

@ -0,0 +1,816 @@
import logging
from calendar import timegm
from datetime import datetime
from uuid import uuid4
from peewee import IntegrityError, JOIN, fn
from data.model import (image, storage, db_transaction, DataModelException, _basequery,
InvalidManifestException, TagAlreadyCreatedException, StaleTagException,
config)
from data.database import (RepositoryTag, Repository, Image, ImageStorage, Namespace, TagManifest,
RepositoryNotification, Label, TagManifestLabel, get_epoch_timestamp,
db_for_update, Manifest, ManifestLabel, ManifestBlob,
ManifestLegacyImage, TagManifestToManifest,
TagManifestLabelMap, TagToRepositoryTag, Tag, get_epoch_timestamp_ms)
from util.timedeltastring import convert_to_timedelta
logger = logging.getLogger(__name__)
def get_max_id_for_sec_scan():
""" Gets the maximum id for security scanning """
return RepositoryTag.select(fn.Max(RepositoryTag.id)).scalar()
def get_min_id_for_sec_scan(version):
""" Gets the minimum id for a security scanning """
return _tag_alive(RepositoryTag
.select(fn.Min(RepositoryTag.id))
.join(Image)
.where(Image.security_indexed_engine < version)).scalar()
def get_tag_pk_field():
""" Returns the primary key for Image DB model """
return RepositoryTag.id
def get_tags_images_eligible_for_scan(clair_version):
Parent = Image.alias()
ParentImageStorage = ImageStorage.alias()
return _tag_alive(RepositoryTag
.select(Image, ImageStorage, Parent, ParentImageStorage, RepositoryTag)
.join(Image, on=(RepositoryTag.image == Image.id))
.join(ImageStorage, on=(Image.storage == ImageStorage.id))
.switch(Image)
.join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage))
.where(RepositoryTag.hidden == False)
.where(Image.security_indexed_engine < clair_version))
def _tag_alive(query, now_ts=None):
if now_ts is None:
now_ts = get_epoch_timestamp()
return query.where((RepositoryTag.lifetime_end_ts >> None) |
(RepositoryTag.lifetime_end_ts > now_ts))
def filter_has_repository_event(query, event):
""" Filters the query by ensuring the repositories returned have the given event. """
return (query
.join(Repository)
.join(RepositoryNotification)
.where(RepositoryNotification.event == event))
def filter_tags_have_repository_event(query, event):
""" Filters the query by ensuring the repository tags live in a repository that has the given
event. Also returns the image storage for the tag's image and orders the results by
lifetime_start_ts.
"""
query = filter_has_repository_event(query, event)
query = query.switch(RepositoryTag).join(Image).join(ImageStorage)
query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
return query
_MAX_SUB_QUERIES = 100
_MAX_IMAGE_LOOKUP_COUNT = 500
def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=None,
selections=None):
""" Returns all tags that contain the images with the given docker_image_id and storage_uuid,
as specified as an iterable of pairs. """
if not image_pairs:
return []
image_pairs_set = set(image_pairs)
# Find all possible matching image+storages.
images = []
while image_pairs:
image_pairs_slice = image_pairs[:_MAX_IMAGE_LOOKUP_COUNT]
ids = [pair[0] for pair in image_pairs_slice]
uuids = [pair[1] for pair in image_pairs_slice]
images_query = (Image
.select(Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid)
.join(ImageStorage)
.where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)
.switch(Image))
if filter_images is not None:
images_query = filter_images(images_query)
images.extend(list(images_query))
image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:]
# Filter down to those images actually in the pairs set and build the set of queries to run.
individual_image_queries = []
for img in images:
# Make sure the image found is in the set of those requested, and that we haven't already
# processed it. We need this check because the query above checks for images with matching
# IDs OR storage UUIDs, rather than the expected ID+UUID pair. We do this for efficiency
# reasons, and it is highly unlikely we'll find an image with a mismatch, but we need this
# check to be absolutely sure.
pair = (img.docker_image_id, img.storage.uuid)
if pair not in image_pairs_set:
continue
# Remove the pair so we don't try it again.
image_pairs_set.remove(pair)
ancestors_str = '%s%s/%%' % (img.ancestors, img.id)
query = (Image
.select(Image.id)
.where((Image.id == img.id) | (Image.ancestors ** ancestors_str)))
individual_image_queries.append(query)
if not individual_image_queries:
return []
# Shard based on the max subquery count. This is used to prevent going over the DB's max query
# size, as well as to prevent the DB from locking up on a massive query.
sharded_queries = []
while individual_image_queries:
shard = individual_image_queries[:_MAX_SUB_QUERIES]
sharded_queries.append(_basequery.reduce_as_tree(shard))
individual_image_queries = individual_image_queries[_MAX_SUB_QUERIES:]
# Collect IDs of the tags found for each query.
tags = {}
for query in sharded_queries:
ImageAlias = Image.alias()
tag_query = (_tag_alive(RepositoryTag
.select(*(selections or []))
.distinct()
.join(ImageAlias)
.where(RepositoryTag.hidden == False)
.where(ImageAlias.id << query)
.switch(RepositoryTag)))
if filter_tags is not None:
tag_query = filter_tags(tag_query)
for tag in tag_query:
tags[tag.id] = tag
return tags.values()
def get_matching_tags(docker_image_id, storage_uuid, *args):
""" Returns a query pointing to all tags that contain the image with the
given docker_image_id and storage_uuid. """
image_row = image.get_image_with_storage(docker_image_id, storage_uuid)
if image_row is None:
return RepositoryTag.select().where(RepositoryTag.id < 0) # Empty query.
ancestors_str = '%s%s/%%' % (image_row.ancestors, image_row.id)
return _tag_alive(RepositoryTag
.select(*args)
.distinct()
.join(Image)
.join(ImageStorage)
.where(RepositoryTag.hidden == False)
.where((Image.id == image_row.id) |
(Image.ancestors ** ancestors_str)))
def get_tags_for_image(image_id, *args):
return _tag_alive(RepositoryTag
.select(*args)
.distinct()
.where(RepositoryTag.image == image_id,
RepositoryTag.hidden == False))
def get_tag_manifest_digests(tags):
""" Returns a map from tag ID to its associated manifest digest, if any. """
if not tags:
return dict()
manifests = (TagManifest
.select(TagManifest.tag, TagManifest.digest)
.where(TagManifest.tag << [t.id for t in tags]))
return {manifest.tag_id: manifest.digest for manifest in manifests}
def list_active_repo_tags(repo, start_id=None, limit=None, include_images=True):
""" Returns all of the active, non-hidden tags in a repository, joined to they images
and (if present), their manifest.
"""
if include_images:
query = _tag_alive(RepositoryTag
.select(RepositoryTag, Image, ImageStorage, TagManifest.digest)
.join(Image)
.join(ImageStorage)
.where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
.switch(RepositoryTag)
.join(TagManifest, JOIN.LEFT_OUTER)
.order_by(RepositoryTag.id))
else:
query = _tag_alive(RepositoryTag
.select(RepositoryTag)
.where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
.order_by(RepositoryTag.id))
if start_id is not None:
query = query.where(RepositoryTag.id >= start_id)
if limit is not None:
query = query.limit(limit)
return query
def list_repository_tags(namespace_name, repository_name, include_hidden=False,
include_storage=False):
to_select = (RepositoryTag, Image)
if include_storage:
to_select = (RepositoryTag, Image, ImageStorage)
query = _tag_alive(RepositoryTag
.select(*to_select)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryTag)
.join(Image)
.where(Repository.name == repository_name,
Namespace.username == namespace_name))
if not include_hidden:
query = query.where(RepositoryTag.hidden == False)
if include_storage:
query = query.switch(Image).join(ImageStorage)
return query
def create_or_update_tag(namespace_name, repository_name, tag_name, tag_docker_image_id,
reversion=False, now_ms=None):
try:
repo = _basequery.get_existing_repository(namespace_name, repository_name)
except Repository.DoesNotExist:
raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
return create_or_update_tag_for_repo(repo.id, tag_name, tag_docker_image_id, reversion=reversion,
now_ms=now_ms)
def create_or_update_tag_for_repo(repository_id, tag_name, tag_docker_image_id, reversion=False,
oci_manifest=None, now_ms=None):
now_ms = now_ms or get_epoch_timestamp_ms()
now_ts = int(now_ms / 1000)
with db_transaction():
try:
tag = db_for_update(_tag_alive(RepositoryTag
.select()
.where(RepositoryTag.repository == repository_id,
RepositoryTag.name == tag_name), now_ts)).get()
tag.lifetime_end_ts = now_ts
tag.save()
# Check for an OCI tag.
try:
oci_tag = db_for_update(Tag
.select()
.join(TagToRepositoryTag)
.where(TagToRepositoryTag.repository_tag == tag)).get()
oci_tag.lifetime_end_ms = now_ms
oci_tag.save()
except Tag.DoesNotExist:
pass
except RepositoryTag.DoesNotExist:
pass
except IntegrityError:
msg = 'Tag with name %s was stale when we tried to update it; Please retry the push'
raise StaleTagException(msg % tag_name)
try:
image_obj = Image.get(Image.docker_image_id == tag_docker_image_id,
Image.repository == repository_id)
except Image.DoesNotExist:
raise DataModelException('Invalid image with id: %s' % tag_docker_image_id)
try:
created = RepositoryTag.create(repository=repository_id, image=image_obj, name=tag_name,
lifetime_start_ts=now_ts, reversion=reversion)
if oci_manifest:
# Create the OCI tag as well.
oci_tag = Tag.create(repository=repository_id, manifest=oci_manifest, name=tag_name,
lifetime_start_ms=now_ms, reversion=reversion,
tag_kind=Tag.tag_kind.get_id('tag'))
TagToRepositoryTag.create(tag=oci_tag, repository_tag=created, repository=repository_id)
return created
except IntegrityError:
msg = 'Tag with name %s and lifetime start %s already exists'
raise TagAlreadyCreatedException(msg % (tag_name, now_ts))
def create_temporary_hidden_tag(repo, image_obj, expiration_s):
""" Create a tag with a defined timeline, that will not appear in the UI or CLI. Returns the name
of the temporary tag. """
now_ts = get_epoch_timestamp()
expire_ts = now_ts + expiration_s
tag_name = str(uuid4())
RepositoryTag.create(repository=repo, image=image_obj, name=tag_name, lifetime_start_ts=now_ts,
lifetime_end_ts=expire_ts, hidden=True)
return tag_name
def lookup_unrecoverable_tags(repo):
""" Returns the tags in a repository that are expired and past their time machine recovery
period. """
expired_clause = get_epoch_timestamp() - Namespace.removed_tag_expiration_s
return (RepositoryTag
.select()
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(RepositoryTag.repository == repo)
.where(~(RepositoryTag.lifetime_end_ts >> None),
RepositoryTag.lifetime_end_ts <= expired_clause))
def delete_tag(namespace_name, repository_name, tag_name, now_ms=None):
now_ms = now_ms or get_epoch_timestamp_ms()
now_ts = int(now_ms / 1000)
with db_transaction():
try:
query = _tag_alive(RepositoryTag
.select(RepositoryTag, Repository)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(Repository.name == repository_name,
Namespace.username == namespace_name,
RepositoryTag.name == tag_name), now_ts)
found = db_for_update(query).get()
except RepositoryTag.DoesNotExist:
msg = ('Invalid repository tag \'%s\' on repository \'%s/%s\'' %
(tag_name, namespace_name, repository_name))
raise DataModelException(msg)
found.lifetime_end_ts = now_ts
found.save()
try:
oci_tag_query = TagToRepositoryTag.select().where(TagToRepositoryTag.repository_tag == found)
oci_tag = db_for_update(oci_tag_query).get().tag
oci_tag.lifetime_end_ms = now_ms
oci_tag.save()
except TagToRepositoryTag.DoesNotExist:
pass
return found
def _get_repo_tag_image(tag_name, include_storage, modifier):
query = Image.select().join(RepositoryTag)
if include_storage:
query = (Image
.select(Image, ImageStorage)
.join(ImageStorage)
.switch(Image)
.join(RepositoryTag))
images = _tag_alive(modifier(query.where(RepositoryTag.name == tag_name)))
if not images:
raise DataModelException('Unable to find image for tag.')
else:
return images[0]
def get_repo_tag_image(repo, tag_name, include_storage=False):
def modifier(query):
return query.where(RepositoryTag.repository == repo)
return _get_repo_tag_image(tag_name, include_storage, modifier)
def get_tag_image(namespace_name, repository_name, tag_name, include_storage=False):
def modifier(query):
return (query
.switch(RepositoryTag)
.join(Repository)
.join(Namespace)
.where(Namespace.username == namespace_name, Repository.name == repository_name))
return _get_repo_tag_image(tag_name, include_storage, modifier)
def list_repository_tag_history(repo_obj, page=1, size=100, specific_tag=None, active_tags_only=False, since_time=None):
# Only available on OCI model
if since_time is not None:
raise NotImplementedError
query = (RepositoryTag
.select(RepositoryTag, Image, ImageStorage)
.join(Image)
.join(ImageStorage)
.switch(RepositoryTag)
.where(RepositoryTag.repository == repo_obj)
.where(RepositoryTag.hidden == False)
.order_by(RepositoryTag.lifetime_start_ts.desc(), RepositoryTag.name)
.limit(size + 1)
.offset(size * (page - 1)))
if active_tags_only:
query = _tag_alive(query)
if specific_tag:
query = query.where(RepositoryTag.name == specific_tag)
tags = list(query)
if not tags:
return [], {}, False
manifest_map = get_tag_manifest_digests(tags)
return tags[0:size], manifest_map, len(tags) > size
def restore_tag_to_manifest(repo_obj, tag_name, manifest_digest):
""" Restores a tag to a specific manifest digest. """
with db_transaction():
# Verify that the manifest digest already existed under this repository under the
# tag.
try:
tag_manifest = (TagManifest
.select(TagManifest, RepositoryTag, Image)
.join(RepositoryTag)
.join(Image)
.where(RepositoryTag.repository == repo_obj)
.where(RepositoryTag.name == tag_name)
.where(TagManifest.digest == manifest_digest)
.get())
except TagManifest.DoesNotExist:
raise DataModelException('Cannot restore to unknown or invalid digest')
# Lookup the existing image, if any.
try:
existing_image = get_repo_tag_image(repo_obj, tag_name)
except DataModelException:
existing_image = None
docker_image_id = tag_manifest.tag.image.docker_image_id
oci_manifest = None
try:
oci_manifest = Manifest.get(repository=repo_obj, digest=manifest_digest)
except Manifest.DoesNotExist:
pass
# Change the tag and tag manifest to point to the updated image.
updated_tag = create_or_update_tag_for_repo(repo_obj, tag_name, docker_image_id,
reversion=True, oci_manifest=oci_manifest)
tag_manifest.tag = updated_tag
tag_manifest.save()
return existing_image
def restore_tag_to_image(repo_obj, tag_name, docker_image_id):
""" Restores a tag to a specific image ID. """
with db_transaction():
# Verify that the image ID already existed under this repository under the
# tag.
try:
(RepositoryTag
.select()
.join(Image)
.where(RepositoryTag.repository == repo_obj)
.where(RepositoryTag.name == tag_name)
.where(Image.docker_image_id == docker_image_id)
.get())
except RepositoryTag.DoesNotExist:
raise DataModelException('Cannot restore to unknown or invalid image')
# Lookup the existing image, if any.
try:
existing_image = get_repo_tag_image(repo_obj, tag_name)
except DataModelException:
existing_image = None
create_or_update_tag_for_repo(repo_obj, tag_name, docker_image_id, reversion=True)
return existing_image
def store_tag_manifest_for_testing(namespace_name, repository_name, tag_name, manifest,
leaf_layer_id, storage_id_map):
""" Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
object, as well as a boolean indicating whether the TagManifest was created.
"""
try:
repo = _basequery.get_existing_repository(namespace_name, repository_name)
except Repository.DoesNotExist:
raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))
return store_tag_manifest_for_repo(repo.id, tag_name, manifest, leaf_layer_id, storage_id_map)
def store_tag_manifest_for_repo(repository_id, tag_name, manifest, leaf_layer_id, storage_id_map,
reversion=False):
""" Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
object, as well as a boolean indicating whether the TagManifest was created.
"""
# Create the new-style OCI manifest and its blobs.
oci_manifest = _populate_manifest_and_blobs(repository_id, manifest, storage_id_map,
leaf_layer_id=leaf_layer_id)
# Create the tag for the tag manifest.
tag = create_or_update_tag_for_repo(repository_id, tag_name, leaf_layer_id,
reversion=reversion, oci_manifest=oci_manifest)
# Add a tag manifest pointing to that tag.
try:
manifest = TagManifest.get(digest=manifest.digest)
manifest.tag = tag
manifest.save()
return manifest, False
except TagManifest.DoesNotExist:
created = _associate_manifest(tag, oci_manifest)
return created, True
def get_active_tag(namespace, repo_name, tag_name):
return _tag_alive(RepositoryTag
.select()
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(RepositoryTag.name == tag_name, Repository.name == repo_name,
Namespace.username == namespace)).get()
def get_active_tag_for_repo(repo, tag_name):
try:
return _tag_alive(RepositoryTag
.select(RepositoryTag, Image, ImageStorage)
.join(Image)
.join(ImageStorage)
.where(RepositoryTag.name == tag_name,
RepositoryTag.repository == repo,
RepositoryTag.hidden == False)).get()
except RepositoryTag.DoesNotExist:
return None
def get_expired_tag_in_repo(repo, tag_name):
return (RepositoryTag
.select()
.where(RepositoryTag.name == tag_name, RepositoryTag.repository == repo)
.where(~(RepositoryTag.lifetime_end_ts >> None))
.where(RepositoryTag.lifetime_end_ts <= get_epoch_timestamp())
.get())
def get_possibly_expired_tag(namespace, repo_name, tag_name):
return (RepositoryTag
.select()
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(RepositoryTag.name == tag_name, Repository.name == repo_name,
Namespace.username == namespace)).get()
def associate_generated_tag_manifest_with_tag(tag, manifest, storage_id_map):
oci_manifest = _populate_manifest_and_blobs(tag.repository, manifest, storage_id_map)
with db_transaction():
try:
(Tag
.select()
.join(TagToRepositoryTag)
.where(TagToRepositoryTag.repository_tag == tag)).get()
except Tag.DoesNotExist:
oci_tag = Tag.create(repository=tag.repository, manifest=oci_manifest, name=tag.name,
reversion=tag.reversion,
lifetime_start_ms=tag.lifetime_start_ts * 1000,
lifetime_end_ms=(tag.lifetime_end_ts * 1000
if tag.lifetime_end_ts else None),
tag_kind=Tag.tag_kind.get_id('tag'))
TagToRepositoryTag.create(tag=oci_tag, repository_tag=tag, repository=tag.repository)
return _associate_manifest(tag, oci_manifest)
def _associate_manifest(tag, oci_manifest):
with db_transaction():
tag_manifest = TagManifest.create(tag=tag, digest=oci_manifest.digest,
json_data=oci_manifest.manifest_bytes)
TagManifestToManifest.create(tag_manifest=tag_manifest, manifest=oci_manifest)
return tag_manifest
def _populate_manifest_and_blobs(repository, manifest, storage_id_map, leaf_layer_id=None):
leaf_layer_id = leaf_layer_id or manifest.leaf_layer_v1_image_id
try:
legacy_image = Image.get(Image.docker_image_id == leaf_layer_id,
Image.repository == repository)
except Image.DoesNotExist:
raise DataModelException('Invalid image with id: %s' % leaf_layer_id)
storage_ids = set()
for blob_digest in manifest.local_blob_digests:
image_storage_id = storage_id_map.get(blob_digest)
if image_storage_id is None:
logger.error('Missing blob for manifest `%s` in: %s', blob_digest, storage_id_map)
raise DataModelException('Missing blob for manifest `%s`' % blob_digest)
if image_storage_id in storage_ids:
continue
storage_ids.add(image_storage_id)
return populate_manifest(repository, manifest, legacy_image, storage_ids)
def populate_manifest(repository, manifest, legacy_image, storage_ids):
""" Populates the rows for the manifest, including its blobs and legacy image. """
media_type = Manifest.media_type.get_id(manifest.media_type)
# Check for an existing manifest. If present, return it.
try:
return Manifest.get(repository=repository, digest=manifest.digest)
except Manifest.DoesNotExist:
pass
with db_transaction():
try:
manifest_row = Manifest.create(digest=manifest.digest, repository=repository,
manifest_bytes=manifest.bytes.as_encoded_str(),
media_type=media_type)
except IntegrityError as ie:
logger.debug('Got integrity error when trying to write manifest: %s', ie)
return Manifest.get(repository=repository, digest=manifest.digest)
ManifestLegacyImage.create(manifest=manifest_row, repository=repository, image=legacy_image)
blobs_to_insert = [dict(manifest=manifest_row, repository=repository,
blob=storage_id) for storage_id in storage_ids]
if blobs_to_insert:
ManifestBlob.insert_many(blobs_to_insert).execute()
return manifest_row
def get_tag_manifest(tag):
try:
return TagManifest.get(tag=tag)
except TagManifest.DoesNotExist:
return None
def load_tag_manifest(namespace, repo_name, tag_name):
try:
return (_load_repo_manifests(namespace, repo_name)
.where(RepositoryTag.name == tag_name)
.get())
except TagManifest.DoesNotExist:
msg = 'Manifest not found for tag {0} in repo {1}/{2}'.format(tag_name, namespace, repo_name)
raise InvalidManifestException(msg)
def delete_manifest_by_digest(namespace, repo_name, digest):
tag_manifests = list(_load_repo_manifests(namespace, repo_name)
.where(TagManifest.digest == digest))
now_ms = get_epoch_timestamp_ms()
for tag_manifest in tag_manifests:
try:
tag = _tag_alive(RepositoryTag.select().where(RepositoryTag.id == tag_manifest.tag_id)).get()
delete_tag(namespace, repo_name, tag_manifest.tag.name, now_ms)
except RepositoryTag.DoesNotExist:
pass
return [tag_manifest.tag for tag_manifest in tag_manifests]
def load_manifest_by_digest(namespace, repo_name, digest, allow_dead=False):
try:
return (_load_repo_manifests(namespace, repo_name, allow_dead=allow_dead)
.where(TagManifest.digest == digest)
.get())
except TagManifest.DoesNotExist:
msg = 'Manifest not found with digest {0} in repo {1}/{2}'.format(digest, namespace, repo_name)
raise InvalidManifestException(msg)
def _load_repo_manifests(namespace, repo_name, allow_dead=False):
query = (TagManifest
.select(TagManifest, RepositoryTag)
.join(RepositoryTag)
.join(Image)
.join(Repository)
.join(Namespace, on=(Namespace.id == Repository.namespace_user))
.where(Repository.name == repo_name, Namespace.username == namespace))
if not allow_dead:
query = _tag_alive(query)
return query
def change_repository_tag_expiration(namespace_name, repo_name, tag_name, expiration_date):
""" Changes the expiration of the tag with the given name to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring.
"""
try:
tag = get_active_tag(namespace_name, repo_name, tag_name)
return change_tag_expiration(tag, expiration_date)
except RepositoryTag.DoesNotExist:
return (None, False)
def set_tag_expiration_for_manifest(tag_manifest, expiration_sec):
"""
Changes the expiration of the tag that points to the given manifest to be its lifetime start +
the expiration seconds.
"""
expiration_time_ts = tag_manifest.tag.lifetime_start_ts + expiration_sec
expiration_date = datetime.utcfromtimestamp(expiration_time_ts)
return change_tag_expiration(tag_manifest.tag, expiration_date)
def change_tag_expiration(tag, expiration_date):
""" Changes the expiration of the given tag to the given expiration datetime. If
the expiration datetime is None, then the tag is marked as not expiring.
"""
end_ts = None
min_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MINIMUM', '1h'))
max_expire_sec = convert_to_timedelta(config.app_config.get('LABELED_EXPIRATION_MAXIMUM', '104w'))
if expiration_date is not None:
offset = timegm(expiration_date.utctimetuple()) - tag.lifetime_start_ts
offset = min(max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds())
end_ts = tag.lifetime_start_ts + offset
if end_ts == tag.lifetime_end_ts:
return (None, True)
return set_tag_end_ts(tag, end_ts)
def set_tag_end_ts(tag, end_ts):
""" Sets the end timestamp for a tag. Should only be called by change_tag_expiration
or tests.
"""
end_ms = end_ts * 1000 if end_ts is not None else None
with db_transaction():
# Note: We check not just the ID of the tag but also its lifetime_end_ts, to ensure that it has
# not changed while we were updating it expiration.
result = (RepositoryTag
.update(lifetime_end_ts=end_ts)
.where(RepositoryTag.id == tag.id,
RepositoryTag.lifetime_end_ts == tag.lifetime_end_ts)
.execute())
# Check for a mapping to an OCI tag.
try:
oci_tag = (Tag
.select()
.join(TagToRepositoryTag)
.where(TagToRepositoryTag.repository_tag == tag)
.get())
(Tag
.update(lifetime_end_ms=end_ms)
.where(Tag.id == oci_tag.id,
Tag.lifetime_end_ms == oci_tag.lifetime_end_ms)
.execute())
except Tag.DoesNotExist:
pass
return (tag.lifetime_end_ts, result > 0)
def find_matching_tag(repo_id, tag_names):
""" Finds the most recently pushed alive tag in the repository with one of the given names,
if any.
"""
try:
return (_tag_alive(RepositoryTag
.select()
.where(RepositoryTag.repository == repo_id,
RepositoryTag.name << list(tag_names))
.order_by(RepositoryTag.lifetime_start_ts.desc()))
.get())
except RepositoryTag.DoesNotExist:
return None
def get_most_recent_tag(repo_id):
""" Returns the most recently pushed alive tag in the repository, or None if none. """
try:
return (_tag_alive(RepositoryTag
.select()
.where(RepositoryTag.repository == repo_id, RepositoryTag.hidden == False)
.order_by(RepositoryTag.lifetime_start_ts.desc()))
.get())
except RepositoryTag.DoesNotExist:
return None

519
data/model/team.py Normal file
View file

@ -0,0 +1,519 @@
import json
import re
import uuid
from datetime import datetime
from peewee import fn
from data.database import (Team, TeamMember, TeamRole, User, TeamMemberInvite, RepositoryPermission,
TeamSync, LoginService, FederatedLogin, db_random_func, db_transaction)
from data.model import (DataModelException, InvalidTeamException, UserAlreadyInTeam,
InvalidTeamMemberException, _basequery)
from data.text import prefix_search
from util.validation import validate_username
from util.morecollections import AttrDict
MIN_TEAMNAME_LENGTH = 2
MAX_TEAMNAME_LENGTH = 255
VALID_TEAMNAME_REGEX = r'^([a-z0-9]+(?:[._-][a-z0-9]+)*)$'
def validate_team_name(teamname):
if not re.match(VALID_TEAMNAME_REGEX, teamname):
return (False, 'Namespace must match expression ' + VALID_TEAMNAME_REGEX)
length_match = (len(teamname) >= MIN_TEAMNAME_LENGTH and len(teamname) <= MAX_TEAMNAME_LENGTH)
if not length_match:
return (False, 'Team must be between %s and %s characters in length' %
(MIN_TEAMNAME_LENGTH, MAX_TEAMNAME_LENGTH))
return (True, '')
def create_team(name, org_obj, team_role_name, description=''):
(teamname_valid, teamname_issue) = validate_team_name(name)
if not teamname_valid:
raise InvalidTeamException('Invalid team name %s: %s' % (name, teamname_issue))
if not org_obj.organization:
raise InvalidTeamException('Specified organization %s was not an organization' %
org_obj.username)
team_role = TeamRole.get(TeamRole.name == team_role_name)
return Team.create(name=name, organization=org_obj, role=team_role,
description=description)
def add_user_to_team(user_obj, team):
try:
return TeamMember.create(user=user_obj, team=team)
except Exception:
raise UserAlreadyInTeam('User %s is already a member of team %s' %
(user_obj.username, team.name))
def remove_user_from_team(org_name, team_name, username, removed_by_username):
Org = User.alias()
joined = TeamMember.select().join(User).switch(TeamMember).join(Team)
with_role = joined.join(TeamRole)
with_org = with_role.switch(Team).join(Org,
on=(Org.id == Team.organization))
found = list(with_org.where(User.username == username,
Org.username == org_name,
Team.name == team_name))
if not found:
raise DataModelException('User %s does not belong to team %s' %
(username, team_name))
if username == removed_by_username:
admin_team_query = __get_user_admin_teams(org_name, username)
admin_team_names = {team.name for team in admin_team_query}
if team_name in admin_team_names and len(admin_team_names) <= 1:
msg = 'User cannot remove themselves from their only admin team.'
raise DataModelException(msg)
user_in_team = found[0]
user_in_team.delete_instance()
def set_team_org_permission(team, team_role_name, set_by_username):
if team.role.name == 'admin' and team_role_name != 'admin':
# We need to make sure we're not removing the users only admin role
user_admin_teams = __get_user_admin_teams(team.organization.username, set_by_username)
admin_team_set = {admin_team.name for admin_team in user_admin_teams}
if team.name in admin_team_set and len(admin_team_set) <= 1:
msg = (('Cannot remove admin from team \'%s\' because calling user ' +
'would no longer have admin on org \'%s\'') %
(team.name, team.organization.username))
raise DataModelException(msg)
new_role = TeamRole.get(TeamRole.name == team_role_name)
team.role = new_role
team.save()
return team
def __get_user_admin_teams(org_name, username):
Org = User.alias()
user_teams = Team.select().join(TeamMember).join(User)
with_org = user_teams.switch(Team).join(Org,
on=(Org.id == Team.organization))
with_role = with_org.switch(Team).join(TeamRole)
admin_teams = with_role.where(User.username == username,
Org.username == org_name,
TeamRole.name == 'admin')
return admin_teams
def remove_team(org_name, team_name, removed_by_username):
joined = Team.select(Team, TeamRole).join(User).switch(Team).join(TeamRole)
found = list(joined.where(User.organization == True,
User.username == org_name,
Team.name == team_name))
if not found:
raise InvalidTeamException('Team \'%s\' is not a team in org \'%s\'' %
(team_name, org_name))
team = found[0]
if team.role.name == 'admin':
admin_teams = list(__get_user_admin_teams(org_name, removed_by_username))
if len(admin_teams) <= 1:
# The team we are trying to remove is the only admin team containing this user.
msg = "Deleting team '%s' would remove admin ability for user '%s' in organization '%s'"
raise DataModelException(msg % (team_name, removed_by_username, org_name))
team.delete_instance(recursive=True, delete_nullable=True)
def add_or_invite_to_team(inviter, team, user_obj=None, email=None, requires_invite=True):
# If the user is a member of the organization, then we simply add the
# user directly to the team. Otherwise, an invite is created for the user/email.
# We return None if the user was directly added and the invite object if the user was invited.
if user_obj and requires_invite:
orgname = team.organization.username
# If the user is part of the organization (or a robot), then no invite is required.
if user_obj.robot:
requires_invite = False
if not user_obj.username.startswith(orgname + '+'):
raise InvalidTeamMemberException('Cannot add the specified robot to this team, ' +
'as it is not a member of the organization')
else:
query = (TeamMember
.select()
.where(TeamMember.user == user_obj)
.join(Team)
.join(User)
.where(User.username == orgname, User.organization == True))
requires_invite = not any(query)
# If we have a valid user and no invite is required, simply add the user to the team.
if user_obj and not requires_invite:
add_user_to_team(user_obj, team)
return None
email_address = email if not user_obj else None
return TeamMemberInvite.create(user=user_obj, email=email_address, team=team, inviter=inviter)
def get_matching_user_teams(team_prefix, user_obj, limit=10):
team_prefix_search = prefix_search(Team.name, team_prefix)
query = (Team
.select(Team.id.distinct(), Team)
.join(User)
.switch(Team)
.join(TeamMember)
.where(TeamMember.user == user_obj, team_prefix_search)
.limit(limit))
return query
def get_organization_team(orgname, teamname):
joined = Team.select().join(User)
query = joined.where(Team.name == teamname, User.organization == True,
User.username == orgname).limit(1)
result = list(query)
if not result:
raise InvalidTeamException('Team does not exist: %s/%s', orgname,
teamname)
return result[0]
def get_matching_admined_teams(team_prefix, user_obj, limit=10):
team_prefix_search = prefix_search(Team.name, team_prefix)
admined_orgs = (_basequery.get_user_organizations(user_obj.username)
.switch(Team)
.join(TeamRole)
.where(TeamRole.name == 'admin'))
query = (Team
.select(Team.id.distinct(), Team)
.join(User)
.switch(Team)
.join(TeamMember)
.where(team_prefix_search, Team.organization << (admined_orgs))
.limit(limit))
return query
def get_matching_teams(team_prefix, organization):
team_prefix_search = prefix_search(Team.name, team_prefix)
query = Team.select().where(team_prefix_search, Team.organization == organization)
return query.limit(10)
def get_teams_within_org(organization, has_external_auth=False):
""" Returns a AttrDict of team info (id, name, description), its role under the org,
the number of repositories on which it has permission, and the number of members.
"""
query = (Team.select()
.where(Team.organization == organization)
.join(TeamRole))
def _team_view(team):
return {
'id': team.id,
'name': team.name,
'description': team.description,
'role_name': Team.role.get_name(team.role_id),
'repo_count': 0,
'member_count': 0,
'is_synced': False,
}
teams = {team.id: _team_view(team) for team in query}
if not teams:
# Just in case. Should ideally never happen.
return []
# Add repository permissions count.
permission_tuples = (RepositoryPermission.select(RepositoryPermission.team,
fn.Count(RepositoryPermission.id))
.where(RepositoryPermission.team << teams.keys())
.group_by(RepositoryPermission.team)
.tuples())
for perm_tuple in permission_tuples:
teams[perm_tuple[0]]['repo_count'] = perm_tuple[1]
# Add the member count.
members_tuples = (TeamMember.select(TeamMember.team,
fn.Count(TeamMember.id))
.where(TeamMember.team << teams.keys())
.group_by(TeamMember.team)
.tuples())
for member_tuple in members_tuples:
teams[member_tuple[0]]['member_count'] = member_tuple[1]
# Add syncing information.
if has_external_auth:
sync_query = TeamSync.select(TeamSync.team).where(TeamSync.team << teams.keys())
for team_sync in sync_query:
teams[team_sync.team_id]['is_synced'] = True
return [AttrDict(team_info) for team_info in teams.values()]
def get_user_teams_within_org(username, organization):
joined = Team.select().join(TeamMember).join(User)
return joined.where(Team.organization == organization,
User.username == username)
def list_organization_members_by_teams(organization):
query = (TeamMember
.select(Team, User)
.join(Team)
.switch(TeamMember)
.join(User)
.where(Team.organization == organization))
return query
def get_organization_team_member_invites(teamid):
joined = TeamMemberInvite.select().join(Team).join(User)
query = joined.where(Team.id == teamid)
return query
def delete_team_email_invite(team, email):
try:
found = TeamMemberInvite.get(TeamMemberInvite.email == email, TeamMemberInvite.team == team)
except TeamMemberInvite.DoesNotExist:
return False
found.delete_instance()
return True
def delete_team_user_invite(team, user_obj):
try:
found = TeamMemberInvite.get(TeamMemberInvite.user == user_obj, TeamMemberInvite.team == team)
except TeamMemberInvite.DoesNotExist:
return False
found.delete_instance()
return True
def lookup_team_invites_by_email(email):
return TeamMemberInvite.select().where(TeamMemberInvite.email == email)
def lookup_team_invites(user_obj):
return TeamMemberInvite.select().where(TeamMemberInvite.user == user_obj)
def lookup_team_invite(code, user_obj=None):
# Lookup the invite code.
try:
found = TeamMemberInvite.get(TeamMemberInvite.invite_token == code)
except TeamMemberInvite.DoesNotExist:
raise DataModelException('Invalid confirmation code.')
if user_obj and found.user != user_obj:
raise DataModelException('Invalid confirmation code.')
return found
def delete_team_invite(code, user_obj=None):
found = lookup_team_invite(code, user_obj)
team = found.team
inviter = found.inviter
found.delete_instance()
return (team, inviter)
def find_matching_team_invite(code, user_obj):
""" Finds a team invite with the given code that applies to the given user and returns it or
raises a DataModelException if not found. """
found = lookup_team_invite(code)
# If the invite is for a specific user, we have to confirm that here.
if found.user is not None and found.user != user_obj:
message = """This invite is intended for user "%s".
Please login to that account and try again.""" % found.user.username
raise DataModelException(message)
return found
def find_organization_invites(organization, user_obj):
""" Finds all organization team invites for the given user under the given organization. """
invite_check = (TeamMemberInvite.user == user_obj)
if user_obj.verified:
invite_check = invite_check | (TeamMemberInvite.email == user_obj.email)
query = (TeamMemberInvite
.select()
.join(Team)
.where(invite_check, Team.organization == organization))
return query
def confirm_team_invite(code, user_obj):
""" Confirms the given team invite code for the given user by adding the user to the team
and deleting the code. Raises a DataModelException if the code was not found or does
not apply to the given user. If the user is invited to two or more teams under the
same organization, they are automatically confirmed for all of them. """
found = find_matching_team_invite(code, user_obj)
# Find all matching invitations for the user under the organization.
code_found = False
for invite in find_organization_invites(found.team.organization, user_obj):
# Add the user to the team.
try:
code_found = True
add_user_to_team(user_obj, invite.team)
except UserAlreadyInTeam:
# Ignore.
pass
# Delete the invite and return the team.
invite.delete_instance()
if not code_found:
if found.user:
message = """This invite is intended for user "%s".
Please login to that account and try again.""" % found.user.username
raise DataModelException(message)
else:
message = """This invite is intended for email "%s".
Please login to that account and try again.""" % found.email
raise DataModelException(message)
team = found.team
inviter = found.inviter
return (team, inviter)
def get_federated_team_member_mapping(team, login_service_name):
""" Returns a dict of all federated IDs for all team members in the team whose users are
bound to the login service within the given name. The dictionary is from federated service
identifier (username) to their Quay User table ID.
"""
login_service = LoginService.get(name=login_service_name)
query = (FederatedLogin
.select(FederatedLogin.service_ident, User.id)
.join(User)
.join(TeamMember)
.join(Team)
.where(Team.id == team, User.robot == False, FederatedLogin.service == login_service))
return dict(query.tuples())
def list_team_users(team):
""" Returns an iterator of all the *users* found in a team. Does not include robots. """
return (User
.select()
.join(TeamMember)
.join(Team)
.where(Team.id == team, User.robot == False))
def list_team_robots(team):
""" Returns an iterator of all the *robots* found in a team. Does not include users. """
return (User
.select()
.join(TeamMember)
.join(Team)
.where(Team.id == team, User.robot == True))
def set_team_syncing(team, login_service_name, config):
""" Sets the given team to sync to the given service using the given config. """
login_service = LoginService.get(name=login_service_name)
return TeamSync.create(team=team, transaction_id='', service=login_service,
config=json.dumps(config))
def remove_team_syncing(orgname, teamname):
""" Removes syncing on the team matching the given organization name and team name. """
existing = get_team_sync_information(orgname, teamname)
if existing:
existing.delete_instance()
def get_stale_team(stale_timespan):
""" Returns a team that is setup to sync to an external group, and who has not been synced in
now - stale_timespan. Returns None if none found.
"""
stale_at = datetime.now() - stale_timespan
try:
candidates = (TeamSync
.select(TeamSync.id)
.where((TeamSync.last_updated <= stale_at) | (TeamSync.last_updated >> None))
.limit(500)
.alias('candidates'))
found = (TeamSync
.select(candidates.c.id)
.from_(candidates)
.order_by(db_random_func())
.get())
if found is None:
return
return TeamSync.select(TeamSync, Team).join(Team).where(TeamSync.id == found.id).get()
except TeamSync.DoesNotExist:
return None
def get_team_sync_information(orgname, teamname):
""" Returns the team syncing information for the team with the given name under the organization
with the given name or None if none.
"""
query = (TeamSync
.select(TeamSync, LoginService)
.join(Team)
.join(User)
.switch(TeamSync)
.join(LoginService)
.where(Team.name == teamname, User.organization == True, User.username == orgname))
try:
return query.get()
except TeamSync.DoesNotExist:
return None
def update_sync_status(team_sync_info):
""" Attempts to update the transaction ID and last updated time on a TeamSync object. If the
transaction ID on the entry in the DB does not match that found on the object, this method
returns False, which indicates another caller updated it first.
"""
new_transaction_id = str(uuid.uuid4())
query = (TeamSync
.update(transaction_id=new_transaction_id, last_updated=datetime.now())
.where(TeamSync.id == team_sync_info.id,
TeamSync.transaction_id == team_sync_info.transaction_id))
return query.execute() == 1
def delete_members_not_present(team, member_id_set):
""" Deletes all members of the given team that are not found in the member ID set. """
with db_transaction():
user_ids = set([u.id for u in list_team_users(team)])
to_delete = list(user_ids - member_id_set)
if to_delete:
query = TeamMember.delete().where(TeamMember.team == team, TeamMember.user << to_delete)
return query.execute()
return 0

View file

View file

@ -0,0 +1,126 @@
from datetime import datetime, timedelta
from mock import patch
import pytest
from data.model import config as _config
from data import model
from data.model.appspecifictoken import create_token, revoke_token, access_valid_token
from data.model.appspecifictoken import gc_expired_tokens, get_expiring_tokens
from data.model.appspecifictoken import get_full_token_string
from util.timedeltastring import convert_to_timedelta
from test.fixtures import *
@pytest.mark.parametrize('expiration', [
(None),
('-1m'),
('-1d'),
('-1w'),
('10m'),
('10d'),
('10w'),
])
def test_gc(expiration, initialized_db):
user = model.user.get_user('devtable')
expiration_date = None
is_expired = False
if expiration:
if expiration[0] == '-':
is_expired = True
expiration_date = datetime.now() - convert_to_timedelta(expiration[1:])
else:
expiration_date = datetime.now() + convert_to_timedelta(expiration)
# Create a token.
token = create_token(user, 'Some token', expiration=expiration_date)
# GC tokens.
gc_expired_tokens(timedelta(seconds=0))
# Ensure the token was GCed if expired and not if it wasn't.
assert (access_valid_token(get_full_token_string(token)) is None) == is_expired
def test_access_token(initialized_db):
user = model.user.get_user('devtable')
# Create a token.
token = create_token(user, 'Some token')
assert token.last_accessed is None
# Lookup the token.
token = access_valid_token(get_full_token_string(token))
assert token.last_accessed is not None
# Revoke the token.
revoke_token(token)
# Ensure it cannot be accessed
assert access_valid_token(get_full_token_string(token)) is None
def test_expiring_soon(initialized_db):
user = model.user.get_user('devtable')
# Create some tokens.
create_token(user, 'Some token')
exp_token = create_token(user, 'Some expiring token', datetime.now() + convert_to_timedelta('1d'))
create_token(user, 'Some other token', expiration=datetime.now() + convert_to_timedelta('2d'))
# Get the token expiring soon.
expiring_soon = get_expiring_tokens(user, convert_to_timedelta('25h'))
assert expiring_soon
assert len(expiring_soon) == 1
assert expiring_soon[0].id == exp_token.id
expiring_soon = get_expiring_tokens(user, convert_to_timedelta('49h'))
assert expiring_soon
assert len(expiring_soon) == 2
@pytest.fixture(scope='function')
def app_config():
with patch.dict(_config.app_config, {}, clear=True):
yield _config.app_config
@pytest.mark.parametrize('expiration', [
(None),
('10m'),
('10d'),
('10w'),
])
@pytest.mark.parametrize('default_expiration', [
(None),
('10m'),
('10d'),
('10w'),
])
def test_create_access_token(expiration, default_expiration, initialized_db, app_config):
user = model.user.get_user('devtable')
expiration_date = datetime.now() + convert_to_timedelta(expiration) if expiration else None
with patch.dict(_config.app_config, {}, clear=True):
app_config['APP_SPECIFIC_TOKEN_EXPIRATION'] = default_expiration
if expiration:
exp_token = create_token(user, 'Some token', expiration=expiration_date)
assert exp_token.expiration == expiration_date
else:
exp_token = create_token(user, 'Some token')
assert (exp_token.expiration is None) == (default_expiration is None)
@pytest.mark.parametrize('invalid_token', [
'',
'foo',
'a' * 40,
'b' * 40,
'%s%s' % ('b' * 40, 'a' * 40),
'%s%s' % ('a' * 39, 'b' * 40),
'%s%s' % ('a' * 40, 'b' * 39),
'%s%s' % ('a' * 40, 'b' * 41),
])
def test_invalid_access_token(invalid_token, initialized_db):
user = model.user.get_user('devtable')
token = access_valid_token(invalid_token)
assert token is None

View file

@ -0,0 +1,107 @@
import pytest
from peewee import JOIN
from playhouse.test_utils import assert_query_count
from data.database import Repository, RepositoryPermission, TeamMember, Namespace
from data.model._basequery import filter_to_repos_for_user
from data.model.organization import get_admin_users
from data.model.user import get_namespace_user
from util.names import parse_robot_username
from test.fixtures import *
def _is_team_member(team, user):
return user.id in [member.user_id for member in
TeamMember.select().where(TeamMember.team == team)]
def _get_visible_repositories_for_user(user, repo_kind='image', include_public=False,
namespace=None):
""" Returns all repositories directly visible to the given user, by either repo permission,
or the user being the admin of a namespace.
"""
for repo in Repository.select():
if repo_kind is not None and repo.kind.name != repo_kind:
continue
if namespace is not None and repo.namespace_user.username != namespace:
continue
if include_public and repo.visibility.name == 'public':
yield repo
continue
# Direct repo permission.
try:
RepositoryPermission.get(repository=repo, user=user).get()
yield repo
continue
except RepositoryPermission.DoesNotExist:
pass
# Team permission.
found_in_team = False
for perm in RepositoryPermission.select().where(RepositoryPermission.repository == repo):
if perm.team and _is_team_member(perm.team, user):
found_in_team = True
break
if found_in_team:
yield repo
continue
# Org namespace admin permission.
if user in get_admin_users(repo.namespace_user):
yield repo
continue
@pytest.mark.parametrize('username', [
'devtable',
'devtable+dtrobot',
'public',
'reader',
])
@pytest.mark.parametrize('include_public', [
True,
False
])
@pytest.mark.parametrize('filter_to_namespace', [
True,
False
])
@pytest.mark.parametrize('repo_kind', [
None,
'image',
'application',
])
def test_filter_repositories(username, include_public, filter_to_namespace, repo_kind,
initialized_db):
namespace = username if filter_to_namespace else None
if '+' in username and filter_to_namespace:
namespace, _ = parse_robot_username(username)
user = get_namespace_user(username)
query = (Repository
.select()
.distinct()
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository)
.join(RepositoryPermission, JOIN.LEFT_OUTER))
# Prime the cache.
Repository.kind.get_id('image')
with assert_query_count(1):
found = list(filter_to_repos_for_user(query, user.id,
namespace=namespace,
include_public=include_public,
repo_kind=repo_kind))
expected = list(_get_visible_repositories_for_user(user,
repo_kind=repo_kind,
namespace=namespace,
include_public=include_public))
assert len(found) == len(expected)
assert {r.id for r in found} == {r.id for r in expected}

View file

@ -0,0 +1,107 @@
import pytest
from mock import patch
from data.database import BUILD_PHASE, RepositoryBuildTrigger, RepositoryBuild
from data.model.build import (update_trigger_disable_status, create_repository_build,
get_repository_build, update_phase_then_close)
from test.fixtures import *
TEST_FAIL_THRESHOLD = 5
TEST_INTERNAL_ERROR_THRESHOLD = 2
@pytest.mark.parametrize('starting_failure_count, starting_error_count, status, expected_reason', [
(0, 0, BUILD_PHASE.COMPLETE, None),
(10, 10, BUILD_PHASE.COMPLETE, None),
(TEST_FAIL_THRESHOLD - 1, TEST_INTERNAL_ERROR_THRESHOLD - 1, BUILD_PHASE.COMPLETE, None),
(TEST_FAIL_THRESHOLD - 1, 0, BUILD_PHASE.ERROR, 'successive_build_failures'),
(0, TEST_INTERNAL_ERROR_THRESHOLD - 1, BUILD_PHASE.INTERNAL_ERROR,
'successive_build_internal_errors'),
])
def test_update_trigger_disable_status(starting_failure_count, starting_error_count, status,
expected_reason, initialized_db):
test_config = {
'SUCCESSIVE_TRIGGER_FAILURE_DISABLE_THRESHOLD': TEST_FAIL_THRESHOLD,
'SUCCESSIVE_TRIGGER_INTERNAL_ERROR_DISABLE_THRESHOLD': TEST_INTERNAL_ERROR_THRESHOLD,
}
trigger = model.build.list_build_triggers('devtable', 'building')[0]
trigger.successive_failure_count = starting_failure_count
trigger.successive_internal_error_count = starting_error_count
trigger.enabled = True
trigger.save()
with patch('data.model.config.app_config', test_config):
update_trigger_disable_status(trigger, status)
updated_trigger = RepositoryBuildTrigger.get(uuid=trigger.uuid)
assert updated_trigger.enabled == (expected_reason is None)
if expected_reason is not None:
assert updated_trigger.disabled_reason.name == expected_reason
else:
assert updated_trigger.disabled_reason is None
assert updated_trigger.successive_failure_count == 0
assert updated_trigger.successive_internal_error_count == 0
def test_archivable_build_logs(initialized_db):
# Make sure there are no archivable logs.
result = model.build.get_archivable_build()
assert result is None
# Add a build that cannot (yet) be archived.
repo = model.repository.get_repository('devtable', 'simple')
token = model.token.create_access_token(repo, 'write')
created = RepositoryBuild.create(repository=repo, access_token=token,
phase=model.build.BUILD_PHASE.WAITING,
logs_archived=False, job_config='{}',
display_name='')
# Make sure there are no archivable logs.
result = model.build.get_archivable_build()
assert result is None
# Change the build to being complete.
created.phase = model.build.BUILD_PHASE.COMPLETE
created.save()
# Make sure we now find an archivable build.
result = model.build.get_archivable_build()
assert result.id == created.id
def test_update_build_phase(initialized_db):
build = create_build(model.repository.get_repository("devtable", "building"))
repo_build = get_repository_build(build.uuid)
assert repo_build.phase == BUILD_PHASE.WAITING
assert update_phase_then_close(build.uuid, BUILD_PHASE.COMPLETE)
repo_build = get_repository_build(build.uuid)
assert repo_build.phase == BUILD_PHASE.COMPLETE
repo_build.delete_instance()
assert not update_phase_then_close(repo_build.uuid, BUILD_PHASE.PULLING)
def create_build(repository):
new_token = model.token.create_access_token(repository, 'write', 'build-worker')
repo = 'ci.devtable.com:5000/%s/%s' % (repository.namespace_user.username, repository.name)
job_config = {
'repository': repo,
'docker_tags': ['latest'],
'build_subdir': '',
'trigger_metadata': {
'commit': '3482adc5822c498e8f7db2e361e8d57b3d77ddd9',
'ref': 'refs/heads/master',
'default_branch': 'master'
}
}
build = create_repository_build(repository, new_token, job_config,
'68daeebd-a5b9-457f-80a0-4363b882f8ea',
"build_name")
build.save()
return build

725
data/model/test/test_gc.py Normal file
View file

@ -0,0 +1,725 @@
import hashlib
import pytest
from datetime import datetime, timedelta
from mock import patch
from app import storage, docker_v2_signing_key
from contextlib import contextmanager
from playhouse.test_utils import assert_query_count
from freezegun import freeze_time
from data import model, database
from data.database import (Image, ImageStorage, DerivedStorageForImage, Label, TagManifestLabel,
ApprBlob, Manifest, TagManifestToManifest, ManifestBlob, Tag,
TagToRepositoryTag)
from data.model.oci.test.test_oci_manifest import create_manifest_for_testing
from image.docker.schema1 import DockerSchema1ManifestBuilder
from image.docker.schema2.manifest import DockerSchema2ManifestBuilder
from image.docker.schemas import parse_manifest_from_bytes
from util.bytes import Bytes
from test.fixtures import *
ADMIN_ACCESS_USER = 'devtable'
PUBLIC_USER = 'public'
REPO = 'somerepo'
def _set_tag_expiration_policy(namespace, expiration_s):
namespace_user = model.user.get_user(namespace)
model.user.change_user_tag_expiration(namespace_user, expiration_s)
@pytest.fixture()
def default_tag_policy(initialized_db):
_set_tag_expiration_policy(ADMIN_ACCESS_USER, 0)
_set_tag_expiration_policy(PUBLIC_USER, 0)
def create_image(docker_image_id, repository_obj, username):
preferred = storage.preferred_locations[0]
image = model.image.find_create_or_link_image(docker_image_id, repository_obj, username, {},
preferred)
image.storage.uploading = False
image.storage.save()
# Create derived images as well.
model.image.find_or_create_derived_storage(image, 'squash', preferred)
model.image.find_or_create_derived_storage(image, 'aci', preferred)
# Add some torrent info.
try:
database.TorrentInfo.get(storage=image.storage)
except database.TorrentInfo.DoesNotExist:
model.storage.save_torrent_info(image.storage, 1, 'helloworld')
# Add some additional placements to the image.
for location_name in ['local_eu']:
location = database.ImageStorageLocation.get(name=location_name)
try:
database.ImageStoragePlacement.get(location=location, storage=image.storage)
except:
continue
database.ImageStoragePlacement.create(location=location, storage=image.storage)
return image.storage
def store_tag_manifest(namespace, repo_name, tag_name, image_id):
builder = DockerSchema1ManifestBuilder(namespace, repo_name, tag_name)
storage_id_map = {}
try:
image_storage = ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).get()
builder.add_layer(image_storage.content_checksum, '{"id": "foo"}')
storage_id_map[image_storage.content_checksum] = image_storage.id
except ImageStorage.DoesNotExist:
pass
manifest = builder.build(docker_v2_signing_key)
manifest_row, _ = model.tag.store_tag_manifest_for_testing(namespace, repo_name, tag_name,
manifest, image_id, storage_id_map)
return manifest_row
def create_repository(namespace=ADMIN_ACCESS_USER, name=REPO, **kwargs):
user = model.user.get_user(namespace)
repo = model.repository.create_repository(namespace, name, user)
# Populate the repository with the tags.
image_map = {}
for tag_name in kwargs:
image_ids = kwargs[tag_name]
parent = None
for image_id in image_ids:
if not image_id in image_map:
image_map[image_id] = create_image(image_id, repo, namespace)
v1_metadata = {
'id': image_id,
}
if parent is not None:
v1_metadata['parent'] = parent.docker_image_id
# Set the ancestors for the image.
parent = model.image.set_image_metadata(image_id, namespace, name, '', '', '', v1_metadata,
parent=parent)
# Set the tag for the image.
tag_manifest = store_tag_manifest(namespace, name, tag_name, image_ids[-1])
# Add some labels to the tag.
model.label.create_manifest_label(tag_manifest, 'foo', 'bar', 'manifest')
model.label.create_manifest_label(tag_manifest, 'meh', 'grah', 'manifest')
return repo
def gc_now(repository):
assert model.gc.garbage_collect_repo(repository)
def delete_tag(repository, tag, perform_gc=True, expect_gc=True):
model.tag.delete_tag(repository.namespace_user.username, repository.name, tag)
if perform_gc:
assert model.gc.garbage_collect_repo(repository) == expect_gc
def move_tag(repository, tag, docker_image_id, expect_gc=True):
model.tag.create_or_update_tag(repository.namespace_user.username, repository.name, tag,
docker_image_id)
assert model.gc.garbage_collect_repo(repository) == expect_gc
def assert_not_deleted(repository, *args):
for docker_image_id in args:
assert model.image.get_image_by_id(repository.namespace_user.username, repository.name,
docker_image_id)
def assert_deleted(repository, *args):
for docker_image_id in args:
try:
# Verify the image is missing when accessed by the repository.
model.image.get_image_by_id(repository.namespace_user.username, repository.name,
docker_image_id)
except model.DataModelException:
return
assert False, 'Expected image %s to be deleted' % docker_image_id
def _get_dangling_storage_count():
storage_ids = set([current.id for current in ImageStorage.select()])
referenced_by_image = set([image.storage_id for image in Image.select()])
referenced_by_manifest = set([blob.blob_id for blob in ManifestBlob.select()])
referenced_by_derived = set([derived.derivative_id
for derived in DerivedStorageForImage.select()])
return len(storage_ids - referenced_by_image - referenced_by_derived - referenced_by_manifest)
def _get_dangling_label_count():
return len(_get_dangling_labels())
def _get_dangling_labels():
label_ids = set([current.id for current in Label.select()])
referenced_by_manifest = set([mlabel.label_id for mlabel in TagManifestLabel.select()])
return label_ids - referenced_by_manifest
def _get_dangling_manifest_count():
manifest_ids = set([current.id for current in Manifest.select()])
referenced_by_tag_manifest = set([tmt.manifest_id for tmt in TagManifestToManifest.select()])
return len(manifest_ids - referenced_by_tag_manifest)
@contextmanager
def assert_gc_integrity(expect_storage_removed=True, check_oci_tags=True):
""" Specialized assertion for ensuring that GC cleans up all dangling storages
and labels, invokes the callback for images removed and doesn't invoke the
callback for images *not* removed.
"""
# Add a callback for when images are removed.
removed_image_storages = []
model.config.register_image_cleanup_callback(removed_image_storages.extend)
# Store the number of dangling storages and labels.
existing_storage_count = _get_dangling_storage_count()
existing_label_count = _get_dangling_label_count()
existing_manifest_count = _get_dangling_manifest_count()
yield
# Ensure the number of dangling storages, manifests and labels has not changed.
updated_storage_count = _get_dangling_storage_count()
assert updated_storage_count == existing_storage_count
updated_label_count = _get_dangling_label_count()
assert updated_label_count == existing_label_count, _get_dangling_labels()
updated_manifest_count = _get_dangling_manifest_count()
assert updated_manifest_count == existing_manifest_count
# Ensure that for each call to the image+storage cleanup callback, the image and its
# storage is not found *anywhere* in the database.
for removed_image_and_storage in removed_image_storages:
with pytest.raises(Image.DoesNotExist):
Image.get(id=removed_image_and_storage.id)
# Ensure that image storages are only removed if not shared.
shared = Image.select().where(Image.storage == removed_image_and_storage.storage_id).count()
if shared == 0:
shared = (ManifestBlob
.select()
.where(ManifestBlob.blob == removed_image_and_storage.storage_id)
.count())
if shared == 0:
with pytest.raises(ImageStorage.DoesNotExist):
ImageStorage.get(id=removed_image_and_storage.storage_id)
with pytest.raises(ImageStorage.DoesNotExist):
ImageStorage.get(uuid=removed_image_and_storage.storage.uuid)
# Ensure all CAS storage is in the storage engine.
preferred = storage.preferred_locations[0]
for storage_row in ImageStorage.select():
if storage_row.cas_path:
storage.get_content({preferred}, storage.blob_path(storage_row.content_checksum))
for blob_row in ApprBlob.select():
storage.get_content({preferred}, storage.blob_path(blob_row.digest))
# Ensure there are no danglings OCI tags.
if check_oci_tags:
oci_tags = {t.id for t in Tag.select()}
referenced_oci_tags = {t.tag_id for t in TagToRepositoryTag.select()}
assert not oci_tags - referenced_oci_tags
# Ensure all tags have valid manifests.
for manifest in {t.manifest for t in Tag.select()}:
# Ensure that the manifest's blobs all exist.
found_blobs = {b.blob.content_checksum
for b in ManifestBlob.select().where(ManifestBlob.manifest == manifest)}
parsed = parse_manifest_from_bytes(Bytes.for_string_or_unicode(manifest.manifest_bytes),
manifest.media_type.name)
assert set(parsed.local_blob_digests) == found_blobs
def test_has_garbage(default_tag_policy, initialized_db):
""" Remove all existing repositories, then add one without garbage, check, then add one with
garbage, and check again.
"""
# Delete all existing repos.
for repo in database.Repository.select().order_by(database.Repository.id):
assert model.gc.purge_repository(repo.namespace_user.username, repo.name)
# Change the time machine expiration on the namespace.
(database.User
.update(removed_tag_expiration_s=1000000000)
.where(database.User.username == ADMIN_ACCESS_USER)
.execute())
# Create a repository without any garbage.
repository = create_repository(latest=['i1', 'i2', 'i3'])
# Ensure that no repositories are returned by the has garbage check.
assert model.repository.find_repository_with_garbage(1000000000) is None
# Delete a tag.
delete_tag(repository, 'latest', perform_gc=False)
# There should still not be any repositories with garbage, due to time machine.
assert model.repository.find_repository_with_garbage(1000000000) is None
# Change the time machine expiration on the namespace.
(database.User
.update(removed_tag_expiration_s=0)
.where(database.User.username == ADMIN_ACCESS_USER)
.execute())
# Now we should find the repository for GC.
repository = model.repository.find_repository_with_garbage(0)
assert repository is not None
assert repository.name == REPO
# GC the repository.
assert model.gc.garbage_collect_repo(repository)
# There should now be no repositories with garbage.
assert model.repository.find_repository_with_garbage(0) is None
def test_find_garbage_policy_functions(default_tag_policy, initialized_db):
with assert_query_count(1):
one_policy = model.repository.get_random_gc_policy()
all_policies = model.repository._get_gc_expiration_policies()
assert one_policy in all_policies
def test_one_tag(default_tag_policy, initialized_db):
""" Create a repository with a single tag, then remove that tag and verify that the repository
is now empty. """
with assert_gc_integrity():
repository = create_repository(latest=['i1', 'i2', 'i3'])
delete_tag(repository, 'latest')
assert_deleted(repository, 'i1', 'i2', 'i3')
def test_two_tags_unshared_images(default_tag_policy, initialized_db):
""" Repository has two tags with no shared images between them. """
with assert_gc_integrity():
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['f1', 'f2'])
delete_tag(repository, 'latest')
assert_deleted(repository, 'i1', 'i2', 'i3')
assert_not_deleted(repository, 'f1', 'f2')
def test_two_tags_shared_images(default_tag_policy, initialized_db):
""" Repository has two tags with shared images. Deleting the tag should only remove the
unshared images.
"""
with assert_gc_integrity():
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
delete_tag(repository, 'latest')
assert_deleted(repository, 'i2', 'i3')
assert_not_deleted(repository, 'i1', 'f1')
def test_unrelated_repositories(default_tag_policy, initialized_db):
""" Two repositories with different images. Removing the tag from one leaves the other's
images intact.
"""
with assert_gc_integrity():
repository1 = create_repository(latest=['i1', 'i2', 'i3'], name='repo1')
repository2 = create_repository(latest=['j1', 'j2', 'j3'], name='repo2')
delete_tag(repository1, 'latest')
assert_deleted(repository1, 'i1', 'i2', 'i3')
assert_not_deleted(repository2, 'j1', 'j2', 'j3')
def test_related_repositories(default_tag_policy, initialized_db):
""" Two repositories with shared images. Removing the tag from one leaves the other's
images intact.
"""
with assert_gc_integrity():
repository1 = create_repository(latest=['i1', 'i2', 'i3'], name='repo1')
repository2 = create_repository(latest=['i1', 'i2', 'j1'], name='repo2')
delete_tag(repository1, 'latest')
assert_deleted(repository1, 'i3')
assert_not_deleted(repository2, 'i1', 'i2', 'j1')
def test_inaccessible_repositories(default_tag_policy, initialized_db):
""" Two repositories under different namespaces should result in the images being deleted
but not completely removed from the database.
"""
with assert_gc_integrity():
repository1 = create_repository(namespace=ADMIN_ACCESS_USER, latest=['i1', 'i2', 'i3'])
repository2 = create_repository(namespace=PUBLIC_USER, latest=['i1', 'i2', 'i3'])
delete_tag(repository1, 'latest')
assert_deleted(repository1, 'i1', 'i2', 'i3')
assert_not_deleted(repository2, 'i1', 'i2', 'i3')
def test_many_multiple_shared_images(default_tag_policy, initialized_db):
""" Repository has multiple tags with shared images. Delete all but one tag.
"""
with assert_gc_integrity():
repository = create_repository(latest=['i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j0'],
master=['i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1'])
# Delete tag latest. Should only delete j0, since it is not shared.
delete_tag(repository, 'latest')
assert_deleted(repository, 'j0')
assert_not_deleted(repository, 'i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1')
# Delete tag master. Should delete the rest of the images.
delete_tag(repository, 'master')
assert_deleted(repository, 'i1', 'i2', 'i3', 'i4', 'i5', 'i6', 'i7', 'i8', 'j1')
def test_multiple_shared_images(default_tag_policy, initialized_db):
""" Repository has multiple tags with shared images. Selectively deleting the tags, and
verifying at each step.
"""
with assert_gc_integrity():
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1', 'f2'],
third=['t1', 't2', 't3'], fourth=['i1', 'f1'])
# Current state:
# latest -> i3->i2->i1
# other -> f2->f1->i1
# third -> t3->t2->t1
# fourth -> f1->i1
# Delete tag other. Should delete f2, since it is not shared.
delete_tag(repository, 'other')
assert_deleted(repository, 'f2')
assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3', 'f1')
# Current state:
# latest -> i3->i2->i1
# third -> t3->t2->t1
# fourth -> f1->i1
# Move tag fourth to i3. This should remove f1 since it is no longer referenced.
move_tag(repository, 'fourth', 'i3')
assert_deleted(repository, 'f1')
assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3')
# Current state:
# latest -> i3->i2->i1
# third -> t3->t2->t1
# fourth -> i3->i2->i1
# Delete tag 'latest'. This should do nothing since fourth is on the same branch.
delete_tag(repository, 'latest')
assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3')
# Current state:
# third -> t3->t2->t1
# fourth -> i3->i2->i1
# Delete tag 'third'. This should remove t1->t3.
delete_tag(repository, 'third')
assert_deleted(repository, 't1', 't2', 't3')
assert_not_deleted(repository, 'i1', 'i2', 'i3')
# Current state:
# fourth -> i3->i2->i1
# Add tag to i1.
move_tag(repository, 'newtag', 'i1', expect_gc=False)
assert_not_deleted(repository, 'i1', 'i2', 'i3')
# Current state:
# fourth -> i3->i2->i1
# newtag -> i1
# Delete tag 'fourth'. This should remove i2 and i3.
delete_tag(repository, 'fourth')
assert_deleted(repository, 'i2', 'i3')
assert_not_deleted(repository, 'i1')
# Current state:
# newtag -> i1
# Delete tag 'newtag'. This should remove the remaining image.
delete_tag(repository, 'newtag')
assert_deleted(repository, 'i1')
# Current state:
# (Empty)
def test_empty_gc(default_tag_policy, initialized_db):
with assert_gc_integrity(expect_storage_removed=False):
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1', 'f2'],
third=['t1', 't2', 't3'], fourth=['i1', 'f1'])
assert not model.gc.garbage_collect_repo(repository)
assert_not_deleted(repository, 'i1', 'i2', 'i3', 't1', 't2', 't3', 'f1', 'f2')
def test_time_machine_no_gc(default_tag_policy, initialized_db):
""" Repository has two tags with shared images. Deleting the tag should not remove any images
"""
with assert_gc_integrity(expect_storage_removed=False):
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
_set_tag_expiration_policy(repository.namespace_user.username, 60*60*24)
delete_tag(repository, 'latest', expect_gc=False)
assert_not_deleted(repository, 'i2', 'i3')
assert_not_deleted(repository, 'i1', 'f1')
def test_time_machine_gc(default_tag_policy, initialized_db):
""" Repository has two tags with shared images. Deleting the second tag should cause the images
for the first deleted tag to gc.
"""
now = datetime.utcnow()
with assert_gc_integrity():
with freeze_time(now):
repository = create_repository(latest=['i1', 'i2', 'i3'], other=['i1', 'f1'])
_set_tag_expiration_policy(repository.namespace_user.username, 1)
delete_tag(repository, 'latest', expect_gc=False)
assert_not_deleted(repository, 'i2', 'i3')
assert_not_deleted(repository, 'i1', 'f1')
with freeze_time(now + timedelta(seconds=2)):
# This will cause the images associated with latest to gc
delete_tag(repository, 'other')
assert_deleted(repository, 'i2', 'i3')
assert_not_deleted(repository, 'i1', 'f1')
def test_images_shared_storage(default_tag_policy, initialized_db):
""" Repository with two tags, both with the same shared storage. Deleting the first
tag should delete the first image, but *not* its storage.
"""
with assert_gc_integrity(expect_storage_removed=False):
repository = create_repository()
# Add two tags, each with their own image, but with the same storage.
image_storage = model.storage.create_v1_storage(storage.preferred_locations[0])
first_image = Image.create(docker_image_id='i1',
repository=repository, storage=image_storage,
ancestors='/')
second_image = Image.create(docker_image_id='i2',
repository=repository, storage=image_storage,
ancestors='/')
store_tag_manifest(repository.namespace_user.username, repository.name,
'first', first_image.docker_image_id)
store_tag_manifest(repository.namespace_user.username, repository.name,
'second', second_image.docker_image_id)
# Delete the first tag.
delete_tag(repository, 'first')
assert_deleted(repository, 'i1')
assert_not_deleted(repository, 'i2')
def test_image_with_cas(default_tag_policy, initialized_db):
""" A repository with a tag pointing to an image backed by CAS. Deleting and GCing the tag
should result in the storage and its CAS data being removed.
"""
with assert_gc_integrity(expect_storage_removed=True):
repository = create_repository()
# Create an image storage record under CAS.
content = 'hello world'
digest = 'sha256:' + hashlib.sha256(content).hexdigest()
preferred = storage.preferred_locations[0]
storage.put_content({preferred}, storage.blob_path(digest), content)
image_storage = database.ImageStorage.create(content_checksum=digest, uploading=False)
location = database.ImageStorageLocation.get(name=preferred)
database.ImageStoragePlacement.create(location=location, storage=image_storage)
# Ensure the CAS path exists.
assert storage.exists({preferred}, storage.blob_path(digest))
# Create the image and the tag.
first_image = Image.create(docker_image_id='i1',
repository=repository, storage=image_storage,
ancestors='/')
store_tag_manifest(repository.namespace_user.username, repository.name,
'first', first_image.docker_image_id)
assert_not_deleted(repository, 'i1')
# Delete the tag.
delete_tag(repository, 'first')
assert_deleted(repository, 'i1')
# Ensure the CAS path is gone.
assert not storage.exists({preferred}, storage.blob_path(digest))
def test_images_shared_cas(default_tag_policy, initialized_db):
""" A repository, each two tags, pointing to the same image, which has image storage
with the same *CAS path*, but *distinct records*. Deleting the first tag should delete the
first image, and its storage, but not the file in storage, as it shares its CAS path.
"""
with assert_gc_integrity(expect_storage_removed=True):
repository = create_repository()
# Create two image storage records with the same content checksum.
content = 'hello world'
digest = 'sha256:' + hashlib.sha256(content).hexdigest()
preferred = storage.preferred_locations[0]
storage.put_content({preferred}, storage.blob_path(digest), content)
is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
is2 = database.ImageStorage.create(content_checksum=digest, uploading=False)
location = database.ImageStorageLocation.get(name=preferred)
database.ImageStoragePlacement.create(location=location, storage=is1)
database.ImageStoragePlacement.create(location=location, storage=is2)
# Ensure the CAS path exists.
assert storage.exists({preferred}, storage.blob_path(digest))
# Create two images in the repository, and two tags, each pointing to one of the storages.
first_image = Image.create(docker_image_id='i1',
repository=repository, storage=is1,
ancestors='/')
second_image = Image.create(docker_image_id='i2',
repository=repository, storage=is2,
ancestors='/')
store_tag_manifest(repository.namespace_user.username, repository.name,
'first', first_image.docker_image_id)
store_tag_manifest(repository.namespace_user.username, repository.name,
'second', second_image.docker_image_id)
assert_not_deleted(repository, 'i1', 'i2')
# Delete the first tag.
delete_tag(repository, 'first')
assert_deleted(repository, 'i1')
assert_not_deleted(repository, 'i2')
# Ensure the CAS path still exists.
assert storage.exists({preferred}, storage.blob_path(digest))
def test_images_shared_cas_with_new_blob_table(default_tag_policy, initialized_db):
""" A repository with a tag and image that shares its CAS path with a record in the new Blob
table. Deleting the first tag should delete the first image, and its storage, but not the
file in storage, as it shares its CAS path with the blob row.
"""
with assert_gc_integrity(expect_storage_removed=True):
repository = create_repository()
# Create two image storage records with the same content checksum.
content = 'hello world'
digest = 'sha256:' + hashlib.sha256(content).hexdigest()
preferred = storage.preferred_locations[0]
storage.put_content({preferred}, storage.blob_path(digest), content)
media_type = database.MediaType.get(name='text/plain')
is1 = database.ImageStorage.create(content_checksum=digest, uploading=False)
database.ApprBlob.create(digest=digest, size=0, media_type=media_type)
location = database.ImageStorageLocation.get(name=preferred)
database.ImageStoragePlacement.create(location=location, storage=is1)
# Ensure the CAS path exists.
assert storage.exists({preferred}, storage.blob_path(digest))
# Create the image in the repository, and the tag.
first_image = Image.create(docker_image_id='i1',
repository=repository, storage=is1,
ancestors='/')
store_tag_manifest(repository.namespace_user.username, repository.name,
'first', first_image.docker_image_id)
assert_not_deleted(repository, 'i1')
# Delete the tag.
delete_tag(repository, 'first')
assert_deleted(repository, 'i1')
# Ensure the CAS path still exists, as it is referenced by the Blob table
assert storage.exists({preferred}, storage.blob_path(digest))
def test_purge_repo(app):
""" Test that app registers delete_metadata function on repository deletions """
with assert_gc_integrity():
with patch('app.tuf_metadata_api') as mock_tuf:
model.gc.purge_repository("ns", "repo")
assert mock_tuf.delete_metadata.called_with("ns", "repo")
def test_super_long_image_chain_gc(app, default_tag_policy):
""" Test that a super long chain of images all gets properly GCed. """
with assert_gc_integrity():
images = ['i%s' % i for i in range(0, 100)]
repository = create_repository(latest=images)
delete_tag(repository, 'latest')
# Ensure the repository is now empty.
assert_deleted(repository, *images)
def test_manifest_v2_shared_config_and_blobs(app, default_tag_policy):
""" Test that GCing a tag that refers to a V2 manifest with the same config and some shared
blobs as another manifest ensures that the config blob and shared blob are NOT GCed.
"""
repo = model.repository.create_repository('devtable', 'newrepo', None)
manifest1, built1 = create_manifest_for_testing(repo, differentiation_field='1',
include_shared_blob=True)
manifest2, built2 = create_manifest_for_testing(repo, differentiation_field='2',
include_shared_blob=True)
assert set(built1.local_blob_digests).intersection(built2.local_blob_digests)
assert built1.config.digest == built2.config.digest
# Create tags pointing to the manifests.
model.oci.tag.retarget_tag('tag1', manifest1)
model.oci.tag.retarget_tag('tag2', manifest2)
with assert_gc_integrity(expect_storage_removed=True, check_oci_tags=False):
# Delete tag2.
model.oci.tag.delete_tag(repo, 'tag2')
assert model.gc.garbage_collect_repo(repo)
# Ensure the blobs for manifest1 still all exist.
preferred = storage.preferred_locations[0]
for blob_digest in built1.local_blob_digests:
storage_row = ImageStorage.get(content_checksum=blob_digest)
assert storage_row.cas_path
storage.get_content({preferred}, storage.blob_path(storage_row.content_checksum))

View file

@ -0,0 +1,104 @@
import pytest
from collections import defaultdict
from data.model import image, repository
from playhouse.test_utils import assert_query_count
from test.fixtures import *
@pytest.fixture()
def images(initialized_db):
images = image.get_repository_images('devtable', 'simple')
assert len(images)
return images
def test_get_image_with_storage(images, initialized_db):
for current in images:
storage_uuid = current.storage.uuid
with assert_query_count(1):
retrieved = image.get_image_with_storage(current.docker_image_id, storage_uuid)
assert retrieved.id == current.id
assert retrieved.storage.uuid == storage_uuid
def test_get_parent_images(images, initialized_db):
for current in images:
if not len(current.ancestor_id_list()):
continue
with assert_query_count(1):
parent_images = list(image.get_parent_images('devtable', 'simple', current))
assert len(parent_images) == len(current.ancestor_id_list())
assert set(current.ancestor_id_list()) == {i.id for i in parent_images}
for parent in parent_images:
with assert_query_count(0):
assert parent.storage.id
def test_get_image(images, initialized_db):
for current in images:
repo = current.repository
with assert_query_count(1):
found = image.get_image(repo, current.docker_image_id)
assert found.id == current.id
def test_placements(images, initialized_db):
with assert_query_count(1):
placements_map = image.get_placements_for_images(images)
for current in images:
assert current.storage.id in placements_map
with assert_query_count(2):
expected_image, expected_placements = image.get_image_and_placements('devtable', 'simple',
current.docker_image_id)
assert expected_image.id == current.id
assert len(expected_placements) == len(placements_map.get(current.storage.id))
assert ({p.id for p in expected_placements} ==
{p.id for p in placements_map.get(current.storage.id)})
def test_get_repo_image(images, initialized_db):
for current in images:
with assert_query_count(1):
found = image.get_repo_image('devtable', 'simple', current.docker_image_id)
assert found.id == current.id
with assert_query_count(1):
assert found.storage.id
def test_get_repo_image_and_storage(images, initialized_db):
for current in images:
with assert_query_count(1):
found = image.get_repo_image_and_storage('devtable', 'simple', current.docker_image_id)
assert found.id == current.id
with assert_query_count(0):
assert found.storage.id
def test_get_repository_images_without_placements(images, initialized_db):
ancestors_map = defaultdict(list)
for img in images:
current = img.parent
while current is not None:
ancestors_map[current.id].append(img.id)
current = current.parent
for current in images:
repo = current.repository
with assert_query_count(1):
found = list(image.get_repository_images_without_placements(repo, with_ancestor=current))
assert len(found) == len(ancestors_map[current.id]) + 1
assert {i.id for i in found} == set(ancestors_map[current.id] + [current.id])

View file

@ -0,0 +1,215 @@
import pytest
from data import model
from storage.distributedstorage import DistributedStorage
from storage.fakestorage import FakeStorage
from test.fixtures import *
NO_ACCESS_USER = 'freshuser'
READ_ACCESS_USER = 'reader'
ADMIN_ACCESS_USER = 'devtable'
PUBLIC_USER = 'public'
RANDOM_USER = 'randomuser'
OUTSIDE_ORG_USER = 'outsideorg'
ADMIN_ROBOT_USER = 'devtable+dtrobot'
ORGANIZATION = 'buynlarge'
REPO = 'devtable/simple'
PUBLIC_REPO = 'public/publicrepo'
RANDOM_REPO = 'randomuser/randomrepo'
OUTSIDE_ORG_REPO = 'outsideorg/coolrepo'
ORG_REPO = 'buynlarge/orgrepo'
ANOTHER_ORG_REPO = 'buynlarge/anotherorgrepo'
# Note: The shared repo has devtable as admin, public as a writer and reader as a reader.
SHARED_REPO = 'devtable/shared'
@pytest.fixture()
def storage(app):
return DistributedStorage({'local_us': FakeStorage(None)}, preferred_locations=['local_us'])
def createStorage(storage, docker_image_id, repository=REPO, username=ADMIN_ACCESS_USER):
repository_obj = model.repository.get_repository(repository.split('/')[0],
repository.split('/')[1])
preferred = storage.preferred_locations[0]
image = model.image.find_create_or_link_image(docker_image_id, repository_obj, username, {},
preferred)
image.storage.uploading = False
image.storage.save()
return image.storage
def assertSameStorage(storage, docker_image_id, existing_storage, repository=REPO,
username=ADMIN_ACCESS_USER):
new_storage = createStorage(storage, docker_image_id, repository, username)
assert existing_storage.id == new_storage.id
def assertDifferentStorage(storage, docker_image_id, existing_storage, repository=REPO,
username=ADMIN_ACCESS_USER):
new_storage = createStorage(storage, docker_image_id, repository, username)
assert existing_storage.id != new_storage.id
def test_same_user(storage, initialized_db):
""" The same user creates two images, each which should be shared in the same repo. This is a
sanity check. """
# Create a reference to a new docker ID => new image.
first_storage_id = createStorage(storage, 'first-image')
# Create a reference to the same docker ID => same image.
assertSameStorage(storage, 'first-image', first_storage_id)
# Create a reference to another new docker ID => new image.
second_storage_id = createStorage(storage, 'second-image')
# Create a reference to that same docker ID => same image.
assertSameStorage(storage, 'second-image', second_storage_id)
# Make sure the images are different.
assert first_storage_id != second_storage_id
def test_no_user_private_repo(storage, initialized_db):
""" If no user is specified (token case usually), then no sharing can occur on a private repo. """
# Create a reference to a new docker ID => new image.
first_storage = createStorage(storage, 'the-image', username=None, repository=SHARED_REPO)
# Create a areference to the same docker ID, but since no username => new image.
assertDifferentStorage(storage, 'the-image', first_storage, username=None, repository=RANDOM_REPO)
def test_no_user_public_repo(storage, initialized_db):
""" If no user is specified (token case usually), then no sharing can occur on a private repo except when the image is first public. """
# Create a reference to a new docker ID => new image.
first_storage = createStorage(storage, 'the-image', username=None, repository=PUBLIC_REPO)
# Create a areference to the same docker ID. Since no username, we'd expect different but the first image is public so => shaed image.
assertSameStorage(storage, 'the-image', first_storage, username=None, repository=RANDOM_REPO)
def test_different_user_same_repo(storage, initialized_db):
""" Two different users create the same image in the same repo. """
# Create a reference to a new docker ID under the first user => new image.
first_storage = createStorage(storage, 'the-image', username=PUBLIC_USER, repository=SHARED_REPO)
# Create a reference to the *same* docker ID under the second user => same image.
assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=SHARED_REPO)
def test_different_repo_no_shared_access(storage, initialized_db):
""" Neither user has access to the other user's repository. """
# Create a reference to a new docker ID under the first user => new image.
first_storage_id = createStorage(storage, 'the-image', username=RANDOM_USER, repository=RANDOM_REPO)
# Create a reference to the *same* docker ID under the second user => new image.
second_storage_id = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=REPO)
# Verify that the users do not share storage.
assert first_storage_id != second_storage_id
def test_public_than_private(storage, initialized_db):
""" An image is created publicly then used privately, so it should be shared. """
# Create a reference to a new docker ID under the first user => new image.
first_storage = createStorage(storage, 'the-image', username=PUBLIC_USER, repository=PUBLIC_REPO)
# Create a reference to the *same* docker ID under the second user => same image, since the first was public.
assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=REPO)
def test_private_than_public(storage, initialized_db):
""" An image is created privately then used publicly, so it should *not* be shared. """
# Create a reference to a new docker ID under the first user => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=REPO)
# Create a reference to the *same* docker ID under the second user => new image, since the first was private.
assertDifferentStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
def test_different_repo_with_access(storage, initialized_db):
""" An image is created in one repo (SHARED_REPO) which the user (PUBLIC_USER) has access to. Later, the
image is created in another repo (PUBLIC_REPO) that the user also has access to. The image should
be shared since the user has access.
"""
# Create the image in the shared repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=SHARED_REPO)
# Create the image in the other user's repo, but since the user (PUBLIC) still has access to the shared
# repository, they should reuse the storage.
assertSameStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
def test_org_access(storage, initialized_db):
""" An image is accessible by being a member of the organization. """
# Create the new image under the org's repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
# Create an image under the user's repo, but since the user has access to the organization => shared image.
assertSameStorage(storage, 'the-image', first_storage, username=ADMIN_ACCESS_USER, repository=REPO)
# Ensure that the user's robot does not have access, since it is not on the permissions list for the repo.
assertDifferentStorage(storage, 'the-image', first_storage, username=ADMIN_ROBOT_USER, repository=SHARED_REPO)
def test_org_access_different_user(storage, initialized_db):
""" An image is accessible by being a member of the organization. """
# Create the new image under the org's repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
# Create an image under a user's repo, but since the user has access to the organization => shared image.
assertSameStorage(storage, 'the-image', first_storage, username=PUBLIC_USER, repository=PUBLIC_REPO)
# Also verify for reader.
assertSameStorage(storage, 'the-image', first_storage, username=READ_ACCESS_USER, repository=PUBLIC_REPO)
def test_org_no_access(storage, initialized_db):
""" An image is not accessible if not a member of the organization. """
# Create the new image under the org's repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
# Create an image under a user's repo. Since the user is not a member of the organization => new image.
assertDifferentStorage(storage, 'the-image', first_storage, username=RANDOM_USER, repository=RANDOM_REPO)
def test_org_not_team_member_with_access(storage, initialized_db):
""" An image is accessible to a user specifically listed as having permission on the org repo. """
# Create the new image under the org's repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ORG_REPO)
# Create an image under a user's repo. Since the user has read access on that repo, they can see the image => shared image.
assertSameStorage(storage, 'the-image', first_storage, username=OUTSIDE_ORG_USER, repository=OUTSIDE_ORG_REPO)
def test_org_not_team_member_with_no_access(storage, initialized_db):
""" A user that has access to one org repo but not another and is not a team member. """
# Create the new image under the org's repo => new image.
first_storage = createStorage(storage, 'the-image', username=ADMIN_ACCESS_USER, repository=ANOTHER_ORG_REPO)
# Create an image under a user's repo. The user doesn't have access to the repo (ANOTHER_ORG_REPO) so => new image.
assertDifferentStorage(storage, 'the-image', first_storage, username=OUTSIDE_ORG_USER, repository=OUTSIDE_ORG_REPO)
def test_no_link_to_uploading(storage, initialized_db):
still_uploading = createStorage(storage, 'an-image', repository=PUBLIC_REPO)
still_uploading.uploading = True
still_uploading.save()
assertDifferentStorage(storage, 'an-image', still_uploading)

View file

@ -0,0 +1,80 @@
import pytest
from data.database import LogEntry3, User
from data.model import config as _config
from data.model.log import log_action
from mock import patch, Mock, DEFAULT, sentinel
from peewee import PeeweeException
@pytest.fixture(scope='function')
def app_config():
with patch.dict(_config.app_config, {}, clear=True):
yield _config.app_config
@pytest.fixture()
def logentry_kind():
kinds = {'pull_repo': 'pull_repo_kind', 'push_repo': 'push_repo_kind'}
with patch('data.model.log.get_log_entry_kinds', return_value=kinds, spec=True):
yield kinds
@pytest.fixture()
def logentry(logentry_kind):
with patch('data.database.LogEntry3.create', spec=True):
yield LogEntry3
@pytest.fixture()
def user():
with patch.multiple('data.database.User', username=DEFAULT, get=DEFAULT, select=DEFAULT) as user:
user['get'].return_value = Mock(id='mock_user_id')
user['select'].return_value.tuples.return_value.get.return_value = ['default_user_id']
yield User
@pytest.mark.parametrize('action_kind', [('pull'), ('oops')])
def test_log_action_unknown_action(action_kind):
''' test unknown action types throw an exception when logged '''
with pytest.raises(Exception):
log_action(action_kind, None)
@pytest.mark.parametrize('user_or_org_name,account_id,account', [
('my_test_org', 'N/A', 'mock_user_id' ),
(None, 'test_account_id', 'test_account_id'),
(None, None, 'default_user_id')
])
@pytest.mark.parametrize('unlogged_pulls_ok,action_kind,db_exception,throws', [
(False, 'pull_repo', None, False),
(False, 'push_repo', None, False),
(False, 'pull_repo', PeeweeException, True ),
(False, 'push_repo', PeeweeException, True ),
(True, 'pull_repo', PeeweeException, False),
(True, 'push_repo', PeeweeException, True ),
(True, 'pull_repo', Exception, True ),
(True, 'push_repo', Exception, True )
])
def test_log_action(user_or_org_name, account_id, account, unlogged_pulls_ok, action_kind,
db_exception, throws, app_config, logentry, user):
log_args = {
'performer' : Mock(id='TEST_PERFORMER_ID'),
'repository' : Mock(id='TEST_REPO'),
'ip' : 'TEST_IP',
'metadata' : { 'test_key' : 'test_value' },
'timestamp' : 'TEST_TIMESTAMP'
}
app_config['SERVICE_LOG_ACCOUNT_ID'] = account_id
app_config['ALLOW_PULLS_WITHOUT_STRICT_LOGGING'] = unlogged_pulls_ok
logentry.create.side_effect = db_exception
if throws:
with pytest.raises(db_exception):
log_action(action_kind, user_or_org_name, **log_args)
else:
log_action(action_kind, user_or_org_name, **log_args)
logentry.create.assert_called_once_with(kind=action_kind+'_kind', account=account,
performer='TEST_PERFORMER_ID', repository='TEST_REPO',
ip='TEST_IP', metadata_json='{"test_key": "test_value"}',
datetime='TEST_TIMESTAMP')

View file

@ -0,0 +1,51 @@
from app import storage
from data import model, database
from test.fixtures import *
ADMIN_ACCESS_USER = 'devtable'
REPO = 'simple'
def test_store_blob(initialized_db):
location = database.ImageStorageLocation.select().get()
# Create a new blob at a unique digest.
digest = 'somecooldigest'
blob_storage = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, digest,
location, 1024, 0, 5000)
assert blob_storage.content_checksum == digest
assert blob_storage.image_size == 1024
assert blob_storage.uncompressed_size == 5000
# Link to the same digest.
blob_storage2 = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, digest,
location, 2048, 0, 6000)
assert blob_storage2.id == blob_storage.id
# The sizes should be unchanged.
assert blob_storage2.image_size == 1024
assert blob_storage2.uncompressed_size == 5000
# Add a new digest, ensure it has a new record.
otherdigest = 'anotherdigest'
blob_storage3 = model.blob.store_blob_record_and_temp_link(ADMIN_ACCESS_USER, REPO, otherdigest,
location, 1234, 0, 5678)
assert blob_storage3.id != blob_storage.id
assert blob_storage3.image_size == 1234
assert blob_storage3.uncompressed_size == 5678
def test_get_or_create_shared_blob(initialized_db):
shared = model.blob.get_or_create_shared_blob('sha256:abcdef', 'somecontent', storage)
assert shared.content_checksum == 'sha256:abcdef'
again = model.blob.get_or_create_shared_blob('sha256:abcdef', 'somecontent', storage)
assert shared == again
def test_lookup_repo_storages_by_content_checksum(initialized_db):
for image in database.Image.select():
found = model.storage.lookup_repo_storages_by_content_checksum(image.repository,
[image.storage.content_checksum])
assert len(found) == 1
assert found[0].content_checksum == image.storage.content_checksum

View file

@ -0,0 +1,50 @@
import pytest
from data.database import Role
from data.model.modelutil import paginate
from test.fixtures import *
@pytest.mark.parametrize('page_size', [
10,
20,
50,
100,
200,
500,
1000,
])
@pytest.mark.parametrize('descending', [
False,
True,
])
def test_paginate(page_size, descending, initialized_db):
# Add a bunch of rows into a test table (`Role`).
for i in range(0, 522):
Role.create(name='testrole%s' % i)
query = Role.select().where(Role.name ** 'testrole%')
all_matching_roles = list(query)
assert len(all_matching_roles) == 522
# Paginate a query to lookup roles.
collected = []
page_token = None
while True:
results, page_token = paginate(query, Role, limit=page_size, descending=descending,
page_token=page_token)
assert len(results) <= page_size
collected.extend(results)
if page_token is None:
break
assert len(results) == page_size
for index, result in enumerate(results[1:]):
if descending:
assert result.id < results[index].id
else:
assert result.id > results[index].id
assert len(collected) == len(all_matching_roles)
assert {c.id for c in collected} == {a.id for a in all_matching_roles}

View file

@ -0,0 +1,22 @@
import pytest
from data.model.organization import get_organization, get_organizations
from data.model.user import mark_namespace_for_deletion
from data.queue import WorkQueue
from test.fixtures import *
@pytest.mark.parametrize('deleted', [
(True),
(False),
])
def test_get_organizations(deleted, initialized_db):
# Delete an org.
deleted_org = get_organization('sellnsmall')
queue = WorkQueue('testgcnamespace', lambda db: db.transaction())
mark_namespace_for_deletion(deleted_org, [], queue)
orgs = get_organizations(deleted=deleted)
assert orgs
deleted_found = [org for org in orgs if org.id == deleted_org.id]
assert bool(deleted_found) == deleted

View file

@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from jsonschema import ValidationError
from data.database import RepoMirrorConfig, RepoMirrorStatus, User
from data import model
from data.model.repo_mirror import (create_mirroring_rule, get_eligible_mirrors, update_sync_status_to_cancel,
MAX_SYNC_RETRIES, release_mirror)
from test.fixtures import *
def create_mirror_repo_robot(rules, repo_name="repo"):
try:
user = User.get(User.username == "mirror")
except User.DoesNotExist:
user = create_user_noverify("mirror", "mirror@example.com", email_required=False)
try:
robot = lookup_robot("mirror+robot")
except model.InvalidRobotException:
robot, _ = create_robot("robot", user)
repo = create_repository("mirror", repo_name, None, repo_kind="image", visibility="public")
repo.save()
rule = model.repo_mirror.create_mirroring_rule(repo, rules)
mirror_kwargs = {
"repository": repo,
"root_rule": rule,
"internal_robot": robot,
"external_reference": "registry.example.com/namespace/repository",
"sync_interval": timedelta(days=1).total_seconds()
}
mirror = enable_mirroring_for_repository(**mirror_kwargs)
mirror.sync_status = RepoMirrorStatus.NEVER_RUN
mirror.sync_start_date = datetime.utcnow() - timedelta(days=1)
mirror.sync_retries_remaining = 3
mirror.save()
return (mirror, repo)
def disable_existing_mirrors():
mirrors = RepoMirrorConfig.select().execute()
for mirror in mirrors:
mirror.is_enabled = False
mirror.save()
def test_eligible_oldest_first(initialized_db):
"""
Eligible mirror candidates should be returned with the oldest (earliest created) first.
"""
disable_existing_mirrors()
mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
candidates = get_eligible_mirrors()
assert len(candidates) == 3
assert candidates[0] == mirror_first
assert candidates[1] == mirror_second
assert candidates[2] == mirror_third
def test_eligible_includes_expired_syncing(initialized_db):
"""
Mirrors that have an end time in the past are eligible even if their state indicates still syncing.
"""
disable_existing_mirrors()
mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
mirror_fourth, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="fourth")
mirror_second.sync_expiration_date = datetime.utcnow() - timedelta(hours=1)
mirror_second.sync_status = RepoMirrorStatus.SYNCING
mirror_second.save()
mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
mirror_fourth.save()
candidates = get_eligible_mirrors()
assert len(candidates) == 3
assert candidates[0] == mirror_first
assert candidates[1] == mirror_second
assert candidates[2] == mirror_third
def test_eligible_includes_immediate(initialized_db):
"""
Mirrors that are SYNC_NOW, regardless of starting time
"""
disable_existing_mirrors()
mirror_first, repo_first = create_mirror_repo_robot(["updated", "created"], repo_name="first")
mirror_second, repo_second = create_mirror_repo_robot(["updated", "created"], repo_name="second")
mirror_third, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="third")
mirror_fourth, repo_third = create_mirror_repo_robot(["updated", "created"], repo_name="fourth")
mirror_future, _ = create_mirror_repo_robot(["updated", "created"], repo_name="future")
mirror_past, _ = create_mirror_repo_robot(["updated", "created"], repo_name="past")
mirror_future.sync_start_date = datetime.utcnow() + timedelta(hours=6)
mirror_future.sync_status = RepoMirrorStatus.SYNC_NOW
mirror_future.save()
mirror_past.sync_start_date = datetime.utcnow() - timedelta(hours=6)
mirror_past.sync_status = RepoMirrorStatus.SYNC_NOW
mirror_past.save()
mirror_fourth.sync_expiration_date = datetime.utcnow() + timedelta(hours=1)
mirror_fourth.sync_status = RepoMirrorStatus.SYNCING
mirror_fourth.save()
candidates = get_eligible_mirrors()
assert len(candidates) == 5
assert candidates[0] == mirror_first
assert candidates[1] == mirror_second
assert candidates[2] == mirror_third
assert candidates[3] == mirror_past
assert candidates[4] == mirror_future
def test_create_rule_validations(initialized_db):
mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
with pytest.raises(ValidationError):
create_mirroring_rule(repo, None)
with pytest.raises(ValidationError):
create_mirroring_rule(repo, "['tag1', 'tag2']")
with pytest.raises(ValidationError):
create_mirroring_rule(repo, ['tag1', 'tag2'], rule_type=None)
def test_long_registry_passwords(initialized_db):
"""
Verify that long passwords, such as Base64 JWT used by Redhat's Registry, work as expected.
"""
MAX_PASSWORD_LENGTH = 1024
username = ''.join('a' for _ in range(MAX_PASSWORD_LENGTH))
password = ''.join('b' for _ in range(MAX_PASSWORD_LENGTH))
assert len(username) == MAX_PASSWORD_LENGTH
assert len(password) == MAX_PASSWORD_LENGTH
repo = model.repository.get_repository('devtable', 'mirrored')
assert repo
existing_mirror_conf = model.repo_mirror.get_mirror(repo)
assert existing_mirror_conf
assert model.repo_mirror.change_credentials(repo, username, password)
updated_mirror_conf = model.repo_mirror.get_mirror(repo)
assert updated_mirror_conf
assert updated_mirror_conf.external_registry_username.decrypt() == username
assert updated_mirror_conf.external_registry_password.decrypt() == password
def test_sync_status_to_cancel(initialized_db):
"""
SYNCING and SYNC_NOW mirrors may be canceled, ending in NEVER_RUN
"""
disable_existing_mirrors()
mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="cancel")
mirror.sync_status = RepoMirrorStatus.SYNCING
mirror.save()
updated = update_sync_status_to_cancel(mirror)
assert updated is not None
assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
mirror.sync_status = RepoMirrorStatus.SYNC_NOW
mirror.save()
updated = update_sync_status_to_cancel(mirror)
assert updated is not None
assert updated.sync_status == RepoMirrorStatus.NEVER_RUN
mirror.sync_status = RepoMirrorStatus.FAIL
mirror.save()
updated = update_sync_status_to_cancel(mirror)
assert updated is None
mirror.sync_status = RepoMirrorStatus.NEVER_RUN
mirror.save()
updated = update_sync_status_to_cancel(mirror)
assert updated is None
mirror.sync_status = RepoMirrorStatus.SUCCESS
mirror.save()
updated = update_sync_status_to_cancel(mirror)
assert updated is None
def test_release_mirror(initialized_db):
"""
Mirrors that are SYNC_NOW, regardless of starting time
"""
disable_existing_mirrors()
mirror, repo = create_mirror_repo_robot(["updated", "created"], repo_name="first")
# mysql rounds the milliseconds on update so force that to happen now
query = (RepoMirrorConfig
.update(sync_start_date=mirror.sync_start_date)
.where(RepoMirrorConfig.id == mirror.id))
query.execute()
mirror = RepoMirrorConfig.get_by_id(mirror.id)
original_sync_start_date = mirror.sync_start_date
assert mirror.sync_retries_remaining == 3
mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
assert mirror.sync_retries_remaining == 2
assert mirror.sync_start_date == original_sync_start_date
mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
assert mirror.sync_retries_remaining == 1
assert mirror.sync_start_date == original_sync_start_date
mirror = release_mirror(mirror, RepoMirrorStatus.FAIL)
assert mirror.sync_retries_remaining == 3
assert mirror.sync_start_date > original_sync_start_date

View file

@ -0,0 +1,49 @@
from datetime import timedelta
import pytest
from peewee import IntegrityError
from data.model.gc import purge_repository
from data.model.repository import create_repository, is_empty
from data.model.repository import get_filtered_matching_repositories
from test.fixtures import *
def test_duplicate_repository_different_kinds(initialized_db):
# Create an image repo.
create_repository('devtable', 'somenewrepo', None, repo_kind='image')
# Try to create an app repo with the same name, which should fail.
with pytest.raises(IntegrityError):
create_repository('devtable', 'somenewrepo', None, repo_kind='application')
def test_is_empty(initialized_db):
create_repository('devtable', 'somenewrepo', None, repo_kind='image')
assert is_empty('devtable', 'somenewrepo')
assert not is_empty('devtable', 'simple')
@pytest.mark.skipif(os.environ.get('TEST_DATABASE_URI', '').find('mysql') >= 0,
reason='MySQL requires specialized indexing of newly created repos')
@pytest.mark.parametrize('query', [
(''),
('e'),
])
@pytest.mark.parametrize('authed_username', [
(None),
('devtable'),
])
def test_search_pagination(query, authed_username, initialized_db):
# Create some public repos.
repo1 = create_repository('devtable', 'somenewrepo', None, repo_kind='image', visibility='public')
repo2 = create_repository('devtable', 'somenewrepo2', None, repo_kind='image', visibility='public')
repo3 = create_repository('devtable', 'somenewrepo3', None, repo_kind='image', visibility='public')
repositories = get_filtered_matching_repositories(query, filter_username=authed_username)
assert len(repositories) > 3
next_repos = get_filtered_matching_repositories(query, filter_username=authed_username, offset=1)
assert repositories[0].id != next_repos[0].id
assert repositories[1].id == next_repos[0].id

View file

@ -0,0 +1,38 @@
from datetime import date, timedelta
import pytest
from data.database import RepositoryActionCount, RepositorySearchScore
from data.model.repository import create_repository
from data.model.repositoryactioncount import update_repository_score, SEARCH_BUCKETS
from test.fixtures import *
@pytest.mark.parametrize('bucket_sums,expected_score', [
((0, 0, 0, 0), 0),
((1, 6, 24, 152), 100),
((2, 6, 24, 152), 101),
((1, 6, 24, 304), 171),
((100, 480, 24, 152), 703),
((1, 6, 24, 15200), 7131),
((300, 500, 1000, 0), 1733),
((5000, 0, 0, 0), 5434),
])
def test_update_repository_score(bucket_sums, expected_score, initialized_db):
# Create a new repository.
repo = create_repository('devtable', 'somenewrepo', None, repo_kind='image')
# Delete the RAC created in create_repository.
RepositoryActionCount.delete().where(RepositoryActionCount.repository == repo).execute()
# Add RAC rows for each of the buckets.
for index, bucket in enumerate(SEARCH_BUCKETS):
for day in range(0, bucket.days):
RepositoryActionCount.create(repository=repo,
count=(bucket_sums[index] / bucket.days * 1.0),
date=date.today() - bucket.delta + timedelta(days=day))
assert update_repository_score(repo)
assert RepositorySearchScore.get(repository=repo).score == expected_score

356
data/model/test/test_tag.py Normal file
View file

@ -0,0 +1,356 @@
import json
from datetime import datetime
from time import time
import pytest
from mock import patch
from app import docker_v2_signing_key
from data.database import (Image, RepositoryTag, ImageStorage, Repository, Manifest, ManifestBlob,
ManifestLegacyImage, TagManifestToManifest, Tag, TagToRepositoryTag)
from data.model.repository import create_repository
from data.model.tag import (list_active_repo_tags, create_or_update_tag, delete_tag,
get_matching_tags, _tag_alive, get_matching_tags_for_images,
change_tag_expiration, get_active_tag, store_tag_manifest_for_testing,
get_most_recent_tag, get_active_tag_for_repo,
create_or_update_tag_for_repo, set_tag_end_ts)
from data.model.image import find_create_or_link_image
from image.docker.schema1 import DockerSchema1ManifestBuilder
from util.timedeltastring import convert_to_timedelta
from test.fixtures import *
def _get_expected_tags(image):
expected_query = (RepositoryTag
.select()
.join(Image)
.where(RepositoryTag.hidden == False)
.where((Image.id == image.id) | (Image.ancestors ** ('%%/%s/%%' % image.id))))
return set([tag.id for tag in _tag_alive(expected_query)])
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tags(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
# Test for every image in the test database.
for image in Image.select(Image, ImageStorage).join(ImageStorage):
matching_query = get_matching_tags(image.docker_image_id, image.storage.uuid)
matching_tags = set([tag.id for tag in matching_query])
expected_tags = _get_expected_tags(image)
assert matching_tags == expected_tags, "mismatch for image %s" % image.id
oci_tags = list(Tag
.select()
.join(TagToRepositoryTag)
.where(TagToRepositoryTag.repository_tag << expected_tags))
assert len(oci_tags) == len(expected_tags)
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tag_ids_for_images(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
# Try for various sets of the first N images.
for count in [5, 10, 15]:
pairs = []
expected_tags_ids = set()
for image in Image.select(Image, ImageStorage).join(ImageStorage):
if len(pairs) >= count:
break
pairs.append((image.docker_image_id, image.storage.uuid))
expected_tags_ids.update(_get_expected_tags(image))
matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
assert matching_tags_ids == expected_tags_ids
@pytest.mark.parametrize('max_subqueries,max_image_lookup_count', [
(1, 1),
(10, 10),
(100, 500),
])
def test_get_matching_tag_ids_for_all_images(max_subqueries, max_image_lookup_count, initialized_db):
with patch('data.model.tag._MAX_SUB_QUERIES', max_subqueries):
with patch('data.model.tag._MAX_IMAGE_LOOKUP_COUNT', max_image_lookup_count):
pairs = []
for image in Image.select(Image, ImageStorage).join(ImageStorage):
pairs.append((image.docker_image_id, image.storage.uuid))
expected_tags_ids = set([tag.id for tag in _tag_alive(RepositoryTag.select())])
matching_tags_ids = set([tag.id for tag in get_matching_tags_for_images(pairs)])
# Ensure every alive tag was found.
assert matching_tags_ids == expected_tags_ids
def test_get_matching_tag_ids_images_filtered(initialized_db):
def filter_query(query):
return query.join(Repository).where(Repository.name == 'simple')
filtered_images = filter_query(Image
.select(Image, ImageStorage)
.join(RepositoryTag)
.switch(Image)
.join(ImageStorage)
.switch(Image))
expected_tags_query = _tag_alive(filter_query(RepositoryTag
.select()))
pairs = []
for image in filtered_images:
pairs.append((image.docker_image_id, image.storage.uuid))
matching_tags = get_matching_tags_for_images(pairs, filter_images=filter_query,
filter_tags=filter_query)
expected_tag_ids = set([tag.id for tag in expected_tags_query])
matching_tags_ids = set([tag.id for tag in matching_tags])
# Ensure every alive tag was found.
assert matching_tags_ids == expected_tag_ids
def _get_oci_tag(tag):
return (Tag
.select()
.join(TagToRepositoryTag)
.where(TagToRepositoryTag.repository_tag == tag)).get()
def assert_tags(repository, *args):
tags = list(list_active_repo_tags(repository))
assert len(tags) == len(args)
tags_dict = {}
for tag in tags:
assert not tag.name in tags_dict
assert not tag.hidden
assert not tag.lifetime_end_ts or tag.lifetime_end_ts > time()
tags_dict[tag.name] = tag
oci_tag = _get_oci_tag(tag)
assert oci_tag.name == tag.name
assert not oci_tag.hidden
assert oci_tag.reversion == tag.reversion
if tag.lifetime_end_ts:
assert oci_tag.lifetime_end_ms == (tag.lifetime_end_ts * 1000)
else:
assert oci_tag.lifetime_end_ms is None
for expected in args:
assert expected in tags_dict
def test_create_reversion_tag(initialized_db):
repository = create_repository('devtable', 'somenewrepo', None)
manifest = Manifest.get()
image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
oci_manifest=manifest, reversion=True)
assert footag.reversion
oci_tag = _get_oci_tag(footag)
assert oci_tag.name == footag.name
assert not oci_tag.hidden
assert oci_tag.reversion == footag.reversion
def test_list_active_tags(initialized_db):
# Create a new repository.
repository = create_repository('devtable', 'somenewrepo', None)
manifest = Manifest.get()
# Create some images.
image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
image2 = find_create_or_link_image('foobarimage2', repository, None, {}, 'local_us')
# Make sure its tags list is empty.
assert_tags(repository)
# Add some new tags.
footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
oci_manifest=manifest)
bartag = create_or_update_tag_for_repo(repository, 'bar', image1.docker_image_id,
oci_manifest=manifest)
# Since timestamps are stored on a second-granularity, we need to make the tags "start"
# before "now", so when we recreate them below, they don't conflict.
footag.lifetime_start_ts -= 5
footag.save()
bartag.lifetime_start_ts -= 5
bartag.save()
footag_oci = _get_oci_tag(footag)
footag_oci.lifetime_start_ms -= 5000
footag_oci.save()
bartag_oci = _get_oci_tag(bartag)
bartag_oci.lifetime_start_ms -= 5000
bartag_oci.save()
# Make sure they are returned.
assert_tags(repository, 'foo', 'bar')
# Set the expirations to be explicitly empty.
set_tag_end_ts(footag, None)
set_tag_end_ts(bartag, None)
# Make sure they are returned.
assert_tags(repository, 'foo', 'bar')
# Mark as a tag as expiring in the far future, and make sure it is still returned.
set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
# Make sure they are returned.
assert_tags(repository, 'foo', 'bar')
# Delete a tag and make sure it isn't returned.
footag = delete_tag('devtable', 'somenewrepo', 'foo')
set_tag_end_ts(footag, footag.lifetime_end_ts - 4)
assert_tags(repository, 'bar')
# Add a new foo again.
footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
oci_manifest=manifest)
footag.lifetime_start_ts -= 3
footag.save()
footag_oci = _get_oci_tag(footag)
footag_oci.lifetime_start_ms -= 3000
footag_oci.save()
assert_tags(repository, 'foo', 'bar')
# Mark as a tag as expiring in the far future, and make sure it is still returned.
set_tag_end_ts(footag, footag.lifetime_start_ts + 10000000)
# Make sure they are returned.
assert_tags(repository, 'foo', 'bar')
# "Move" foo by updating it and make sure we don't get duplicates.
create_or_update_tag_for_repo(repository, 'foo', image2.docker_image_id, oci_manifest=manifest)
assert_tags(repository, 'foo', 'bar')
@pytest.mark.parametrize('expiration_offset, expected_offset', [
(None, None),
('0s', '1h'),
('30m', '1h'),
('2h', '2h'),
('2w', '2w'),
('200w', '104w'),
])
def test_change_tag_expiration(expiration_offset, expected_offset, initialized_db):
repository = create_repository('devtable', 'somenewrepo', None)
image1 = find_create_or_link_image('foobarimage1', repository, None, {}, 'local_us')
manifest = Manifest.get()
footag = create_or_update_tag_for_repo(repository, 'foo', image1.docker_image_id,
oci_manifest=manifest)
expiration_date = None
if expiration_offset is not None:
expiration_date = datetime.utcnow() + convert_to_timedelta(expiration_offset)
assert change_tag_expiration(footag, expiration_date)
# Lookup the tag again.
footag_updated = get_active_tag('devtable', 'somenewrepo', 'foo')
oci_tag = _get_oci_tag(footag_updated)
if expected_offset is None:
assert footag_updated.lifetime_end_ts is None
assert oci_tag.lifetime_end_ms is None
else:
start_date = datetime.utcfromtimestamp(footag_updated.lifetime_start_ts)
end_date = datetime.utcfromtimestamp(footag_updated.lifetime_end_ts)
expected_end_date = start_date + convert_to_timedelta(expected_offset)
assert (expected_end_date - end_date).total_seconds() < 5 # variance in test
assert oci_tag.lifetime_end_ms == (footag_updated.lifetime_end_ts * 1000)
def random_storages():
return list(ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(10))
def repeated_storages():
storages = list(ImageStorage.select().where(~(ImageStorage.content_checksum >> None)).limit(5))
return storages + storages
@pytest.mark.parametrize('get_storages', [
random_storages,
repeated_storages,
])
def test_store_tag_manifest(get_storages, initialized_db):
# Create a manifest with some layers.
builder = DockerSchema1ManifestBuilder('devtable', 'simple', 'sometag')
storages = get_storages()
assert storages
repo = model.repository.get_repository('devtable', 'simple')
storage_id_map = {}
for index, storage in enumerate(storages):
image_id = 'someimage%s' % index
builder.add_layer(storage.content_checksum, json.dumps({'id': image_id}))
find_create_or_link_image(image_id, repo, 'devtable', {}, 'local_us')
storage_id_map[storage.content_checksum] = storage.id
manifest = builder.build(docker_v2_signing_key)
tag_manifest, _ = store_tag_manifest_for_testing('devtable', 'simple', 'sometag', manifest,
manifest.leaf_layer_v1_image_id, storage_id_map)
# Ensure we have the new-model expected rows.
mapping_row = TagManifestToManifest.get(tag_manifest=tag_manifest)
assert mapping_row.manifest is not None
assert mapping_row.manifest.manifest_bytes == manifest.bytes.as_encoded_str()
assert mapping_row.manifest.digest == str(manifest.digest)
blob_rows = {m.blob_id for m in
ManifestBlob.select().where(ManifestBlob.manifest == mapping_row.manifest)}
assert blob_rows == {s.id for s in storages}
assert ManifestLegacyImage.get(manifest=mapping_row.manifest).image == tag_manifest.tag.image
def test_get_most_recent_tag(initialized_db):
# Create a hidden tag that is the most recent.
repo = model.repository.get_repository('devtable', 'simple')
image = model.tag.get_tag_image('devtable', 'simple', 'latest')
model.tag.create_temporary_hidden_tag(repo, image, 10000000)
# Ensure we find a non-hidden tag.
found = model.tag.get_most_recent_tag(repo)
assert not found.hidden
def test_get_active_tag_for_repo(initialized_db):
repo = model.repository.get_repository('devtable', 'simple')
image = model.tag.get_tag_image('devtable', 'simple', 'latest')
hidden_tag = model.tag.create_temporary_hidden_tag(repo, image, 10000000)
# Ensure get active tag for repo cannot find it.
assert model.tag.get_active_tag_for_repo(repo, hidden_tag) is None
assert model.tag.get_active_tag_for_repo(repo, 'latest') is not None

View file

@ -0,0 +1,61 @@
import pytest
from data.model.team import (add_or_invite_to_team, create_team, confirm_team_invite,
list_team_users, validate_team_name)
from data.model.organization import create_organization
from data.model.user import get_user, create_user_noverify
from test.fixtures import *
@pytest.mark.parametrize('name, is_valid', [
('', False),
('f', False),
('fo', True),
('f' * 255, True),
('f' * 256, False),
(' ', False),
('helloworld', True),
('hello_world', True),
('hello-world', True),
('hello world', False),
('HelloWorld', False),
])
def test_validate_team_name(name, is_valid):
result, _ = validate_team_name(name)
assert result == is_valid
def is_in_team(team, user):
return user.username in {u.username for u in list_team_users(team)}
def test_invite_to_team(initialized_db):
first_user = get_user('devtable')
second_user = create_user_noverify('newuser', 'foo@example.com')
def run_invite_flow(orgname):
# Create an org owned by `devtable`.
org = create_organization(orgname, orgname + '@example.com', first_user)
# Create another team and add `devtable` to it. Since `devtable` is already
# in the org, it should be done directly.
other_team = create_team('otherteam', org, 'admin')
invite = add_or_invite_to_team(first_user, other_team, user_obj=first_user)
assert invite is None
assert is_in_team(other_team, first_user)
# Try to add `newuser` to the team, which should require an invite.
invite = add_or_invite_to_team(first_user, other_team, user_obj=second_user)
assert invite is not None
assert not is_in_team(other_team, second_user)
# Accept the invite.
confirm_team_invite(invite.invite_token, second_user)
assert is_in_team(other_team, second_user)
# Run for a new org.
run_invite_flow('firstorg')
# Create another org and repeat, ensuring the same operations perform the same way.
run_invite_flow('secondorg')

View file

@ -0,0 +1,205 @@
from datetime import datetime
import pytest
from mock import patch
from data.database import EmailConfirmation, User, DeletedNamespace
from data.model.organization import get_organization
from data.model.notification import create_notification
from data.model.team import create_team, add_user_to_team
from data.model.user import create_user_noverify, validate_reset_code, get_active_users
from data.model.user import mark_namespace_for_deletion, delete_namespace_via_marker
from data.model.user import create_robot, lookup_robot, list_namespace_robots
from data.model.user import get_pull_credentials, retrieve_robot_token, verify_robot
from data.model.user import InvalidRobotException, delete_robot, get_matching_users
from data.model.repository import create_repository
from data.fields import Credential
from data.queue import WorkQueue
from util.timedeltastring import convert_to_timedelta
from util.timedeltastring import convert_to_timedelta
from util.security.token import encode_public_private_token
from test.fixtures import *
def test_create_user_with_expiration(initialized_db):
with patch('data.model.config.app_config', {'DEFAULT_TAG_EXPIRATION': '1h'}):
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
assert user.removed_tag_expiration_s == 60 * 60
@pytest.mark.parametrize('token_lifetime, time_since', [
('1m', '2m'),
('2m', '1m'),
('1h', '1m'),
])
def test_validation_code(token_lifetime, time_since, initialized_db):
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
created = datetime.now() - convert_to_timedelta(time_since)
verification_code, unhashed = Credential.generate()
confirmation = EmailConfirmation.create(user=user, pw_reset=True,
created=created, verification_code=verification_code)
encoded = encode_public_private_token(confirmation.code, unhashed)
with patch('data.model.config.app_config', {'USER_RECOVERY_TOKEN_LIFETIME': token_lifetime}):
result = validate_reset_code(encoded)
expect_success = convert_to_timedelta(token_lifetime) >= convert_to_timedelta(time_since)
assert expect_success == (result is not None)
@pytest.mark.parametrize('disabled', [
(True),
(False),
])
@pytest.mark.parametrize('deleted', [
(True),
(False),
])
def test_get_active_users(disabled, deleted, initialized_db):
# Delete a user.
deleted_user = model.user.get_user('public')
queue = WorkQueue('testgcnamespace', lambda db: db.transaction())
mark_namespace_for_deletion(deleted_user, [], queue)
users = get_active_users(disabled=disabled, deleted=deleted)
deleted_found = [user for user in users if user.id == deleted_user.id]
assert bool(deleted_found) == (deleted and disabled)
for user in users:
if not disabled:
assert user.enabled
def test_mark_namespace_for_deletion(initialized_db):
def create_transaction(db):
return db.transaction()
# Create a user and then mark it for deletion.
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
# Add some robots.
create_robot('foo', user)
create_robot('bar', user)
assert lookup_robot('foobar+foo') is not None
assert lookup_robot('foobar+bar') is not None
assert len(list(list_namespace_robots('foobar'))) == 2
# Mark the user for deletion.
queue = WorkQueue('testgcnamespace', create_transaction)
mark_namespace_for_deletion(user, [], queue)
# Ensure the older user is still in the DB.
older_user = User.get(id=user.id)
assert older_user.username != 'foobar'
# Ensure the robots are deleted.
with pytest.raises(InvalidRobotException):
assert lookup_robot('foobar+foo')
with pytest.raises(InvalidRobotException):
assert lookup_robot('foobar+bar')
assert len(list(list_namespace_robots(older_user.username))) == 0
# Ensure we can create a user with the same namespace again.
new_user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
assert new_user.id != user.id
# Ensure the older user is still in the DB.
assert User.get(id=user.id).username != 'foobar'
def test_delete_namespace_via_marker(initialized_db):
def create_transaction(db):
return db.transaction()
# Create a user and then mark it for deletion.
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
# Add some repositories.
create_repository('foobar', 'somerepo', user)
create_repository('foobar', 'anotherrepo', user)
# Mark the user for deletion.
queue = WorkQueue('testgcnamespace', create_transaction)
marker_id = mark_namespace_for_deletion(user, [], queue)
# Delete the user.
delete_namespace_via_marker(marker_id, [])
# Ensure the user was actually deleted.
with pytest.raises(User.DoesNotExist):
User.get(id=user.id)
with pytest.raises(DeletedNamespace.DoesNotExist):
DeletedNamespace.get(id=marker_id)
def test_delete_robot(initialized_db):
# Create a robot account.
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
robot, _ = create_robot('foo', user)
# Add some notifications and other rows pointing to the robot.
create_notification('repo_push', robot)
team = create_team('someteam', get_organization('buynlarge'), 'member')
add_user_to_team(robot, team)
# Ensure the robot exists.
assert lookup_robot(robot.username).id == robot.id
# Delete the robot.
delete_robot(robot.username)
# Ensure it is gone.
with pytest.raises(InvalidRobotException):
lookup_robot(robot.username)
def test_get_matching_users(initialized_db):
# Exact match.
for user in User.select().where(User.organization == False, User.robot == False):
assert list(get_matching_users(user.username))[0].username == user.username
# Prefix matching.
for user in User.select().where(User.organization == False, User.robot == False):
assert user.username in [r.username for r in get_matching_users(user.username[:2])]
def test_get_matching_users_with_same_prefix(initialized_db):
# Create a bunch of users with the same prefix.
for index in range(0, 20):
create_user_noverify('foo%s' % index, 'foo%s@example.com' % index, email_required=False)
# For each user, ensure that lookup of the exact name is found first.
for index in range(0, 20):
username = 'foo%s' % index
assert list(get_matching_users(username))[0].username == username
# Prefix matching.
found = list(get_matching_users('foo', limit=50))
assert len(found) == 20
def test_robot(initialized_db):
# Create a robot account.
user = create_user_noverify('foobar', 'foo@example.com', email_required=False)
robot, token = create_robot('foo', user)
assert retrieve_robot_token(robot) == token
# Ensure we can retrieve its information.
found = lookup_robot('foobar+foo')
assert found == robot
creds = get_pull_credentials('foobar+foo')
assert creds is not None
assert creds['username'] == 'foobar+foo'
assert creds['password'] == token
assert verify_robot('foobar+foo', token) == robot
with pytest.raises(InvalidRobotException):
assert verify_robot('foobar+foo', 'someothertoken')
with pytest.raises(InvalidRobotException):
assert verify_robot('foobar+unknownbot', token)

View file

@ -0,0 +1,89 @@
from data import model
from test.fixtures import *
NO_ACCESS_USER = 'freshuser'
READ_ACCESS_USER = 'reader'
ADMIN_ACCESS_USER = 'devtable'
PUBLIC_USER = 'public'
RANDOM_USER = 'randomuser'
OUTSIDE_ORG_USER = 'outsideorg'
ADMIN_ROBOT_USER = 'devtable+dtrobot'
ORGANIZATION = 'buynlarge'
SIMPLE_REPO = 'simple'
PUBLIC_REPO = 'publicrepo'
RANDOM_REPO = 'randomrepo'
OUTSIDE_ORG_REPO = 'coolrepo'
ORG_REPO = 'orgrepo'
ANOTHER_ORG_REPO = 'anotherorgrepo'
# Note: The shared repo has devtable as admin, public as a writer and reader as a reader.
SHARED_REPO = 'shared'
def assertDoesNotHaveRepo(username, name):
repos = list(model.repository.get_visible_repositories(username))
names = [repo.name for repo in repos]
assert not name in names
def assertHasRepo(username, name):
repos = list(model.repository.get_visible_repositories(username))
names = [repo.name for repo in repos]
assert name in names
def test_noaccess(initialized_db):
repos = list(model.repository.get_visible_repositories(NO_ACCESS_USER))
names = [repo.name for repo in repos]
assert not names
# Try retrieving public repos now.
repos = list(model.repository.get_visible_repositories(NO_ACCESS_USER, include_public=True))
names = [repo.name for repo in repos]
assert PUBLIC_REPO in names
def test_public(initialized_db):
assertHasRepo(PUBLIC_USER, PUBLIC_REPO)
assertHasRepo(PUBLIC_USER, SHARED_REPO)
assertDoesNotHaveRepo(PUBLIC_USER, SIMPLE_REPO)
assertDoesNotHaveRepo(PUBLIC_USER, RANDOM_REPO)
assertDoesNotHaveRepo(PUBLIC_USER, OUTSIDE_ORG_REPO)
def test_reader(initialized_db):
assertHasRepo(READ_ACCESS_USER, SHARED_REPO)
assertHasRepo(READ_ACCESS_USER, ORG_REPO)
assertDoesNotHaveRepo(READ_ACCESS_USER, SIMPLE_REPO)
assertDoesNotHaveRepo(READ_ACCESS_USER, RANDOM_REPO)
assertDoesNotHaveRepo(READ_ACCESS_USER, OUTSIDE_ORG_REPO)
assertDoesNotHaveRepo(READ_ACCESS_USER, PUBLIC_REPO)
def test_random(initialized_db):
assertHasRepo(RANDOM_USER, RANDOM_REPO)
assertDoesNotHaveRepo(RANDOM_USER, SIMPLE_REPO)
assertDoesNotHaveRepo(RANDOM_USER, SHARED_REPO)
assertDoesNotHaveRepo(RANDOM_USER, ORG_REPO)
assertDoesNotHaveRepo(RANDOM_USER, ANOTHER_ORG_REPO)
assertDoesNotHaveRepo(RANDOM_USER, PUBLIC_REPO)
def test_admin(initialized_db):
assertHasRepo(ADMIN_ACCESS_USER, SIMPLE_REPO)
assertHasRepo(ADMIN_ACCESS_USER, SHARED_REPO)
assertHasRepo(ADMIN_ACCESS_USER, ORG_REPO)
assertHasRepo(ADMIN_ACCESS_USER, ANOTHER_ORG_REPO)
assertDoesNotHaveRepo(ADMIN_ACCESS_USER, OUTSIDE_ORG_REPO)

105
data/model/token.py Normal file
View file

@ -0,0 +1,105 @@
import logging
from peewee import JOIN
from active_migration import ActiveDataMigration, ERTMigrationFlags
from data.database import (AccessToken, AccessTokenKind, Repository, Namespace, Role,
RepositoryBuildTrigger)
from data.model import DataModelException, _basequery, InvalidTokenException
logger = logging.getLogger(__name__)
ACCESS_TOKEN_NAME_PREFIX_LENGTH = 32
ACCESS_TOKEN_CODE_MINIMUM_LENGTH = 32
def create_access_token(repo, role, kind=None, friendly_name=None):
role = Role.get(Role.name == role)
kind_ref = None
if kind is not None:
kind_ref = AccessTokenKind.get(AccessTokenKind.name == kind)
new_token = AccessToken.create(repository=repo, temporary=True, role=role, kind=kind_ref,
friendly_name=friendly_name)
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
new_token.code = new_token.token_name + new_token.token_code.decrypt()
new_token.save()
return new_token
def create_delegate_token(namespace_name, repository_name, friendly_name,
role='read'):
read_only = Role.get(name=role)
repo = _basequery.get_existing_repository(namespace_name, repository_name)
new_token = AccessToken.create(repository=repo, role=read_only,
friendly_name=friendly_name, temporary=False)
if ActiveDataMigration.has_flag(ERTMigrationFlags.WRITE_OLD_FIELDS):
new_token.code = new_token.token_name + new_token.token_code.decrypt()
new_token.save()
return new_token
def load_token_data(code):
""" Load the permissions for any token by code. """
token_name = code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
token_code = code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
if not token_name or not token_code:
raise InvalidTokenException('Invalid delegate token code: %s' % code)
# Try loading by name and then comparing the code.
assert token_name
try:
found = (AccessToken
.select(AccessToken, Repository, Namespace, Role)
.join(Role)
.switch(AccessToken)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(AccessToken.token_name == token_name)
.get())
assert token_code
if found.token_code is None or not found.token_code.matches(token_code):
raise InvalidTokenException('Invalid delegate token code: %s' % code)
assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
return found
except AccessToken.DoesNotExist:
pass
# Legacy: Try loading the full code directly.
# TODO(remove-unenc): Remove this once migrated.
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
try:
return (AccessToken
.select(AccessToken, Repository, Namespace, Role)
.join(Role)
.switch(AccessToken)
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(AccessToken.code == code)
.get())
except AccessToken.DoesNotExist:
raise InvalidTokenException('Invalid delegate token code: %s' % code)
raise InvalidTokenException('Invalid delegate token code: %s' % code)
def get_full_token_string(token):
""" Returns the full string to use for this token to login. """
if ActiveDataMigration.has_flag(ERTMigrationFlags.READ_OLD_FIELDS):
if token.token_name is None:
return token.code
assert token.token_name
token_code = token.token_code.decrypt()
assert len(token.token_name) == ACCESS_TOKEN_NAME_PREFIX_LENGTH
assert len(token_code) >= ACCESS_TOKEN_CODE_MINIMUM_LENGTH
return '%s%s' % (token.token_name, token_code)

1217
data/model/user.py Normal file

File diff suppressed because it is too large Load diff