Fix the encrypted token migration issue encountered on HEAD
This change ensures there is better messaging around the encrypted token migration, including a new phase to use for new installations, and fixes an issue encountered when running database migrations for new installations
This commit is contained in:
parent
a0f7c4f396
commit
a54fb1b23a
7 changed files with 194 additions and 163 deletions
|
@ -378,7 +378,7 @@ def configure(config_object, testing=False):
|
||||||
real_for_update))
|
real_for_update))
|
||||||
db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername,
|
db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername,
|
||||||
function_concat))
|
function_concat))
|
||||||
db_encrypter.initialize(FieldEncrypter(config_object['DATABASE_SECRET_KEY']))
|
db_encrypter.initialize(FieldEncrypter(config_object.get('DATABASE_SECRET_KEY')))
|
||||||
|
|
||||||
read_replicas = config_object.get('DB_READ_REPLICAS', None)
|
read_replicas = config_object.get('DB_READ_REPLICAS', None)
|
||||||
is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly'
|
is_read_only = config_object.get('REGISTRY_STATE', 'normal') == 'readonly'
|
||||||
|
|
|
@ -59,11 +59,15 @@ class FieldEncrypter(object):
|
||||||
and the application.
|
and the application.
|
||||||
"""
|
"""
|
||||||
def __init__(self, secret_key, version='v0'):
|
def __init__(self, secret_key, version='v0'):
|
||||||
self._secret_key = convert_secret_key(secret_key)
|
# NOTE: secret_key will be None when the system is being first initialized, so we allow that
|
||||||
|
# case here, but make sure to assert that it is *not* None below if any encryption is actually
|
||||||
|
# needed.
|
||||||
|
self._secret_key = convert_secret_key(secret_key) if secret_key is not None else None
|
||||||
self._encryption_version = _VERSIONS[version]
|
self._encryption_version = _VERSIONS[version]
|
||||||
|
|
||||||
def encrypt_value(self, value, field_max_length=None):
|
def encrypt_value(self, value, field_max_length=None):
|
||||||
""" Encrypts the value using the current version of encryption. """
|
""" Encrypts the value using the current version of encryption. """
|
||||||
|
assert self._secret_key is not None
|
||||||
encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length)
|
encrypted_value = self._encryption_version.encrypt(self._secret_key, value, field_max_length)
|
||||||
return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
|
return '%s%s%s' % (self._encryption_version.prefix, _SEPARATOR, encrypted_value)
|
||||||
|
|
||||||
|
@ -71,6 +75,7 @@ class FieldEncrypter(object):
|
||||||
""" Decrypts the value, returning it. If the value cannot be decrypted
|
""" Decrypts the value, returning it. If the value cannot be decrypted
|
||||||
raises a DecryptionFailureException.
|
raises a DecryptionFailureException.
|
||||||
"""
|
"""
|
||||||
|
assert self._secret_key is not None
|
||||||
if _SEPARATOR not in value:
|
if _SEPARATOR not in value:
|
||||||
raise DecryptionFailureException('Invalid encrypted value')
|
raise DecryptionFailureException('Invalid encrypted value')
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from six import add_metaclass
|
from six import add_metaclass
|
||||||
|
|
||||||
|
@ -92,6 +92,10 @@ class MigrationTester(object):
|
||||||
"""
|
"""
|
||||||
TestDataType = DataTypes
|
TestDataType = DataTypes
|
||||||
|
|
||||||
|
@abstractproperty
|
||||||
|
def is_testing(self):
|
||||||
|
""" Returns whether we are currently under a migration test. """
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def populate_table(self, table_name, fields):
|
def populate_table(self, table_name, fields):
|
||||||
""" Called to populate a table with the given fields filled in with testing data. """
|
""" Called to populate a table with the given fields filled in with testing data. """
|
||||||
|
@ -107,6 +111,10 @@ class NoopTester(MigrationTester):
|
||||||
|
|
||||||
|
|
||||||
class PopulateTestDataTester(MigrationTester):
|
class PopulateTestDataTester(MigrationTester):
|
||||||
|
@property
|
||||||
|
def is_testing(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def populate_table(self, table_name, fields):
|
def populate_table(self, table_name, fields):
|
||||||
columns = {field_name: field_type() for field_name, field_type in fields}
|
columns = {field_name: field_type() for field_name, field_type in fields}
|
||||||
field_name_vars = [':' + field_name for field_name, _ in fields]
|
field_name_vars = [':' + field_name for field_name, _ in fields]
|
||||||
|
|
|
@ -80,11 +80,13 @@ def upgrade(tables, tester, progress_reporter):
|
||||||
|
|
||||||
op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True))
|
op.add_column('repomirrorconfig', sa.Column('external_reference', sa.Text(), nullable=True))
|
||||||
|
|
||||||
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)):
|
from app import app
|
||||||
repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository)
|
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
|
||||||
logger.info('migrating %s' % repo)
|
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_reference >> None)):
|
||||||
repo_mirror.external_reference = repo
|
repo = '%s/%s/%s' % (repo_mirror.external_registry, repo_mirror.external_namespace, repo_mirror.external_repository)
|
||||||
repo_mirror.save()
|
logger.info('migrating %s' % repo)
|
||||||
|
repo_mirror.external_reference = repo
|
||||||
|
repo_mirror.save()
|
||||||
|
|
||||||
op.drop_column('repomirrorconfig', 'external_registry')
|
op.drop_column('repomirrorconfig', 'external_registry')
|
||||||
op.drop_column('repomirrorconfig', 'external_namespace')
|
op.drop_column('repomirrorconfig', 'external_namespace')
|
||||||
|
@ -109,14 +111,16 @@ def downgrade(tables, tester, progress_reporter):
|
||||||
op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True))
|
op.add_column('repomirrorconfig', sa.Column('external_namespace', sa.String(length=255), nullable=True))
|
||||||
op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True))
|
op.add_column('repomirrorconfig', sa.Column('external_repository', sa.String(length=255), nullable=True))
|
||||||
|
|
||||||
logger.info('Restoring columns from external_reference')
|
from app import app
|
||||||
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)):
|
if app.config.get('SETUP_COMPLETE', False):
|
||||||
logger.info('Restoring %s' % repo_mirror.external_reference)
|
logger.info('Restoring columns from external_reference')
|
||||||
parts = repo_mirror.external_reference.split('/', 2)
|
for repo_mirror in _iterate(RepoMirrorConfig, (RepoMirrorConfig.external_registry >> None)):
|
||||||
repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED'
|
logger.info('Restoring %s' % repo_mirror.external_reference)
|
||||||
repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED'
|
parts = repo_mirror.external_reference.split('/', 2)
|
||||||
repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED'
|
repo_mirror.external_registry = parts[0] if len(parts) >= 1 else 'DOWNGRADE-FAILED'
|
||||||
repo_mirror.save()
|
repo_mirror.external_namespace = parts[1] if len(parts) >= 2 else 'DOWNGRADE-FAILED'
|
||||||
|
repo_mirror.external_repository = parts[2] if len(parts) >= 3 else 'DOWNGRADE-FAILED'
|
||||||
|
repo_mirror.save()
|
||||||
|
|
||||||
op.drop_column('repomirrorconfig', 'external_reference')
|
op.drop_column('repomirrorconfig', 'external_reference')
|
||||||
|
|
||||||
|
|
|
@ -98,157 +98,159 @@ class OAuthApplication(BaseModel):
|
||||||
def upgrade(tables, tester, progress_reporter):
|
def upgrade(tables, tester, progress_reporter):
|
||||||
op = ProgressWrapper(original_op, progress_reporter)
|
op = ProgressWrapper(original_op, progress_reporter)
|
||||||
|
|
||||||
# Empty all access token names to fix the bug where we put the wrong name and code
|
from app import app
|
||||||
# in for some tokens.
|
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
|
||||||
AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute()
|
# Empty all access token names to fix the bug where we put the wrong name and code
|
||||||
|
# in for some tokens.
|
||||||
|
AccessToken.update(token_name=None).where(AccessToken.token_name >> None).execute()
|
||||||
|
|
||||||
# AccessToken.
|
# AccessToken.
|
||||||
logger.info('Backfilling encrypted credentials for access tokens')
|
logger.info('Backfilling encrypted credentials for access tokens')
|
||||||
for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) |
|
for access_token in _iterate(AccessToken, ((AccessToken.token_name >> None) |
|
||||||
(AccessToken.token_name == ''))):
|
(AccessToken.token_name == ''))):
|
||||||
logger.info('Backfilling encrypted credentials for access token %s', access_token.id)
|
logger.info('Backfilling encrypted credentials for access token %s', access_token.id)
|
||||||
assert access_token.code is not None
|
assert access_token.code is not None
|
||||||
assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
|
assert access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
|
||||||
assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
|
assert access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:]
|
||||||
|
|
||||||
token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
|
token_name = access_token.code[:ACCESS_TOKEN_NAME_PREFIX_LENGTH]
|
||||||
token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:])
|
token_code = _decrypted(access_token.code[ACCESS_TOKEN_NAME_PREFIX_LENGTH:])
|
||||||
|
|
||||||
(AccessToken
|
(AccessToken
|
||||||
.update(token_name=token_name, token_code=token_code)
|
.update(token_name=token_name, token_code=token_code)
|
||||||
.where(AccessToken.id == access_token.id, AccessToken.code == access_token.code)
|
.where(AccessToken.id == access_token.id, AccessToken.code == access_token.code)
|
||||||
.execute())
|
.execute())
|
||||||
|
|
||||||
assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0
|
assert AccessToken.select().where(AccessToken.token_name >> None).count() == 0
|
||||||
|
|
||||||
# Robots.
|
# Robots.
|
||||||
logger.info('Backfilling encrypted credentials for robots')
|
logger.info('Backfilling encrypted credentials for robots')
|
||||||
while True:
|
while True:
|
||||||
has_row = False
|
has_row = False
|
||||||
query = (User
|
query = (User
|
||||||
.select()
|
.select()
|
||||||
.join(RobotAccountToken, JOIN.LEFT_OUTER)
|
.join(RobotAccountToken, JOIN.LEFT_OUTER)
|
||||||
.where(User.robot == True, RobotAccountToken.id >> None)
|
.where(User.robot == True, RobotAccountToken.id >> None)
|
||||||
.limit(BATCH_SIZE))
|
.limit(BATCH_SIZE))
|
||||||
|
|
||||||
for robot_user in query:
|
for robot_user in query:
|
||||||
logger.info('Backfilling encrypted credentials for robot %s', robot_user.id)
|
logger.info('Backfilling encrypted credentials for robot %s', robot_user.id)
|
||||||
has_row = True
|
has_row = True
|
||||||
try:
|
try:
|
||||||
RobotAccountToken.create(robot_account=robot_user,
|
RobotAccountToken.create(robot_account=robot_user,
|
||||||
token=_decrypted(robot_user.email),
|
token=_decrypted(robot_user.email),
|
||||||
fully_migrated=False)
|
fully_migrated=False)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_row:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not has_row:
|
# RepositoryBuildTrigger
|
||||||
break
|
logger.info('Backfilling encrypted credentials for repo build triggers')
|
||||||
|
for repo_build_trigger in _iterate(RepositoryBuildTrigger,
|
||||||
|
(RepositoryBuildTrigger.fully_migrated == False)):
|
||||||
|
logger.info('Backfilling encrypted credentials for repo build trigger %s',
|
||||||
|
repo_build_trigger.id)
|
||||||
|
|
||||||
# RepositoryBuildTrigger
|
(RepositoryBuildTrigger
|
||||||
logger.info('Backfilling encrypted credentials for repo build triggers')
|
.update(secure_auth_token=_decrypted(repo_build_trigger.auth_token),
|
||||||
for repo_build_trigger in _iterate(RepositoryBuildTrigger,
|
secure_private_key=_decrypted(repo_build_trigger.private_key),
|
||||||
(RepositoryBuildTrigger.fully_migrated == False)):
|
fully_migrated=True)
|
||||||
logger.info('Backfilling encrypted credentials for repo build trigger %s',
|
.where(RepositoryBuildTrigger.id == repo_build_trigger.id,
|
||||||
repo_build_trigger.id)
|
RepositoryBuildTrigger.uuid == repo_build_trigger.uuid)
|
||||||
|
.execute())
|
||||||
|
|
||||||
(RepositoryBuildTrigger
|
assert (RepositoryBuildTrigger
|
||||||
.update(secure_auth_token=_decrypted(repo_build_trigger.auth_token),
|
.select()
|
||||||
secure_private_key=_decrypted(repo_build_trigger.private_key),
|
.where(RepositoryBuildTrigger.fully_migrated == False)
|
||||||
fully_migrated=True)
|
.count()) == 0
|
||||||
.where(RepositoryBuildTrigger.id == repo_build_trigger.id,
|
|
||||||
RepositoryBuildTrigger.uuid == repo_build_trigger.uuid)
|
|
||||||
.execute())
|
|
||||||
|
|
||||||
assert (RepositoryBuildTrigger
|
# AppSpecificAuthToken
|
||||||
.select()
|
logger.info('Backfilling encrypted credentials for app specific auth tokens')
|
||||||
.where(RepositoryBuildTrigger.fully_migrated == False)
|
for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) |
|
||||||
.count()) == 0
|
(AppSpecificAuthToken.token_name == '') |
|
||||||
|
(AppSpecificAuthToken.token_secret >> None))):
|
||||||
|
logger.info('Backfilling encrypted credentials for app specific auth %s',
|
||||||
|
token.id)
|
||||||
|
assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
|
||||||
|
|
||||||
# AppSpecificAuthToken
|
token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH]
|
||||||
logger.info('Backfilling encrypted credentials for app specific auth tokens')
|
token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:])
|
||||||
for token in _iterate(AppSpecificAuthToken, ((AppSpecificAuthToken.token_name >> None) |
|
assert token_name
|
||||||
(AppSpecificAuthToken.token_name == '') |
|
assert token_secret
|
||||||
(AppSpecificAuthToken.token_secret >> None))):
|
|
||||||
logger.info('Backfilling encrypted credentials for app specific auth %s',
|
|
||||||
token.id)
|
|
||||||
assert token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:]
|
|
||||||
|
|
||||||
token_name = token.token_code[:AST_TOKEN_NAME_PREFIX_LENGTH]
|
(AppSpecificAuthToken
|
||||||
token_secret = _decrypted(token.token_code[AST_TOKEN_NAME_PREFIX_LENGTH:])
|
.update(token_name=token_name,
|
||||||
assert token_name
|
token_secret=token_secret)
|
||||||
assert token_secret
|
.where(AppSpecificAuthToken.id == token.id,
|
||||||
|
AppSpecificAuthToken.token_code == token.token_code)
|
||||||
|
.execute())
|
||||||
|
|
||||||
(AppSpecificAuthToken
|
assert (AppSpecificAuthToken
|
||||||
.update(token_name=token_name,
|
.select()
|
||||||
token_secret=token_secret)
|
.where(AppSpecificAuthToken.token_name >> None)
|
||||||
.where(AppSpecificAuthToken.id == token.id,
|
.count()) == 0
|
||||||
AppSpecificAuthToken.token_code == token.token_code)
|
|
||||||
.execute())
|
|
||||||
|
|
||||||
assert (AppSpecificAuthToken
|
# OAuthAccessToken
|
||||||
.select()
|
logger.info('Backfilling credentials for OAuth access tokens')
|
||||||
.where(AppSpecificAuthToken.token_name >> None)
|
for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) |
|
||||||
.count()) == 0
|
(OAuthAccessToken.token_name == ''))):
|
||||||
|
logger.info('Backfilling credentials for OAuth access token %s', token.id)
|
||||||
|
token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH]
|
||||||
|
token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:])
|
||||||
|
assert token_name
|
||||||
|
assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
|
||||||
|
|
||||||
# OAuthAccessToken
|
(OAuthAccessToken
|
||||||
logger.info('Backfilling credentials for OAuth access tokens')
|
.update(token_name=token_name,
|
||||||
for token in _iterate(OAuthAccessToken, ((OAuthAccessToken.token_name >> None) |
|
token_code=token_code)
|
||||||
(OAuthAccessToken.token_name == ''))):
|
.where(OAuthAccessToken.id == token.id,
|
||||||
logger.info('Backfilling credentials for OAuth access token %s', token.id)
|
OAuthAccessToken.access_token == token.access_token)
|
||||||
token_name = token.access_token[:OAUTH_ACCESS_TOKEN_PREFIX_LENGTH]
|
.execute())
|
||||||
token_code = Credential.from_string(token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:])
|
|
||||||
assert token_name
|
|
||||||
assert token.access_token[OAUTH_ACCESS_TOKEN_PREFIX_LENGTH:]
|
|
||||||
|
|
||||||
(OAuthAccessToken
|
assert (OAuthAccessToken
|
||||||
.update(token_name=token_name,
|
.select()
|
||||||
token_code=token_code)
|
.where(OAuthAccessToken.token_name >> None)
|
||||||
.where(OAuthAccessToken.id == token.id,
|
.count()) == 0
|
||||||
OAuthAccessToken.access_token == token.access_token)
|
|
||||||
.execute())
|
|
||||||
|
|
||||||
assert (OAuthAccessToken
|
# OAuthAuthorizationCode
|
||||||
.select()
|
logger.info('Backfilling credentials for OAuth auth code')
|
||||||
.where(OAuthAccessToken.token_name >> None)
|
for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) |
|
||||||
.count()) == 0
|
(OAuthAuthorizationCode.code_name == ''))):
|
||||||
|
logger.info('Backfilling credentials for OAuth auth code %s', code.id)
|
||||||
|
user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
|
||||||
|
code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
|
||||||
|
code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:])
|
||||||
|
assert code_name
|
||||||
|
assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
|
||||||
|
|
||||||
# OAuthAuthorizationCode
|
(OAuthAuthorizationCode
|
||||||
logger.info('Backfilling credentials for OAuth auth code')
|
.update(code_name=code_name, code_credential=code_credential)
|
||||||
for code in _iterate(OAuthAuthorizationCode, ((OAuthAuthorizationCode.code_name >> None) |
|
.where(OAuthAuthorizationCode.id == code.id)
|
||||||
(OAuthAuthorizationCode.code_name == ''))):
|
.execute())
|
||||||
logger.info('Backfilling credentials for OAuth auth code %s', code.id)
|
|
||||||
user_code = code.code or random_string_generator(AUTHORIZATION_CODE_PREFIX_LENGTH * 2)()
|
|
||||||
code_name = user_code[:AUTHORIZATION_CODE_PREFIX_LENGTH]
|
|
||||||
code_credential = Credential.from_string(user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:])
|
|
||||||
assert code_name
|
|
||||||
assert user_code[AUTHORIZATION_CODE_PREFIX_LENGTH:]
|
|
||||||
|
|
||||||
(OAuthAuthorizationCode
|
assert (OAuthAuthorizationCode
|
||||||
.update(code_name=code_name, code_credential=code_credential)
|
.select()
|
||||||
.where(OAuthAuthorizationCode.id == code.id)
|
.where(OAuthAuthorizationCode.code_name >> None)
|
||||||
.execute())
|
.count()) == 0
|
||||||
|
|
||||||
assert (OAuthAuthorizationCode
|
# OAuthApplication
|
||||||
.select()
|
logger.info('Backfilling secret for OAuth applications')
|
||||||
.where(OAuthAuthorizationCode.code_name >> None)
|
for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False):
|
||||||
.count()) == 0
|
logger.info('Backfilling secret for OAuth application %s', app.id)
|
||||||
|
client_secret = app.client_secret or str(uuid.uuid4())
|
||||||
|
secure_client_secret = _decrypted(client_secret)
|
||||||
|
|
||||||
# OAuthApplication
|
(OAuthApplication
|
||||||
logger.info('Backfilling secret for OAuth applications')
|
.update(secure_client_secret=secure_client_secret, fully_migrated=True)
|
||||||
for app in _iterate(OAuthApplication, OAuthApplication.fully_migrated == False):
|
.where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False)
|
||||||
logger.info('Backfilling secret for OAuth application %s', app.id)
|
.execute())
|
||||||
client_secret = app.client_secret or str(uuid.uuid4())
|
|
||||||
secure_client_secret = _decrypted(client_secret)
|
|
||||||
|
|
||||||
(OAuthApplication
|
assert (OAuthApplication
|
||||||
.update(secure_client_secret=secure_client_secret, fully_migrated=True)
|
.select()
|
||||||
.where(OAuthApplication.id == app.id, OAuthApplication.fully_migrated == False)
|
.where(OAuthApplication.fully_migrated == False)
|
||||||
.execute())
|
.count()) == 0
|
||||||
|
|
||||||
assert (OAuthApplication
|
|
||||||
.select()
|
|
||||||
.where(OAuthApplication.fully_migrated == False)
|
|
||||||
.count()) == 0
|
|
||||||
|
|
||||||
# Adjust existing fields to be nullable.
|
# Adjust existing fields to be nullable.
|
||||||
op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255))
|
op.alter_column('accesstoken', 'code', nullable=True, existing_type=sa.String(length=255))
|
||||||
|
@ -271,10 +273,6 @@ def upgrade(tables, tester, progress_reporter):
|
||||||
|
|
||||||
def downgrade(tables, tester, progress_reporter):
|
def downgrade(tables, tester, progress_reporter):
|
||||||
op = ProgressWrapper(original_op, progress_reporter)
|
op = ProgressWrapper(original_op, progress_reporter)
|
||||||
op.alter_column('accesstoken', 'code', nullable=False, existing_type=sa.String(length=255))
|
|
||||||
op.alter_column('oauthaccesstoken', 'access_token', nullable=False, existing_type=sa.String(length=255))
|
|
||||||
op.alter_column('oauthauthorizationcode', 'code', nullable=False, existing_type=sa.String(length=255))
|
|
||||||
op.alter_column('appspecificauthtoken', 'token_code', nullable=False, existing_type=sa.String(length=255))
|
|
||||||
|
|
||||||
op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
|
op.alter_column('accesstoken', 'token_name', nullable=True, existing_type=sa.String(length=255))
|
||||||
op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))
|
op.alter_column('accesstoken', 'token_code', nullable=True, existing_type=sa.String(length=255))
|
||||||
|
|
|
@ -39,22 +39,24 @@ def upgrade(tables, tester, progress_reporter):
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
# Overwrite all plaintext robot credentials.
|
# Overwrite all plaintext robot credentials.
|
||||||
while True:
|
from app import app
|
||||||
try:
|
if app.config.get('SETUP_COMPLETE', False) or tester.is_testing:
|
||||||
robot_account_token = RobotAccountToken.get(fully_migrated=False)
|
while True:
|
||||||
robot_account = robot_account_token.robot_account
|
try:
|
||||||
|
robot_account_token = RobotAccountToken.get(fully_migrated=False)
|
||||||
|
robot_account = robot_account_token.robot_account
|
||||||
|
|
||||||
robot_account.email = str(uuid.uuid4())
|
robot_account.email = str(uuid.uuid4())
|
||||||
robot_account.save()
|
robot_account.save()
|
||||||
|
|
||||||
federated_login = FederatedLogin.get(user=robot_account)
|
federated_login = FederatedLogin.get(user=robot_account)
|
||||||
federated_login.service_ident = 'robot:%s' % robot_account.id
|
federated_login.service_ident = 'robot:%s' % robot_account.id
|
||||||
federated_login.save()
|
federated_login.save()
|
||||||
|
|
||||||
robot_account_token.fully_migrated = True
|
robot_account_token.fully_migrated = True
|
||||||
robot_account_token.save()
|
robot_account_token.save()
|
||||||
except RobotAccountToken.DoesNotExist:
|
except RobotAccountToken.DoesNotExist:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def downgrade(tables, tester, progress_reporter):
|
def downgrade(tables, tester, progress_reporter):
|
||||||
|
|
|
@ -30,21 +30,35 @@ class NullDataMigration(DataMigration):
|
||||||
|
|
||||||
class DefinedDataMigration(DataMigration):
|
class DefinedDataMigration(DataMigration):
|
||||||
def __init__(self, name, env_var, phases):
|
def __init__(self, name, env_var, phases):
|
||||||
|
assert phases
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.phases = {phase.name: phase for phase in phases}
|
self.phases = {phase.name: phase for phase in phases}
|
||||||
|
|
||||||
|
# Add a synthetic phase for new installations that skips the entire migration.
|
||||||
|
self.phases['new-installation'] = phases[-1]._replace(name='new-installation',
|
||||||
|
alembic_revision='head')
|
||||||
|
|
||||||
phase_name = os.getenv(env_var)
|
phase_name = os.getenv(env_var)
|
||||||
if phase_name is None:
|
if phase_name is None:
|
||||||
msg = 'Missing env var `%s` for data migration `%s`' % (env_var, self.name)
|
msg = 'Missing env var `%s` for data migration `%s`. %s' % (env_var, self.name,
|
||||||
|
self._error_suffix)
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
current_phase = self.phases.get(phase_name)
|
current_phase = self.phases.get(phase_name)
|
||||||
if current_phase is None:
|
if current_phase is None:
|
||||||
msg = 'Unknown phase `%s` for data migration `%s`' % (phase_name, self.name)
|
msg = 'Unknown phase `%s` for data migration `%s`. %s' % (phase_name, self.name,
|
||||||
|
self._error_suffix)
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
self.current_phase = current_phase
|
self.current_phase = current_phase
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _error_suffix(self):
|
||||||
|
message = 'Available values for this migration: %s. ' % (self.phases.keys())
|
||||||
|
message += 'If this is a new installation, please use `new-installation`.'
|
||||||
|
return message
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alembic_migration_revision(self):
|
def alembic_migration_revision(self):
|
||||||
assert self.current_phase
|
assert self.current_phase
|
||||||
|
|
Reference in a new issue