Fix the encrypted token migration issue encountered on HEAD

This change ensures there is better messaging around the encrypted token migration, including a new phase to use for new installations, and fixes an issue encountered when running database migrations for new installations
This commit is contained in:
Joseph Schorr 2019-11-14 14:25:38 -05:00
parent a0f7c4f396
commit a54fb1b23a
7 changed files with 194 additions and 163 deletions

View file

@ -378,7 +378,7 @@ def configure(config_object, testing=False):
real_for_update)) real_for_update))
db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername,
function_concat)) function_concat))
db_encrypter.initialize(FieldEncrypter(config_object['DATABASE_SECRET_KEY'])) db_encrypter.initialize(FieldEncrypter(config_object.get('DATABASE_SECRET_KEY')))
read_replicas = config_object.get('DB_READ_REPLICAS', None) read_replicas = config_object.get('DB_READ_REPLICAS', None)
is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly' is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly'

View file

@ -59,11 +59,15 @@ class FieldEncrypter(object):
and the application. and the application.
""" """
def __init__(self, secret_key, version='v0'): def __init__(self, secret_key, version='v0'):
self._secret_key = convert_secret_key(secret_key) # NOTE: secret_key will be None when the system is being first initialized, so we allow that
# case here, but make sure to assert that it is *not* None below if any encryption is actually
# needed.
self._secret_key = convert_secret_key(secret_key) if secret_key is not None else None
self._encryption_version = _VERSIONS[version] self._encryption_version = _VERSIONS[version]
def encrypt_value(self, value, field_max_length=None): def encrypt_value(self, value, field_max_length=None):
""" Encrypts the value using the current version of encryption. """ """ Encrypts the value using the current version of encryption. """
assert self._secret_key is not None
encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length) encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length)
return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value) return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
@ -71,6 +75,7 @@ class FieldEncrypter(object):
""" Decrypts the value, returning it. If the value cannot be decrypted """ Decrypts the value, returning it. If the value cannot be decrypted
raises a DecryptionFailureException. raises a DecryptionFailureException.
""" """
assert self._secret_key is not None
if _SEPARATOR not in value: if _SEPARATOR not in value:
raise DecryptionFailureException('Invalid encrypted value') raise DecryptionFailureException('Invalid encrypted value')

View file

@ -2,7 +2,7 @@ import json
import logging import logging
import uuid import uuid
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod, abstractproperty
from datetime import datetime from datetime import datetime
from six import add_metaclass from six import add_metaclass
@ -92,6 +92,10 @@ class MigrationTester(object):
""" """
TestDataType = DataTypes TestDataType = DataTypes
@abstractproperty
def is_testing(self):
""" Returns whether we are currently under a migration test. """
@abstractmethod @abstractmethod
def populate_table(self, table_name, fields): def populate_table(self, table_name, fields):
""" Called to populate a table with the given fields filled in with testing data. """ """ Called to populate a table with the given fields filled in with testing data. """
@ -107,6 +111,10 @@ class NoopTester(MigrationTester):
class PopulateTestDataTester(MigrationTester): class PopulateTestDataTester(MigrationTester):
@property
def is_testing(self):
return True
def populate_table(self, table_name, fields): def populate_table(self, table_name, fields):
columns = {field_name: field_type() for field_name, field_type in fields} columns = {field_name: field_type() for field_name, field_type in fields}
field_name_vars = [':' + field_name for field_name, _ in fields] field_name_vars = [':' + field_name for field_name, _ in fields]

View file

@ -80,11 +80,13 @@ def upgrade(tables, tester, progress_reporter):
op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True))
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)): from app import app
repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository) if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
logger.info('migrating %s' % repo) for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)):
repo_mirror.external_reference = repo repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository)
repo_mirror.save() logger.info('migrating %s' % repo)
repo_mirror.external_reference = repo
repo_mirror.save()
op.drop_column('repomirrorconfig', 'external_registry') op.drop_column('repomirrorconfig', 'external_registry')
op.drop_column('repomirrorconfig', 'external_namespace') op.drop_column('repomirrorconfig', 'external_namespace')
@ -109,14 +111,16 @@ def downgrade(tables, tester, progress_reporter):
op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True))
op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True))
logger.info('Restoring columns from external_reference') from app import app
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)): if app.config.get('SETUP_COMPLETE', False):
logger.info('Restoring %s' % repo_mirror.external_reference) logger.info('Restoring columns from external_reference')
parts = repo_mirror.external_reference.split('/', 2) for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)):
repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED' logger.info('Restoring %s' % repo_mirror.external_reference)
repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED' parts = repo_mirror.external_reference.split('/', 2)
repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED' repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED'
repo_mirror.save() repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED'
repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED'
repo_mirror.save()
op.drop_column('repomirrorconfig', 'external_reference') op.drop_column('repomirrorconfig', 'external_reference')

View file

@ -98,157 +98,159 @@ class OAuthApplication(BaseModel):
def upgrade(tables, tester, progress_reporter): def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter) op = ProgressWrapper(original_op, progress_reporter)
# Empty all access token names to fix the bug where we put the wrong name and code from app import app
# in for some tokens. if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute() # Empty all access token names to fix the bug where we put the wrong name and code
# in for some tokens.
AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute()
# AccessToken. # AccessToken.
logger.info('Backfilling encrypted credentials for access tokens') logger.info('Backfilling encrypted credentials for access tokens')
for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) | for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) |
(AccessToken.token_name == ''))): (AccessToken.token_name == ''))):
logger.info('Backfilling encrypted credentials for access token %s', access_token.id) logger.info('Backfilling encrypted credentials for access token %s', access_token.id)
assert access_token.code is not None assert access_token.code is not None
assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:] assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH] token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]) token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:])
(AccessToken (AccessToken
.update(token_name=token_name, token_code=token_code) .update(token_name=token_name, token_code=token_code)
.where(AccessToken.id == access_token.id, AccessToken.code == access_token.code) .where(AccessToken.id == access_token.id, AccessToken.code == access_token.code)
.execute()) .execute())
assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0 assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0
# Robots. # Robots.
logger.info('Backfilling encrypted credentials for robots') logger.info('Backfilling encrypted credentials for robots')
while True: while True:
has_row = False has_row = False
query = (User query = (User
.select() .select()
.join(RobotAccountToken, JOIN.LEFT_OUTER) .join(RobotAccountToken, JOIN.LEFT_OUTER)
.where(User.robot == True, RobotAccountToken.id >> None) .where(User.robot == True, RobotAccountToken.id >> None)
.limit(BATCH_SIZE)) .limit(BATCH_SIZE))
for robot_user in query: for robot_user in query:
logger.info('Backfilling encrypted credentials for robot %s', robot_user.id) logger.info('Backfilling encrypted credentials for robot %s', robot_user.id)
has_row = True has_row = True
try: try:
RobotAccountToken.create(robot_account=robot_user, RobotAccountToken.create(robot_account=robot_user,
token=_decrypted(robot_user.email), token=_decrypted(robot_user.email),
fully_migrated=False) fully_migrated=False)
except IntegrityError: except IntegrityError:
break
if not has_row:
break break
if not has_row: # RepositoryBuildTrigger
break logger.info('Backfilling encrypted credentials for repo build triggers')
for repo_build_trigger in _iterate(RepositoryBuildTrigger,
(RepositoryBuildTrigger.fully_migrated == False)):
logger.info('Backfilling encrypted credentials for repo build trigger %s',
repo_build_trigger.id)
# RepositoryBuildTrigger (RepositoryBuildTrigger
logger.info('Backfilling encrypted credentials for repo build triggers') .update(secure_auth_token=_decrypted(repo_build_trigger.auth_token),
for repo_build_trigger in _iterate(RepositoryBuildTrigger, secure_private_key=_decrypted(repo_build_trigger.private_key),
(RepositoryBuildTrigger.fully_migrated == False)): fully_migrated=True)
logger.info('Backfilling encrypted credentials for repo build trigger %s', .where(RepositoryBuildTrigger.id == repo_build_trigger.id,
repo_build_trigger.id) RepositoryBuildTrigger.uuid == repo_build_trigger.uuid)
.execute())
(RepositoryBuildTrigger assert (RepositoryBuildTrigger
.update(secure_auth_token=_decrypted(repo_build_trigger.auth_token), .select()
secure_private_key=_decrypted(repo_build_trigger.private_key), .where(RepositoryBuildTrigger.fully_migrated == False)
fully_migrated=True) .count()) == 0
.where(RepositoryBuildTrigger.id == repo_build_trigger.id,
RepositoryBuildTrigger.uuid == repo_build_trigger.uuid)
.execute())
assert (RepositoryBuildTrigger # AppSpecificAuthToken
.select() logger.info('Backfilling encrypted credentials for app specific auth tokens')
.where(RepositoryBuildTrigger.fully_migrated == False) for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) |
.count()) == 0 (AppSpecificAuthToken.token_name == '') |
(AppSpecificAuthToken.token_secret >> None))):
logger.info('Backfilling encrypted credentials for app specific auth %s',
token.id)
assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
# AppSpecificAuthToken token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH]
logger.info('Backfilling encrypted credentials for app specific auth tokens') token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:])
for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) | assert token_name
(AppSpecificAuthToken.token_name == '') | assert token_secret
(AppSpecificAuthToken.token_secret >> None))):
logger.info('Backfilling encrypted credentials for app specific auth %s',
token.id)
assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH] (AppSpecificAuthToken
token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]) .update(token_name=token_name,
assert token_name token_secret=token_secret)
assert token_secret .where(AppSpecificAuthToken.id == token.id,
AppSpecificAuthToken.token_code == token.token_code)
.execute())
(AppSpecificAuthToken assert (AppSpecificAuthToken
.update(token_name=token_name, .select()
token_secret=token_secret) .where(AppSpecificAuthToken.token_name >> None)
.where(AppSpecificAuthToken.id == token.id, .count()) == 0
AppSpecificAuthToken.token_code == token.token_code)
.execute())
assert (AppSpecificAuthToken # OAuthAccessToken
.select() logger.info('Backfilling credentials for OAuth access tokens')
.where(AppSpecificAuthToken.token_name >> None) for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) |
.count()) == 0 (OAuthAccessToken.token_name == ''))):
logger.info('Backfilling credentials for OAuth access token %s', token.id)
token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH]
token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:])
assert token_name
assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
# OAuthAccessToken (OAuthAccessToken
logger.info('Backfilling credentials for OAuth access tokens') .update(token_name=token_name,
for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) | token_code=token_code)
(OAuthAccessToken.token_name == ''))): .where(OAuthAccessToken.id == token.id,
logger.info('Backfilling credentials for OAuth access token %s', token.id) OAuthAccessToken.access_token == token.access_token)
token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH] .execute())
token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:])
assert token_name
assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
(OAuthAccessToken assert (OAuthAccessToken
.update(token_name=token_name, .select()
token_code=token_code) .where(OAuthAccessToken.token_name >> None)
.where(OAuthAccessToken.id == token.id, .count()) == 0
OAuthAccessToken.access_token == token.access_token)
.execute())
assert (OAuthAccessToken # OAuthAuthorizationCode
.select() logger.info('Backfilling credentials for OAuth auth code')
.where(OAuthAccessToken.token_name >> None) for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) |
.count()) == 0 (OAuthAuthorizationCode.code_name == ''))):
logger.info('Backfilling credentials for OAuth auth code %s', code.id)
user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:])
assert code_name
assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
# OAuthAuthorizationCode (OAuthAuthorizationCode
logger.info('Backfilling credentials for OAuth auth code') .update(code_name=code_name, code_credential=code_credential)
for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) | .where(OAuthAuthorizationCode.id == code.id)
(OAuthAuthorizationCode.code_name == ''))): .execute())
logger.info('Backfilling credentials for OAuth auth code %s', code.id)
user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:])
assert code_name
assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
(OAuthAuthorizationCode assert (OAuthAuthorizationCode
.update(code_name=code_name, code_credential=code_credential) .select()
.where(OAuthAuthorizationCode.id == code.id) .where(OAuthAuthorizationCode.code_name >> None)
.execute()) .count()) == 0
assert (OAuthAuthorizationCode # OAuthApplication
.select() logger.info('Backfilling secret for OAuth applications')
.where(OAuthAuthorizationCode.code_name >> None) for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False):
.count()) == 0 logger.info('Backfilling secret for OAuth application %s', app.id)
client_secret = app.client_secret or str(uuid.uuid4())
secure_client_secret = _decrypted(client_secret)
# OAuthApplication (OAuthApplication
logger.info('Backfilling secret for OAuth applications') .update(secure_client_secret=secure_client_secret, fully_migrated=True)
for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False): .where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False)
logger.info('Backfilling secret for OAuth application %s', app.id) .execute())
client_secret = app.client_secret or str(uuid.uuid4())
secure_client_secret = _decrypted(client_secret)
(OAuthApplication assert (OAuthApplication
.update(secure_client_secret=secure_client_secret, fully_migrated=True) .select()
.where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False) .where(OAuthApplication.fully_migrated == False)
.execute()) .count()) == 0
assert (OAuthApplication
.select()
.where(OAuthApplication.fully_migrated == False)
.count()) == 0
# Adjust existing fields to be nullable. # Adjust existing fields to be nullable.
op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255))
@ -271,10 +273,6 @@ def upgrade(tables, tester, progress_reporter):
def downgrade(tables, tester, progress_reporter): def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter) op = ProgressWrapper(original_op, progress_reporter)
op.alter_column('accesstoken', 'code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('oauthaccesstoken', 'access_token', nullable=False, existing_type=sa.String(length=255))
op.alter_column('oauthauthorizationcode', 'code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('appspecificauthtoken', 'token_code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))

View file

@ -39,22 +39,24 @@ def upgrade(tables, tester, progress_reporter):
# ### end Alembic commands ### # ### end Alembic commands ###
# Overwrite all plaintext robot credentials. # Overwrite all plaintext robot credentials.
while True: from app import app
try: if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
robot_account_token = RobotAccountToken.get(fully_migrated=False) while True:
robot_account = robot_account_token.robot_account try:
robot_account_token = RobotAccountToken.get(fully_migrated=False)
robot_account = robot_account_token.robot_account
robot_account.email = str(uuid.uuid4()) robot_account.email = str(uuid.uuid4())
robot_account.save() robot_account.save()
federated_login = FederatedLogin.get(user=robot_account) federated_login = FederatedLogin.get(user=robot_account)
federated_login.service_ident = 'robot:%s' % robot_account.id federated_login.service_ident = 'robot:%s' % robot_account.id
federated_login.save() federated_login.save()
robot_account_token.fully_migrated = True robot_account_token.fully_migrated = True
robot_account_token.save() robot_account_token.save()
except RobotAccountToken.DoesNotExist: except RobotAccountToken.DoesNotExist:
break break
def downgrade(tables, tester, progress_reporter): def downgrade(tables, tester, progress_reporter):

View file

@ -30,21 +30,35 @@ class NullDataMigration(DataMigration):
class DefinedDataMigration(DataMigration): class DefinedDataMigration(DataMigration):
def __init__(self, name, env_var, phases): def __init__(self, name, env_var, phases):
assert phases
self.name = name self.name = name
self.phases = {phase.name: phase for phase in phases} self.phases = {phase.name: phase for phase in phases}
# Add a synthetic phase for new installations that skips the entire migration.
self.phases['new-installation'] = phases[-1]._replace(name='new-installation',
alembic_revision='head')
phase_name = os.getenv(env_var) phase_name = os.getenv(env_var)
if phase_name is None: if phase_name is None:
msg = 'Missing env var `%s` for data migration `%s`' % (env_var, self.name) msg = 'Missing env var `%s` for data migration `%s`. %s' % (env_var, self.name,
self._error_suffix)
raise Exception(msg) raise Exception(msg)
current_phase = self.phases.get(phase_name) current_phase = self.phases.get(phase_name)
if current_phase is None: if current_phase is None:
msg = 'Unknown phase `%s` for data migration `%s`' % (phase_name, self.name) msg = 'Unknown phase `%s` for data migration `%s`. %s' % (phase_name, self.name,
self._error_suffix)
raise Exception(msg) raise Exception(msg)
self.current_phase = current_phase self.current_phase = current_phase
@property
def _error_suffix(self):
message = 'Available values for this migration: %s. ' % (self.phases.keys())
message += 'If this is a new installation, please use `new-installation`.'
return message
@property @property
def alembic_migration_revision(self): def alembic_migration_revision(self):
assert self.current_phase assert self.current_phase