Add tests for user files endpoint and add regex filter

This commit is contained in:
Joseph Schorr 2017-04-21 16:13:18 -04:00
parent 4bb725dce0
commit 5bba018009
2 changed files with 61 additions and 22 deletions

View file

@ -1,15 +1,53 @@
import pytest import pytest
from data.userfiles import DelegateUserfiles from mock import Mock
from io import BytesIO
@pytest.mark.parametrize('path,expected', [ from data.userfiles import DelegateUserfiles, Userfiles
('foo', 'test/foo'), from test.fixtures import app, appconfig, database_uri, init_db_path, sqlitedb_file
('bar', 'test/bar'),
('/bar', 'test/bar'),
('../foo', 'test/foo'), @pytest.mark.parametrize('prefix,path,expected', [
('foo/bar/baz', 'test/baz'), ('test', 'foo', 'test/foo'),
('foo/../baz', 'test/baz'), ('test', 'bar', 'test/bar'),
('test', '/bar', 'test/bar'),
('test', '../foo', 'test/foo'),
('test', 'foo/bar/baz', 'test/baz'),
('test', 'foo/../baz', 'test/baz'),
(None, 'foo', 'foo'),
(None, 'foo/bar/baz', 'baz'),
]) ])
def test_filepath(path, expected): def test_filepath(prefix, path, expected):
userfiles = DelegateUserfiles(None, None, 'local_us', 'test') userfiles = DelegateUserfiles(None, None, 'local_us', prefix)
assert userfiles.get_file_id_path(path) == expected assert userfiles.get_file_id_path(path) == expected
def test_lookup_userfile(app, client):
uuid = 'deadbeef-dead-beef-dead-beefdeadbeef'
bad_uuid = 'deadduck-dead-duck-dead-duckdeadduck'
upper_uuid = 'DEADBEEF-DEAD-BEEF-DEAD-BEEFDEADBEEF'
def _stream_read_file(locations, path):
if path.find(uuid) > 0 or path.find(upper_uuid) > 0:
return BytesIO("hello world")
raise IOError('Not found!')
storage_mock = Mock()
storage_mock.stream_read_file = _stream_read_file
app.config['USERFILES_PATH'] = 'foo'
Userfiles(app, distributed_storage=storage_mock)
rv = client.open('/userfiles/' + uuid, method='GET')
assert rv.status_code == 200
rv = client.open('/userfiles/' + upper_uuid, method='GET')
assert rv.status_code == 200
rv = client.open('/userfiles/' + bad_uuid, method='GET')
assert rv.status_code == 404
rv = client.open('/userfiles/foo/bar/baz', method='GET')
assert rv.status_code == 404

View file

@ -75,7 +75,7 @@ class DelegateUserfiles(object):
def get_file_id_path(self, file_id): def get_file_id_path(self, file_id):
# Note: We use basename here to prevent paths with ..'s and absolute paths. # Note: We use basename here to prevent paths with ..'s and absolute paths.
return os.path.join(self._prefix, os.path.basename(file_id)) return os.path.join(self._prefix or '', os.path.basename(file_id))
def prepare_for_drop(self, mime_type, requires_cors=True): def prepare_for_drop(self, mime_type, requires_cors=True):
""" Returns a signed URL to upload a file to our bucket. """ """ Returns a signed URL to upload a file to our bucket. """
@ -137,12 +137,12 @@ class Userfiles(object):
location = app.config.get('USERFILES_LOCATION') location = app.config.get('USERFILES_LOCATION')
path = app.config.get('USERFILES_PATH', None) path = app.config.get('USERFILES_PATH', None)
if path is not None:
handler_name = 'userfiles_handlers' handler_name = 'userfiles_handlers'
userfiles = DelegateUserfiles(app, distributed_storage, location, path, userfiles = DelegateUserfiles(app, distributed_storage, location, path,
handler_name=handler_name) handler_name=handler_name)
app.add_url_rule('/userfiles/<file_id>', app.add_url_rule('/userfiles/<regex("[0-9a-zA-Z-]+"):file_id>',
view_func=UserfilesHandlers.as_view(handler_name, view_func=UserfilesHandlers.as_view(handler_name,
distributed_storage=distributed_storage, distributed_storage=distributed_storage,
location=location, location=location,
@ -151,6 +151,7 @@ class Userfiles(object):
# register extension with app # register extension with app
app.extensions = getattr(app, 'extensions', {}) app.extensions = getattr(app, 'extensions', {})
app.extensions['userfiles'] = userfiles app.extensions['userfiles'] = userfiles
return userfiles return userfiles
def __getattr__(self, name): def __getattr__(self, name):