Fix the encrypted token migration issue encountered on HEAD

This change ensures there is better messaging around the encrypted token migration, including a new phase to use for new installations, and fixes an issue encountered when running database migrations for new installations
This commit is contained in:
Joseph Schorr 2019-11-14 14:25:38 -05:00
parent a0f7c4f396
commit a54fb1b23a
7 changed files with 194 additions and 163 deletions

View file

@ -378,7 +378,7 @@ def configure(config_object, testing=False):
real_for_update)) real_for_update))
db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername,
function_concat)) function_concat))
db_encrypter.initialize(FieldEncrypter(config_object['DATABASE_SECRET_KEY'])) db_encrypter.initialize(FieldEncrypter(config_object.get('DATABASE_SECRET_KEY')))
read_replicas = config_object.get('DB_READ_REPLICAS', None) read_replicas = config_object.get('DB_READ_REPLICAS', None)
is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly' is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly'

View file

@ -59,11 +59,15 @@ class FieldEncrypter(object):
and the application. and the application.
""" """
def __init__(self, secret_key, version='v0'): def __init__(self, secret_key, version='v0'):
self._secret_key = convert_secret_key(secret_key) # NOTE: secret_key will be None when the system is being first initialized, so we allow that
# case here, but make sure to assert that it is *not* None below if any encryption is actually
# needed.
self._secret_key = convert_secret_key(secret_key) if secret_key is not None else None
self._encryption_version = _VERSIONS[version] self._encryption_version = _VERSIONS[version]
def encrypt_value(self, value, field_max_length=None): def encrypt_value(self, value, field_max_length=None):
""" Encrypts the value using the current version of encryption. """ """ Encrypts the value using the current version of encryption. """
assert self._secret_key is not None
encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length) encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length)
return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value) return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
@ -71,6 +75,7 @@ class FieldEncrypter(object):
""" Decrypts the value, returning it. If the value cannot be decrypted """ Decrypts the value, returning it. If the value cannot be decrypted
raises a DecryptionFailureException. raises a DecryptionFailureException.
""" """
assert self._secret_key is not None
if _SEPARATOR not in value: if _SEPARATOR not in value:
raise DecryptionFailureException('Invalid encrypted value') raise DecryptionFailureException('Invalid encrypted value')

View file

@ -2,7 +2,7 @@ import json
import logging import logging
import uuid import uuid
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod, abstractproperty
from datetime import datetime from datetime import datetime
from six import add_metaclass from six import add_metaclass
@ -92,6 +92,10 @@ class MigrationTester(object):
""" """
TestDataType = DataTypes TestDataType = DataTypes
@abstractproperty
def is_testing(self):
""" Returns whether we are currently under a migration test. """
@abstractmethod @abstractmethod
def populate_table(self, table_name, fields): def populate_table(self, table_name, fields):
""" Called to populate a table with the given fields filled in with testing data. """ """ Called to populate a table with the given fields filled in with testing data. """
@ -107,6 +111,10 @@ class NoopTester(MigrationTester):
class PopulateTestDataTester(MigrationTester): class PopulateTestDataTester(MigrationTester):
@property
def is_testing(self):
return True
def populate_table(self, table_name, fields): def populate_table(self, table_name, fields):
columns = {field_name: field_type() for field_name, field_type in fields} columns = {field_name: field_type() for field_name, field_type in fields}
field_name_vars = [':' + field_name for field_name, _ in fields] field_name_vars = [':' + field_name for field_name, _ in fields]

View file

@ -80,6 +80,8 @@ def upgrade(tables, tester, progress_reporter):
op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True))
from app import app
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)): for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)):
repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository) repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository)
logger.info('migrating %s' % repo) logger.info('migrating %s' % repo)
@ -109,6 +111,8 @@ def downgrade(tables, tester, progress_reporter):
op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True))
op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True)) op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True))
from app import app
if app.config.get('SETUP_COMPLETE', False):
logger.info('Restoring columns from external_reference') logger.info('Restoring columns from external_reference')
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)): for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)):
logger.info('Restoring %s' % repo_mirror.external_reference) logger.info('Restoring %s' % repo_mirror.external_reference)

View file

@ -98,6 +98,8 @@ class OAuthApplication(BaseModel):
def upgrade(tables, tester, progress_reporter): def upgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter) op = ProgressWrapper(original_op, progress_reporter)
from app import app
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
# Empty all access token names to fix the bug where we put the wrong name and code # Empty all access token names to fix the bug where we put the wrong name and code
# in for some tokens. # in for some tokens.
AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute() AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute()
@ -271,10 +273,6 @@ def upgrade(tables, tester, progress_reporter):
def downgrade(tables, tester, progress_reporter): def downgrade(tables, tester, progress_reporter):
op = ProgressWrapper(original_op, progress_reporter) op = ProgressWrapper(original_op, progress_reporter)
op.alter_column('accesstoken', 'code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('oauthaccesstoken', 'access_token', nullable=False, existing_type=sa.String(length=255))
op.alter_column('oauthauthorizationcode', 'code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('appspecificauthtoken', 'token_code', nullable=False, existing_type=sa.String(length=255))
op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255)) op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))

View file

@ -39,6 +39,8 @@ def upgrade(tables, tester, progress_reporter):
# ### end Alembic commands ### # ### end Alembic commands ###
# Overwrite all plaintext robot credentials. # Overwrite all plaintext robot credentials.
from app import app
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
while True: while True:
try: try:
robot_account_token = RobotAccountToken.get(fully_migrated=False) robot_account_token = RobotAccountToken.get(fully_migrated=False)

View file

@ -30,21 +30,35 @@ class NullDataMigration(DataMigration):
class DefinedDataMigration(DataMigration): class DefinedDataMigration(DataMigration):
def __init__(self, name, env_var, phases): def __init__(self, name, env_var, phases):
assert phases
self.name = name self.name = name
self.phases = {phase.name: phase for phase in phases} self.phases = {phase.name: phase for phase in phases}
# Add a synthetic phase for new installations that skips the entire migration.
self.phases['new-installation'] = phases[-1]._replace(name='new-installation',
alembic_revision='head')
phase_name = os.getenv(env_var) phase_name = os.getenv(env_var)
if phase_name is None: if phase_name is None:
msg = 'Missing env var `%s` for data migration `%s`' % (env_var, self.name) msg = 'Missing env var `%s` for data migration `%s`. %s' % (env_var, self.name,
self._error_suffix)
raise Exception(msg) raise Exception(msg)
current_phase = self.phases.get(phase_name) current_phase = self.phases.get(phase_name)
if current_phase is None: if current_phase is None:
msg = 'Unknown phase `%s` for data migration `%s`' % (phase_name, self.name) msg = 'Unknown phase `%s` for data migration `%s`. %s' % (phase_name, self.name,
self._error_suffix)
raise Exception(msg) raise Exception(msg)
self.current_phase = current_phase self.current_phase = current_phase
@property
def _error_suffix(self):
message = 'Available values for this migration: %s. ' % (self.phases.keys())
message += 'If this is a new installation, please use `new-installation`.'
return message
@property @property
def alembic_migration_revision(self): def alembic_migration_revision(self):
assert self.current_phase assert self.current_phase