diff --git a/data/database.py b/data/database.py index a3d038e4c..62c59e6e0 100644 --- a/data/database.py +++ b/data/database.py @@ -378,7 +378,7 @@ def configure(config_object, testing=False): real_for_update)) db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, function_concat)) - db_encrypter.initialize(FieldEncrypter(config_object['DATABASE_SECRET_KEY'])) + db_encrypter.initialize(FieldEncrypter(config_object.get('DATABASE_SECRET_KEY'))) read_replicas = config_object.get('DB_READ_REPLICAS', None) is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly' diff --git a/data/encryption.py b/data/encryption.py index 83a90860a..429f09827 100644 --- a/data/encryption.py +++ b/data/encryption.py @@ -59,11 +59,15 @@ class FieldEncrypter(object): and the application. """ def __init__(self, secret_key, version='v0'): - self._secret_key = convert_secret_key(secret_key) + # NOTE: secret_key will be None when the system is being first initialized, so we allow that + # case here, but make sure to assert that it is *not* None below if any encryption is actually + # needed. + self._secret_key = convert_secret_key(secret_key) if secret_key is not None else None self._encryption_version = _VERSIONS[version] def encrypt_value(self, value, field_max_length=None): """ Encrypts the value using the current version of encryption. """ + assert self._secret_key is not None encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length) return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value) @@ -71,6 +75,7 @@ class FieldEncrypter(object): """ Decrypts the value, returning it. If the value cannot be decrypted raises a DecryptionFailureException. """ + assert self._secret_key is not None if _SEPARATOR not in value: raise DecryptionFailureException('Invalid encrypted value') diff --git a/data/migrations/tester.py b/data/migrations/tester.py index 01e862909..2643b80e2 100644 --- a/data/migrations/tester.py +++ b/data/migrations/tester.py @@ -2,7 +2,7 @@ import json import logging import uuid -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod, abstractproperty from datetime import datetime from six import add_metaclass @@ -92,6 +92,10 @@ class MigrationTester(object): """ TestDataType = DataTypes + @abstractproperty + def is_testing(self): + """ Returns whether we are currently under a migration test. """ + @abstractmethod def populate_table(self, table_name, fields): """ Called to populate a table with the given fields filled in with testing data. """ @@ -107,6 +111,10 @@ class NoopTester(MigrationTester): class PopulateTestDataTester(MigrationTester): + @property + def is_testing(self): + return True + def populate_table(self, table_name, fields): columns = {field_name: field_type() for field_name, field_type in fields} field_name_vars = [':' + field_name for field_name, _ in fields] diff --git a/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py b/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py index 3a48b9b45..2b73b8afa 100644 --- a/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py +++ b/data/migrations/versions/34c8ef052ec9_repo_mirror_columns.py @@ -80,11 +80,13 @@ def upgrade(tables, tester, progress_reporter): op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True)) - for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)): - repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository) - logger.info('migrating %s' % repo) - repo_mirror.external_reference = repo - repo_mirror.save() + from app import app + if app.config.get('SETUP_COMPLETE', False) or tester.is_testing: + for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)): + repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository) + logger.info('migrating %s' % repo) + repo_mirror.external_reference = repo + repo_mirror.save() op.drop_column('repomirrorconfig', 'external_registry') op.drop_column('repomirrorconfig', 'external_namespace') @@ -109,14 +111,16 @@ def downgrade(tables, tester, progress_reporter): op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True)) - logger.info('Restoring columns from external_reference') - for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)): - logger.info('Restoring %s' % repo_mirror.external_reference) - parts = repo_mirror.external_reference.split('/', 2) - repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED' - repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED' - repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED' - repo_mirror.save() + from app import app + if app.config.get('SETUP_COMPLETE', False): + logger.info('Restoring columns from external_reference') + for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)): + logger.info('Restoring %s' % repo_mirror.external_reference) + parts = repo_mirror.external_reference.split('/', 2) + repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED' + repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED' + repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED' + repo_mirror.save() op.drop_column('repomirrorconfig', 'external_reference') diff --git a/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py b/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py index 4a1416eaa..43459af40 100644 --- a/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py +++ b/data/migrations/versions/703298a825c2_backfill_new_encrypted_fields.py @@ -98,157 +98,159 @@ class OAuthApplication(BaseModel): def upgrade(tables, tester, progress_reporter): op = ProgressWrapper(original_op, progress_reporter) - # Empty all access token names to fix the bug where we put the wrong name and code - # in for some tokens. - AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute() + from app import app + if app.config.get('SETUP_COMPLETE', False) or tester.is_testing: + # Empty all access token names to fix the bug where we put the wrong name and code + # in for some tokens. + AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute() - # AccessToken. - logger.info('Backfilling encrypted credentials for access tokens') - for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) | - (AccessToken.token_name == ''))): - logger.info('Backfilling encrypted credentials for access token %s', access_token.id) - assert access_token.code is not None - assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] - assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:] + # AccessToken. + logger.info('Backfilling encrypted credentials for access tokens') + for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) | + (AccessToken.token_name == ''))): + logger.info('Backfilling encrypted credentials for access token %s', access_token.id) + assert access_token.code is not None + assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] + assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:] - token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] - token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]) + token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] + token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]) - (AccessToken - .update(token_name=token_name, token_code=token_code) - .where(AccessToken.id == access_token.id, AccessToken.code == access_token.code) - .execute()) + (AccessToken + .update(token_name=token_name, token_code=token_code) + .where(AccessToken.id == access_token.id, AccessToken.code == access_token.code) + .execute()) - assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0 + assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0 - # Robots. - logger.info('Backfilling encrypted credentials for robots') - while True: - has_row = False - query = (User - .select() - .join(RobotAccountToken, JOIN.LEFT_OUTER) - .where(User.robot == True, RobotAccountToken.id >> None) - .limit(BATCH_SIZE)) + # Robots. + logger.info('Backfilling encrypted credentials for robots') + while True: + has_row = False + query = (User + .select() + .join(RobotAccountToken, JOIN.LEFT_OUTER) + .where(User.robot == True, RobotAccountToken.id >> None) + .limit(BATCH_SIZE)) - for robot_user in query: - logger.info('Backfilling encrypted credentials for robot %s', robot_user.id) - has_row = True - try: - RobotAccountToken.create(robot_account=robot_user, - token=_decrypted(robot_user.email), - fully_migrated=False) - except IntegrityError: + for robot_user in query: + logger.info('Backfilling encrypted credentials for robot %s', robot_user.id) + has_row = True + try: + RobotAccountToken.create(robot_account=robot_user, + token=_decrypted(robot_user.email), + fully_migrated=False) + except IntegrityError: + break + + if not has_row: break - if not has_row: - break + # RepositoryBuildTrigger + logger.info('Backfilling encrypted credentials for repo build triggers') + for repo_build_trigger in _iterate(RepositoryBuildTrigger, + (RepositoryBuildTrigger.fully_migrated == False)): + logger.info('Backfilling encrypted credentials for repo build trigger %s', + repo_build_trigger.id) - # RepositoryBuildTrigger - logger.info('Backfilling encrypted credentials for repo build triggers') - for repo_build_trigger in _iterate(RepositoryBuildTrigger, - (RepositoryBuildTrigger.fully_migrated == False)): - logger.info('Backfilling encrypted credentials for repo build trigger %s', - repo_build_trigger.id) + (RepositoryBuildTrigger + .update(secure_auth_token=_decrypted(repo_build_trigger.auth_token), + secure_private_key=_decrypted(repo_build_trigger.private_key), + fully_migrated=True) + .where(RepositoryBuildTrigger.id == repo_build_trigger.id, + RepositoryBuildTrigger.uuid == repo_build_trigger.uuid) + .execute()) - (RepositoryBuildTrigger - .update(secure_auth_token=_decrypted(repo_build_trigger.auth_token), - secure_private_key=_decrypted(repo_build_trigger.private_key), - fully_migrated=True) - .where(RepositoryBuildTrigger.id == repo_build_trigger.id, - RepositoryBuildTrigger.uuid == repo_build_trigger.uuid) - .execute()) + assert (RepositoryBuildTrigger + .select() + .where(RepositoryBuildTrigger.fully_migrated == False) + .count()) == 0 - assert (RepositoryBuildTrigger - .select() - .where(RepositoryBuildTrigger.fully_migrated == False) - .count()) == 0 + # AppSpecificAuthToken + logger.info('Backfilling encrypted credentials for app specific auth tokens') + for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) | + (AppSpecificAuthToken.token_name == '') | + (AppSpecificAuthToken.token_secret >> None))): + logger.info('Backfilling encrypted credentials for app specific auth %s', + token.id) + assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:] - # AppSpecificAuthToken - logger.info('Backfilling encrypted credentials for app specific auth tokens') - for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) | - (AppSpecificAuthToken.token_name == '') | - (AppSpecificAuthToken.token_secret >> None))): - logger.info('Backfilling encrypted credentials for app specific auth %s', - token.id) - assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:] + token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH] + token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]) + assert token_name + assert token_secret - token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH] - token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]) - assert token_name - assert token_secret + (AppSpecificAuthToken + .update(token_name=token_name, + token_secret=token_secret) + .where(AppSpecificAuthToken.id == token.id, + AppSpecificAuthToken.token_code == token.token_code) + .execute()) - (AppSpecificAuthToken - .update(token_name=token_name, - token_secret=token_secret) - .where(AppSpecificAuthToken.id == token.id, - AppSpecificAuthToken.token_code == token.token_code) - .execute()) + assert (AppSpecificAuthToken + .select() + .where(AppSpecificAuthToken.token_name >> None) + .count()) == 0 - assert (AppSpecificAuthToken - .select() - .where(AppSpecificAuthToken.token_name >> None) - .count()) == 0 + # OAuthAccessToken + logger.info('Backfilling credentials for OAuth access tokens') + for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) | + (OAuthAccessToken.token_name == ''))): + logger.info('Backfilling credentials for OAuth access token %s', token.id) + token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH] + token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]) + assert token_name + assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:] - # OAuthAccessToken - logger.info('Backfilling credentials for OAuth access tokens') - for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) | - (OAuthAccessToken.token_name == ''))): - logger.info('Backfilling credentials for OAuth access token %s', token.id) - token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH] - token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]) - assert token_name - assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:] + (OAuthAccessToken + .update(token_name=token_name, + token_code=token_code) + .where(OAuthAccessToken.id == token.id, + OAuthAccessToken.access_token == token.access_token) + .execute()) - (OAuthAccessToken - .update(token_name=token_name, - token_code=token_code) - .where(OAuthAccessToken.id == token.id, - OAuthAccessToken.access_token == token.access_token) - .execute()) + assert (OAuthAccessToken + .select() + .where(OAuthAccessToken.token_name >> None) + .count()) == 0 - assert (OAuthAccessToken - .select() - .where(OAuthAccessToken.token_name >> None) - .count()) == 0 + # OAuthAuthorizationCode + logger.info('Backfilling credentials for OAuth auth code') + for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) | + (OAuthAuthorizationCode.code_name == ''))): + logger.info('Backfilling credentials for OAuth auth code %s', code.id) + user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)() + code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH] + code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]) + assert code_name + assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:] - # OAuthAuthorizationCode - logger.info('Backfilling credentials for OAuth auth code') - for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) | - (OAuthAuthorizationCode.code_name == ''))): - logger.info('Backfilling credentials for OAuth auth code %s', code.id) - user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)() - code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH] - code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]) - assert code_name - assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:] + (OAuthAuthorizationCode + .update(code_name=code_name, code_credential=code_credential) + .where(OAuthAuthorizationCode.id == code.id) + .execute()) - (OAuthAuthorizationCode - .update(code_name=code_name, code_credential=code_credential) - .where(OAuthAuthorizationCode.id == code.id) - .execute()) + assert (OAuthAuthorizationCode + .select() + .where(OAuthAuthorizationCode.code_name >> None) + .count()) == 0 - assert (OAuthAuthorizationCode - .select() - .where(OAuthAuthorizationCode.code_name >> None) - .count()) == 0 + # OAuthApplication + logger.info('Backfilling secret for OAuth applications') + for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False): + logger.info('Backfilling secret for OAuth application %s', app.id) + client_secret = app.client_secret or str(uuid.uuid4()) + secure_client_secret = _decrypted(client_secret) - # OAuthApplication - logger.info('Backfilling secret for OAuth applications') - for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False): - logger.info('Backfilling secret for OAuth application %s', app.id) - client_secret = app.client_secret or str(uuid.uuid4()) - secure_client_secret = _decrypted(client_secret) + (OAuthApplication + .update(secure_client_secret=secure_client_secret, fully_migrated=True) + .where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False) + .execute()) - (OAuthApplication - .update(secure_client_secret=secure_client_secret, fully_migrated=True) - .where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False) - .execute()) - - assert (OAuthApplication - .select() - .where(OAuthApplication.fully_migrated == False) - .count()) == 0 + assert (OAuthApplication + .select() + .where(OAuthApplication.fully_migrated == False) + .count()) == 0 # Adjust existing fields to be nullable. op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255)) @@ -271,10 +273,6 @@ def upgrade(tables, tester, progress_reporter): def downgrade(tables, tester, progress_reporter): op = ProgressWrapper(original_op, progress_reporter) - op.alter_column('accesstoken', 'code', nullable=False, existing_type=sa.String(length=255)) - op.alter_column('oauthaccesstoken', 'access_token', nullable=False, existing_type=sa.String(length=255)) - op.alter_column('oauthauthorizationcode', 'code', nullable=False, existing_type=sa.String(length=255)) - op.alter_column('appspecificauthtoken', 'token_code', nullable=False, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255)) diff --git a/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py b/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py index 15d1ac8b6..4854630bf 100644 --- a/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py +++ b/data/migrations/versions/c059b952ed76_remove_unencrypted_fields_and_data.py @@ -39,22 +39,24 @@ def upgrade(tables, tester, progress_reporter): # ### end Alembic commands ### # Overwrite all plaintext robot credentials. - while True: - try: - robot_account_token = RobotAccountToken.get(fully_migrated=False) - robot_account = robot_account_token.robot_account + from app import app + if app.config.get('SETUP_COMPLETE', False) or tester.is_testing: + while True: + try: + robot_account_token = RobotAccountToken.get(fully_migrated=False) + robot_account = robot_account_token.robot_account - robot_account.email = str(uuid.uuid4()) - robot_account.save() + robot_account.email = str(uuid.uuid4()) + robot_account.save() - federated_login = FederatedLogin.get(user=robot_account) - federated_login.service_ident = 'robot:%s' % robot_account.id - federated_login.save() + federated_login = FederatedLogin.get(user=robot_account) + federated_login.service_ident = 'robot:%s' % robot_account.id + federated_login.save() - robot_account_token.fully_migrated = True - robot_account_token.save() - except RobotAccountToken.DoesNotExist: - break + robot_account_token.fully_migrated = True + robot_account_token.save() + except RobotAccountToken.DoesNotExist: + break def downgrade(tables, tester, progress_reporter): diff --git a/data/migrationutil.py b/data/migrationutil.py index db34e1882..a433605f5 100644 --- a/data/migrationutil.py +++ b/data/migrationutil.py @@ -30,21 +30,35 @@ class NullDataMigration(DataMigration): class DefinedDataMigration(DataMigration): def __init__(self, name, env_var, phases): + assert phases + self.name = name self.phases = {phase.name: phase for phase in phases} + # Add a synthetic phase for new installations that skips the entire migration. + self.phases['new-installation'] = phases[-1]._replace(name='new-installation', + alembic_revision='head') + phase_name = os.getenv(env_var) if phase_name is None: - msg = 'Missing env var `%s` for data migration `%s`' % (env_var, self.name) + msg = 'Missing env var `%s` for data migration `%s`. %s' % (env_var, self.name, + self._error_suffix) raise Exception(msg) current_phase = self.phases.get(phase_name) if current_phase is None: - msg = 'Unknown phase `%s` for data migration `%s`' % (phase_name, self.name) + msg = 'Unknown phase `%s` for data migration `%s`. %s' % (phase_name, self.name, + self._error_suffix) raise Exception(msg) self.current_phase = current_phase + @property + def _error_suffix(self): + message = 'Available values for this migration: %s. ' % (self.phases.keys()) + message += 'If this is a new installation, please use `new-installation`.' + return message + @property def alembic_migration_revision(self): assert self.current_phase