diff --git a/data/database.py b/data/database.py index aba8a578d..8ddc4564a 100644 --- a/data/database.py +++ b/data/database.py @@ -29,6 +29,16 @@ SCHEME_RANDOM_FUNCTION = { 'postgresql+psycopg2': fn.Random, } +def real_for_update(query): + return query.for_update() + +def null_for_update(query): + return query + +SCHEME_SPECIALIZED_FOR_UPDATE = { + 'sqlite': null_for_update, +} + class CallableProxy(Proxy): def __call__(self, *args, **kwargs): if self.obj is None: @@ -68,6 +78,7 @@ class UseThenDisconnect(object): db = Proxy() read_slave = Proxy() db_random_func = CallableProxy() +db_for_update = CallableProxy() def validate_database_url(url, connect_timeout=5): @@ -105,6 +116,8 @@ def configure(config_object): parsed_write_uri = make_url(write_db_uri) db_random_func.initialize(SCHEME_RANDOM_FUNCTION[parsed_write_uri.drivername]) + db_for_update.initialize(SCHEME_SPECIALIZED_FOR_UPDATE.get(parsed_write_uri.drivername, + real_for_update)) read_slave_uri = config_object.get('DB_READ_SLAVE_URI', None) if read_slave_uri is not None: diff --git a/data/model/legacy.py b/data/model/legacy.py index f8c04e04c..a4739fc25 100644 --- a/data/model/legacy.py +++ b/data/model/legacy.py @@ -14,7 +14,7 @@ from data.database import (User, Repository, Image, AccessToken, Role, Repositor ExternalNotificationEvent, ExternalNotificationMethod, RepositoryNotification, RepositoryAuthorizedEmail, TeamMemberInvite, DerivedImageStorage, ImageStorageTransformation, random_string_generator, - db, BUILD_PHASE, QuayUserField, validate_database_url) + db, BUILD_PHASE, QuayUserField, validate_database_url, db_for_update) from peewee import JOIN_LEFT_OUTER, fn from util.validation import (validate_username, validate_email, validate_password, INVALID_PASSWORD_MESSAGE) @@ -295,6 +295,9 @@ def delete_robot(robot_username): def _list_entity_robots(entity_name): + """ Return the list of robots for the specified entity. This MUST return a query, not a + materialized list so that callers can use db_for_update. + """ return (User .select() .join(FederatedLogin) @@ -903,14 +906,17 @@ def change_password(user, new_password): delete_notifications_by_kind(user, 'password_required') -def change_username(user, new_username): +def change_username(user_id, new_username): (username_valid, username_issue) = validate_username(new_username) if not username_valid: raise InvalidUsernameException('Invalid username %s: %s' % (new_username, username_issue)) with config.app_config['DB_TRANSACTION_FACTORY'](db): + # Reload the user for update + user = db_for_update(User.select().where(User.id == user_id)).get() + # Rename the robots - for robot in _list_entity_robots(user.username): + for robot in db_for_update(_list_entity_robots(user.username)): _, robot_shortname = parse_robot_username(robot.username) new_robot_name = format_robot_username(new_username, robot_shortname) robot.username = new_robot_name @@ -1251,9 +1257,9 @@ def _find_or_link_image(existing_image, repository, username, translations, pref storage.locations = {placement.location.name for placement in storage.imagestorageplacement_set} - new_image = Image.create(docker_image_id=existing_image.docker_image_id, - repository=repository, storage=storage, - ancestors=new_image_ancestry) + new_image = Image.create(docker_image_id=existing_image.docker_image_id, + repository=repository, storage=storage, + ancestors=new_image_ancestry) logger.debug('Storing translation %s -> %s', existing_image.id, new_image.id) translations[existing_image.id] = new_image.id @@ -1403,7 +1409,7 @@ def set_image_metadata(docker_image_id, namespace_name, repository_name, created Image.docker_image_id == docker_image_id)) try: - fetched = query.get() + fetched = db_for_update(query).get() except Image.DoesNotExist: raise DataModelException('No image with specified id and repository') diff --git a/data/queue.py b/data/queue.py index e560e7cd9..52ccd9770 100644 --- a/data/queue.py +++ b/data/queue.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta -from data.database import QueueItem, db +from data.database import QueueItem, db, db_for_update from util.morecollections import AttrDict @@ -41,6 +41,9 @@ class WorkQueue(object): def _name_match_query(self): return '%s%%' % self._canonical_name([self._queue_name] + self._canonical_name_match_list) + def _item_by_id_for_update(self, queue_id): + return db_for_update(QueueItem.select().where(QueueItem.id == queue_id)).get() + def update_metrics(self): if self._reporter is None: return @@ -91,7 +94,7 @@ class WorkQueue(object): item = None try: - db_item = avail.order_by(QueueItem.id).get() + db_item = db_for_update(avail.order_by(QueueItem.id)).get() db_item.available = False db_item.processing_expires = now + timedelta(seconds=processing_time) db_item.retries_remaining -= 1 @@ -111,14 +114,14 @@ class WorkQueue(object): def complete(self, completed_item): with self._transaction_factory(db): - completed_item_obj = QueueItem.get(QueueItem.id == completed_item.id) + completed_item_obj = self._item_by_id_for_update(completed_item.id) completed_item_obj.delete_instance() self._currently_processing = False def incomplete(self, incomplete_item, retry_after=300, restore_retry=False): with self._transaction_factory(db): retry_date = datetime.utcnow() + timedelta(seconds=retry_after) - incomplete_item_obj = QueueItem.get(QueueItem.id == incomplete_item.id) + incomplete_item_obj = self._item_by_id_for_update(incomplete_item.id) incomplete_item_obj.available_after = retry_date incomplete_item_obj.available = True @@ -128,12 +131,12 @@ class WorkQueue(object): incomplete_item_obj.save() self._currently_processing = False - def extend_processing(self, seconds_from_now, minimum_extension=MINIMUM_EXTENSION): + def extend_processing(self, item, seconds_from_now, minimum_extension=MINIMUM_EXTENSION): with self._transaction_factory(db): - queue_item = QueueItem.get(QueueItem.id == self.id) + queue_item = self._item_by_id_for_update(item.id) new_expiration = datetime.utcnow() + timedelta(seconds=seconds_from_now) # Only actually write the new expiration to the db if it moves the expiration some minimum if new_expiration - queue_item.processing_expires > minimum_extension: queue_item.processing_expires = new_expiration - queue_item.save() \ No newline at end of file + queue_item.save() diff --git a/endpoints/api/user.py b/endpoints/api/user.py index b713b3ff8..cffffacac 100644 --- a/endpoints/api/user.py +++ b/endpoints/api/user.py @@ -246,7 +246,7 @@ class User(ApiResource): # Username already used raise request_error(message='Username is already in use') - model.change_username(user, new_username) + model.change_username(user.id, new_username) except model.InvalidPasswordException, ex: raise request_error(exception=ex)