Refactor JWT auth to not import app locally

This commit is contained in:
Joseph Schorr 2015-06-12 17:58:19 -04:00
parent ee154c37a8
commit 331c300893
4 changed files with 27 additions and 24 deletions

View file

@ -54,18 +54,18 @@ class JWTAuthUsers(object):
""" Delegates authentication to a REST endpoint that returns JWTs. """
PUBLIC_KEY_FILENAME = 'jwt-authn.cert'
def __init__(self, exists_url, verify_url, issuer, public_key_path=None):
from app import OVERRIDE_CONFIG_DIRECTORY
def __init__(self, exists_url, verify_url, issuer, override_config_dir, http_client,
public_key_path=None):
self.verify_url = verify_url
self.exists_url = exists_url
self.issuer = issuer
self.client = http_client
default_key_path = os.path.join(OVERRIDE_CONFIG_DIRECTORY, JWTAuthUsers.PUBLIC_KEY_FILENAME)
default_key_path = os.path.join(override_config_dir, JWTAuthUsers.PUBLIC_KEY_FILENAME)
public_key_path = public_key_path or default_key_path
if not os.path.exists(public_key_path):
error_message = ('JWT Authentication public key file "%s" not found in directory %s' %
(JWTAuthUsers.PUBLIC_KEY_FILENAME, OVERRIDE_CONFIG_DIRECTORY))
(JWTAuthUsers.PUBLIC_KEY_FILENAME, override_config_dir))
raise Exception(error_message)
@ -73,9 +73,7 @@ class JWTAuthUsers(object):
self.public_key = public_key_file.read()
def verify_user(self, username_or_email, password, create_new_user=True):
from app import app
client = app.config['HTTPCLIENT']
result = client.get(self.verify_url, timeout=2, auth=(username_or_email, password))
result = self.client.get(self.verify_url, timeout=2, auth=(username_or_email, password))
if result.status_code != 200:
return (None, result.text or 'Invalid username or password')
@ -112,9 +110,7 @@ class JWTAuthUsers(object):
return _get_federated_user(payload['sub'], payload['email'], 'jwtauthn', create_new_user)
def user_exists(self, username):
from app import app
client = app.config['HTTPCLIENT']
result = client.get(self.exists_url, auth=(username, ''), timeout=2)
result = self.client.get(self.exists_url, auth=(username, ''), timeout=2)
if result.status_code / 500 >= 1:
raise Exception('Internal Error when trying to check if user exists: %s' % result.text)
@ -310,14 +306,17 @@ class LDAPUsers(object):
class UserAuthentication(object):
def __init__(self, app=None):
def __init__(self, app=None, override_config_dir=None):
self.app_secret_key = None
self.app = app
if app is not None:
self.state = self.init_app(app)
self.state = self.init_app(app, override_config_dir)
else:
self.state = None
def init_app(self, app):
def init_app(self, app, override_config_dir):
self.app_secret_key = app.config['SECRET_KEY']
authentication_type = app.config.get('AUTHENTICATION_TYPE', 'Database')
if authentication_type == 'Database':
@ -336,24 +335,24 @@ class UserAuthentication(object):
verify_url = app.config.get('JWT_VERIFY_ENDPOINT')
exists_url = app.config.get('JWT_EXISTS_ENDPOINT')
issuer = app.config.get('JWT_AUTH_ISSUER')
users = JWTAuthUsers(exists_url, verify_url, issuer)
users = JWTAuthUsers(exists_url, verify_url, issuer, override_config_dir,
app.config['HTTPCLIENT'])
else:
raise RuntimeError('Unknown authentication type: %s' % authentication_type)
# register extension with app
app.extensions = getattr(app, 'extensions', {})
app.extensions['authentication'] = users
return users
def _get_secret_key(self):
""" Returns the secret key to use for encrypting and decrypting. """
from app import app
app_secret_key = app.config['SECRET_KEY']
secret_key = None
# First try parsing the key as an int.
try:
big_int = int(app_secret_key)
big_int = int(self.app_secret_key)
secret_key = str(bytearray.fromhex('{:02x}'.format(big_int)))
except ValueError:
pass
@ -361,12 +360,12 @@ class UserAuthentication(object):
# Next try parsing it as an UUID.
if secret_key is None:
try:
secret_key = uuid.UUID(app_secret_key).bytes
secret_key = uuid.UUID(self.app_secret_key).bytes
except ValueError:
pass
if secret_key is None:
secret_key = str(bytearray(map(ord, app_secret_key)))
secret_key = str(bytearray(map(ord, self.app_secret_key)))
# Otherwise, use the bytes directly.
return ''.join(itertools.islice(itertools.cycle(secret_key), 32))